1919from sagemaker .fw_registry import default_framework_uri
2020from sagemaker .fw_utils import (
2121 framework_name_from_image ,
22- empty_framework_version_warning ,
23- python_deprecation_warning ,
22+ framework_version_from_tag ,
23+ validate_version_or_image_args ,
2424)
2525from sagemaker .sklearn import defaults
2626from sagemaker .sklearn .model import SKLearnModel
@@ -37,10 +37,10 @@ class SKLearn(Framework):
3737 def __init__ (
3838 self ,
3939 entry_point ,
40- framework_version = defaults .SKLEARN_VERSION ,
40+ framework_version = None ,
41+ py_version = "py3" ,
4142 source_dir = None ,
4243 hyperparameters = None ,
43- py_version = "py3" ,
4444 image_name = None ,
4545 ** kwargs
4646 ):
@@ -68,8 +68,13 @@ def __init__(
6868 If ``source_dir`` is specified, then ``entry_point``
6969 must point to a file located at the root of ``source_dir``.
7070 framework_version (str): Scikit-learn version you want to use for
71- executing your model training code. List of supported versions
71+ executing your model training code. Defaults to ``None``. Required
72+ unless ``image_name`` is provided. List of supported versions:
7273 https://github.com/aws/sagemaker-python-sdk#sklearn-sagemaker-estimators
74+ py_version (str): Python version you want to use for executing your
75+ model training code (default: 'py3'). Currently, 'py3' is the only
76+ supported version. If ``None`` is passed in, ``image_name`` must be
77+ provided.
7378 source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7479 with any other training source code dependencies aside from the entry
7580 point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -81,15 +86,18 @@ def __init__(
8186 SageMaker. For convenience, this accepts other types for keys
8287 and values, but ``str()`` will be called to convert them before
8388 training.
84- py_version (str): Python version you want to use for executing your
85- model training code (default: 'py3'). One of 'py2' or 'py3'.
8689 image_name (str): If specified, the estimator will use this image
8790 for training and hosting, instead of selecting the appropriate
8891 SageMaker official image based on framework_version and
8992 py_version. It can be an ECR url or dockerhub image and tag.
93+
9094 Examples:
9195 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
9296 custom-image:latest.
97+
98+ If ``framework_version`` or ``py_version`` are ``None``, then
99+ ``image_name`` is required. If also ``None``, then a ``ValueError``
100+ will be raised.
93101 **kwargs: Additional kwargs passed to the
94102 :class:`~sagemaker.estimator.Framework` constructor.
95103
@@ -99,6 +107,14 @@ def __init__(
99107 :class:`~sagemaker.estimator.Framework` and
100108 :class:`~sagemaker.estimator.EstimatorBase`.
101109 """
110+ validate_version_or_image_args (framework_version , py_version , image_name )
111+ if py_version and py_version != "py3" :
112+ raise AttributeError (
113+ "Scikit-learn image only supports Python 3. Please use 'py3' for py_version."
114+ )
115+ self .framework_version = framework_version
116+ self .py_version = py_version
117+
102118 # SciKit-Learn does not support distributed training or training on GPU instance types.
103119 # Fail fast.
104120 train_instance_type = kwargs .get ("train_instance_type" )
@@ -112,6 +128,7 @@ def __init__(
112128 "Please remove the 'train_instance_count' argument or set "
113129 "'train_instance_count=1' when initializing SKLearn."
114130 )
131+
115132 super (SKLearn , self ).__init__ (
116133 entry_point ,
117134 source_dir ,
@@ -120,19 +137,6 @@ def __init__(
120137 ** dict (kwargs , train_instance_count = 1 )
121138 )
122139
123- if py_version == "py2" :
124- logger .warning (
125- python_deprecation_warning (self .__framework_name__ , defaults .LATEST_PY2_VERSION )
126- )
127-
128- self .py_version = py_version
129-
130- if framework_version is None :
131- logger .warning (
132- empty_framework_version_warning (defaults .SKLEARN_VERSION , defaults .SKLEARN_VERSION )
133- )
134- self .framework_version = framework_version or defaults .SKLEARN_VERSION
135-
136140 if image_name is None :
137141 image_tag = "{}-{}-{}" .format (framework_version , "cpu" , py_version )
138142 self .image_name = default_framework_uri (
@@ -216,28 +220,40 @@ class constructor
216220 Args:
217221 job_details: the returned job details from a describe_training_job
218222 API call.
219- model_channel_name:
223+ model_channel_name (str): Name of the channel where pre-trained
224+ model data will be downloaded (default: None).
220225
221226 Returns:
222227 dictionary: The transformed init_params
223228 """
224- init_params = super (SKLearn , cls )._prepare_init_params_from_job_description (job_details )
225-
229+ init_params = super (SKLearn , cls )._prepare_init_params_from_job_description (
230+ job_details , model_channel_name
231+ )
226232 image_name = init_params .pop ("image" )
227- framework , py_version , _ , _ = framework_name_from_image (image_name )
233+ framework , py_version , tag , _ = framework_name_from_image (image_name )
234+
235+ if tag is None :
236+ framework_version = None
237+ else :
238+ framework_version = framework_version_from_tag (tag )
239+ init_params ["framework_version" ] = framework_version
228240 init_params ["py_version" ] = py_version
229241
242+ if not framework :
243+ # If we were unable to parse the framework name from the image it is not one of our
244+ # officially supported images, in this case just add the image to the init params.
245+ init_params ["image_name" ] = image_name
246+ return init_params
247+
248+ training_job_name = init_params ["base_job_name" ]
249+
230250 if framework and framework != cls .__framework_name__ :
231- training_job_name = init_params ["base_job_name" ]
232251 raise ValueError (
233252 "Training job: {} didn't use image for requested framework" .format (
234253 training_job_name
235254 )
236255 )
237- if not framework :
238- # If we were unable to parse the framework name from the image it is not one of our
239- # officially supported images, in this case just add the image to the init params.
240- init_params ["image_name" ] = image_name
256+
241257 return init_params
242258
243259
0 commit comments