1919from sagemaker .fw_utils import (
2020 framework_name_from_image ,
2121 framework_version_from_tag ,
22- empty_framework_version_warning ,
22+ is_version_equal_or_higher ,
2323 python_deprecation_warning ,
2424 parameter_v2_rename_warning ,
25- is_version_equal_or_higher ,
25+ validate_version_or_image_args ,
2626 warn_if_parameter_server_with_multi_gpu ,
2727)
2828from sagemaker .mxnet import defaults
@@ -43,10 +43,10 @@ class MXNet(Framework):
4343 def __init__ (
4444 self ,
4545 entry_point ,
46+ framework_version = None ,
47+ py_version = None ,
4648 source_dir = None ,
4749 hyperparameters = None ,
48- py_version = "py2" ,
49- framework_version = None ,
5050 image_name = None ,
5151 distributions = None ,
5252 ** kwargs
@@ -73,6 +73,13 @@ def __init__(
7373 file which should be executed as the entry point to training.
7474 If ``source_dir`` is specified, then ``entry_point``
7575 must point to a file located at the root of ``source_dir``.
76+ framework_version (str): MXNet version you want to use for executing
77+ your model training code. Defaults to `None`. Required unless
78+ ``image_name`` is provided. List of supported versions.
79+ https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
80+ py_version (str): Python version you want to use for executing your
81+ model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
82+ unless ``image_name`` is provided.
7683 source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7784 with any other training source code dependencies aside from the entry
7885 point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -84,12 +91,6 @@ def __init__(
8491 SageMaker. For convenience, this accepts other types for keys
8592 and values, but ``str()`` will be called to convert them before
8693 training.
87- py_version (str): Python version you want to use for executing your
88- model training code (default: 'py2'). One of 'py2' or 'py3'.
89- framework_version (str): MXNet version you want to use for executing
90- your model training code. List of supported versions
91- https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
92- If not specified, this will default to 1.2.1.
9394 image_name (str): If specified, the estimator will use this image for training and
9495 hosting, instead of selecting the appropriate SageMaker official image based on
9596 framework_version and py_version. It can be an ECR url or dockerhub image and tag.
@@ -98,6 +99,9 @@ def __init__(
9899 * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
99100 * ``custom-image:latest``
100101
102+ If ``framework_version`` or ``py_version`` are ``None``, then
103+ ``image_name`` is required. If also ``None``, then a ``ValueError``
104+ will be raised.
101105 distributions (dict): A dictionary with information on how to run distributed
102106 training (default: None). To have parameter servers launched for training,
103107 set this value to be ``{'parameter_server': {'enabled': True}}``.
@@ -110,34 +114,32 @@ def __init__(
110114 :class:`~sagemaker.estimator.Framework` and
111115 :class:`~sagemaker.estimator.EstimatorBase`.
112116 """
113- if framework_version is None :
117+ validate_version_or_image_args (framework_version , py_version , image_name )
118+ if py_version and py_version == "py2" :
114119 logger .warning (
115- empty_framework_version_warning ( defaults . MXNET_VERSION , self . LATEST_VERSION )
120+ python_deprecation_warning ( self . __framework_name__ , defaults . LATEST_PY2_VERSION )
116121 )
117- self .framework_version = framework_version or defaults .MXNET_VERSION
122+ self .framework_version = framework_version
123+ self .py_version = py_version
118124
119125 if "enable_sagemaker_metrics" not in kwargs :
120126 # enable sagemaker metrics for MXNet v1.6 or greater:
121- if is_version_equal_or_higher ([1 , 6 ], self .framework_version ):
127+ if self .framework_version and is_version_equal_or_higher (
128+ [1 , 6 ], self .framework_version
129+ ):
122130 kwargs ["enable_sagemaker_metrics" ] = True
123131
124132 super (MXNet , self ).__init__ (
125133 entry_point , source_dir , hyperparameters , image_name = image_name , ** kwargs
126134 )
127135
128- if py_version == "py2" :
129- logger .warning (
130- python_deprecation_warning (self .__framework_name__ , defaults .LATEST_PY2_VERSION )
131- )
132-
133136 if distributions is not None :
134137 logger .warning (parameter_v2_rename_warning ("distributions" , "distribution" ))
135138 train_instance_type = kwargs .get ("train_instance_type" )
136139 warn_if_parameter_server_with_multi_gpu (
137140 training_instance_type = train_instance_type , distributions = distributions
138141 )
139142
140- self .py_version = py_version
141143 self ._configure_distribution (distributions )
142144
143145 def _configure_distribution (self , distributions ):
@@ -148,7 +150,10 @@ def _configure_distribution(self, distributions):
148150 if distributions is None :
149151 return
150152
151- if self .framework_version .split ("." ) < self ._LOWEST_SCRIPT_MODE_VERSION :
153+ if (
154+ self .framework_version
155+ and self .framework_version .split ("." ) < self ._LOWEST_SCRIPT_MODE_VERSION
156+ ):
152157 raise ValueError (
153158 "The distributions option is valid for only versions {} and higher" .format (
154159 "." .join (self ._LOWEST_SCRIPT_MODE_VERSION )
@@ -221,12 +226,12 @@ def create_model(
221226 self .model_data ,
222227 role or self .role ,
223228 entry_point or self .entry_point ,
229+ framework_version = self .framework_version ,
230+ py_version = self .py_version ,
224231 source_dir = (source_dir or self ._model_source_dir ()),
225232 enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
226233 container_log_level = self .container_log_level ,
227234 code_location = self .code_location ,
228- py_version = self .py_version ,
229- framework_version = self .framework_version ,
230235 model_server_workers = model_server_workers ,
231236 sagemaker_session = self .sagemaker_session ,
232237 vpc_config = self .get_vpc_config (vpc_config_override ),
@@ -254,22 +259,25 @@ class constructor
254259 image_name = init_params .pop ("image" )
255260 framework , py_version , tag , _ = framework_name_from_image (image_name )
256261
262+ # We switched image tagging scheme from regular image version (e.g. '1.0') to more
263+ # expressive containing framework version, device type and python version
264+ # (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
265+ # '0.12' framework version otherwise extract framework version from the tag itself.
266+ if tag is None :
267+ framework_version = None
268+ elif tag == "1.0" :
269+ framework_version = "0.12"
270+ else :
271+ framework_version = framework_version_from_tag (tag )
272+ init_params ["framework_version" ] = framework_version
273+ init_params ["py_version" ] = py_version
274+
257275 if not framework :
258276 # If we were unable to parse the framework name from the image it is not one of our
259277 # officially supported images, in this case just add the image to the init params.
260278 init_params ["image_name" ] = image_name
261279 return init_params
262280
263- init_params ["py_version" ] = py_version
264-
265- # We switched image tagging scheme from regular image version (e.g. '1.0') to more
266- # expressive containing framework version, device type and python version
267- # (e.g. '0.12-gpu-py2'). For backward compatibility map deprecated image tag '1.0' to a
268- # '0.12' framework version otherwise extract framework version from the tag itself.
269- init_params ["framework_version" ] = (
270- "0.12" if tag == "1.0" else framework_version_from_tag (tag )
271- )
272-
273281 training_job_name = init_params ["base_job_name" ]
274282
275283 if framework != cls .__framework_name__ :
0 commit comments