1919from sagemaker .fw_utils import (
2020 framework_name_from_image ,
2121 framework_version_from_tag ,
22- empty_framework_version_warning ,
23- python_deprecation_warning ,
2422 is_version_equal_or_higher ,
23+ python_deprecation_warning ,
24+ validate_version_or_image_args ,
2525)
2626from sagemaker .pytorch import defaults
2727from sagemaker .pytorch .model import PyTorchModel
@@ -40,10 +40,10 @@ class PyTorch(Framework):
4040 def __init__ (
4141 self ,
4242 entry_point ,
43+ framework_version = None ,
44+ py_version = None ,
4345 source_dir = None ,
4446 hyperparameters = None ,
45- py_version = defaults .PYTHON_VERSION ,
46- framework_version = None ,
4747 image_name = None ,
4848 ** kwargs
4949 ):
@@ -69,6 +69,13 @@ def __init__(
6969 file which should be executed as the entry point to training.
7070 If ``source_dir`` is specified, then ``entry_point``
7171 must point to a file located at the root of ``source_dir``.
72+ framework_version (str): PyTorch version you want to use for
73+ executing your model training code. Defaults to ``None``. Required unless
74+ ``image_name`` is provided. List of supported versions:
75+ https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
76+ py_version (str): Python version you want to use for executing your
77+ model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
78+ unless ``image_name`` is provided.
7279 source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7380 with any other training source code dependencies aside from the entry
7481 point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -80,12 +87,6 @@ def __init__(
8087 SageMaker. For convenience, this accepts other types for keys
8188 and values, but ``str()`` will be called to convert them before
8289 training.
83- py_version (str): Python version you want to use for executing your
84- model training code (default: 'py3'). One of 'py2' or 'py3'.
85- framework_version (str): PyTorch version you want to use for
86- executing your model training code. List of supported versions
87- https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
88- If not specified, this will default to 0.4.
8990 image_name (str): If specified, the estimator will use this image
9091 for training and hosting, instead of selecting the appropriate
9192 SageMaker official image based on framework_version and
@@ -95,6 +96,9 @@ def __init__(
9596 * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
9697 * ``custom-image:latest``
9798
99+ If ``framework_version`` or ``py_version`` are ``None``, then
100+ ``image_name`` is required. If also ``None``, then a ``ValueError``
101+ will be raised.
98102 **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
99103 constructor.
100104
@@ -104,28 +108,25 @@ def __init__(
104108 :class:`~sagemaker.estimator.Framework` and
105109 :class:`~sagemaker.estimator.EstimatorBase`.
106110 """
107- if framework_version is None :
111+ validate_version_or_image_args (framework_version , py_version , image_name )
112+ if py_version == "py2" :
108113 logger .warning (
109- empty_framework_version_warning ( defaults . PYTORCH_VERSION , self . LATEST_VERSION )
114+ python_deprecation_warning ( self . __framework_name__ , defaults . LATEST_PY2_VERSION )
110115 )
111- self .framework_version = framework_version or defaults .PYTORCH_VERSION
116+ self .framework_version = framework_version
117+ self .py_version = py_version
112118
113119 if "enable_sagemaker_metrics" not in kwargs :
114120 # enable sagemaker metrics for PT v1.3 or greater:
115- if is_version_equal_or_higher ([1 , 3 ], self .framework_version ):
121+ if self .framework_version and is_version_equal_or_higher (
122+ [1 , 3 ], self .framework_version
123+ ):
116124 kwargs ["enable_sagemaker_metrics" ] = True
117125
118126 super (PyTorch , self ).__init__ (
119127 entry_point , source_dir , hyperparameters , image_name = image_name , ** kwargs
120128 )
121129
122- if py_version == "py2" :
123- logger .warning (
124- python_deprecation_warning (self .__framework_name__ , defaults .LATEST_PY2_VERSION )
125- )
126-
127- self .py_version = py_version
128-
129130 def create_model (
130131 self ,
131132 model_server_workers = None ,
@@ -177,12 +178,12 @@ def create_model(
177178 self .model_data ,
178179 role or self .role ,
179180 entry_point or self .entry_point ,
181+ framework_version = self .framework_version ,
182+ py_version = self .py_version ,
180183 source_dir = (source_dir or self ._model_source_dir ()),
181184 enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
182185 container_log_level = self .container_log_level ,
183186 code_location = self .code_location ,
184- py_version = self .py_version ,
185- framework_version = self .framework_version ,
186187 model_server_workers = model_server_workers ,
187188 sagemaker_session = self .sagemaker_session ,
188189 vpc_config = self .get_vpc_config (vpc_config_override ),
@@ -210,15 +211,19 @@ class constructor
210211 image_name = init_params .pop ("image" )
211212 framework , py_version , tag , _ = framework_name_from_image (image_name )
212213
214+ if tag is None :
215+ framework_version = None
216+ else :
217+ framework_version = framework_version_from_tag (tag )
218+ init_params ["framework_version" ] = framework_version
219+ init_params ["py_version" ] = py_version
220+
213221 if not framework :
214222 # If we were unable to parse the framework name from the image it is not one of our
215223 # officially supported images, in this case just add the image to the init params.
216224 init_params ["image_name" ] = image_name
217225 return init_params
218226
219- init_params ["py_version" ] = py_version
220- init_params ["framework_version" ] = framework_version_from_tag (tag )
221-
222227 training_job_name = init_params ["base_job_name" ]
223228
224229 if framework != cls .__framework_name__ :
0 commit comments