1919from sagemaker .fw_utils import (
2020 framework_name_from_image ,
2121 framework_version_from_tag ,
22- empty_framework_version_warning ,
2322 python_deprecation_warning ,
23+ validate_version_or_image_args ,
2424)
2525from sagemaker .chainer import defaults
2626from sagemaker .chainer .model import ChainerModel
@@ -51,8 +51,8 @@ def __init__(
5151 additional_mpi_options = None ,
5252 source_dir = None ,
5353 hyperparameters = None ,
54- py_version = "py3" ,
5554 framework_version = None ,
55+ py_version = None ,
5656 image_name = None ,
5757 ** kwargs
5858 ):
@@ -103,11 +103,12 @@ def __init__(
103103 and values, but ``str()`` will be called to convert them before
104104 training.
105105 py_version (str): Python version you want to use for executing your
106- model training code (default: 'py2'). One of 'py2' or 'py3'.
106+ model training code. Defaults to ``None``. Required unless ``image_name``
107+ is provided.
107108 framework_version (str): Chainer version you want to use for
108- executing your model training code. List of supported versions
109+ executing your model training code. Defaults to ``None``. Required unless
110+ ``image_name`` is provided. List of supported versions:
109111 https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators.
110- If not specified, this will default to 4.1.
111112 image_name (str): If specified, the estimator will use this image
112113 for training and hosting, instead of selecting the appropriate
113114 SageMaker official image based on framework_version and
@@ -117,6 +118,9 @@ def __init__(
117118 * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
118119 * ``custom-image:latest``
119120
121+ If ``framework_version`` or ``py_version`` are ``None``, then
122+ ``image_name`` is required. If also ``None``, then a ``ValueError``
123+ will be raised.
120124 **kwargs: Additional kwargs passed to the
121125 :class:`~sagemaker.estimator.Framework` constructor.
122126
@@ -126,22 +130,18 @@ def __init__(
126130 :class:`~sagemaker.estimator.Framework` and
127131 :class:`~sagemaker.estimator.EstimatorBase`.
128132 """
129- if framework_version is None :
133+ validate_version_or_image_args (framework_version , py_version , image_name )
134+ if py_version == "py2" :
130135 logger .warning (
131- empty_framework_version_warning ( defaults . CHAINER_VERSION , self . LATEST_VERSION )
136+ python_deprecation_warning ( self . __framework_name__ , defaults . LATEST_PY2_VERSION )
132137 )
133- self .framework_version = framework_version or defaults .CHAINER_VERSION
138+ self .framework_version = framework_version
139+ self .py_version = py_version
134140
135141 super (Chainer , self ).__init__ (
136142 entry_point , source_dir , hyperparameters , image_name = image_name , ** kwargs
137143 )
138144
139- if py_version == "py2" :
140- logger .warning (
141- python_deprecation_warning (self .__framework_name__ , defaults .LATEST_PY2_VERSION )
142- )
143-
144- self .py_version = py_version
145145 self .use_mpi = use_mpi
146146 self .num_processes = num_processes
147147 self .process_slots_per_host = process_slots_per_host
@@ -262,15 +262,19 @@ class constructor
262262 image_name = init_params .pop ("image" )
263263 framework , py_version , tag , _ = framework_name_from_image (image_name )
264264
265+ if tag is None :
266+ framework_version = None
267+ else :
268+ framework_version = framework_version_from_tag (tag )
269+ init_params ["framework_version" ] = framework_version
270+ init_params ["py_version" ] = py_version
271+
265272 if not framework :
266273 # If we were unable to parse the framework name from the image it is not one of our
267274 # officially supported images, in this case just add the image to the init params.
268275 init_params ["image_name" ] = image_name
269276 return init_params
270277
271- init_params ["py_version" ] = py_version
272- init_params ["framework_version" ] = framework_version_from_tag (tag )
273-
274278 training_job_name = init_params ["base_job_name" ]
275279
276280 if framework != cls .__framework_name__ :
0 commit comments