Skip to content

Commit 1d1ca43

Browse files
ericangelokimJonathan Esterhazy
authored andcommitted
add SKLearn Estimator
1 parent 9e0791f commit 1d1ca43

File tree

17 files changed

+1610
-29
lines changed

17 files changed

+1610
-29
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def read(fname):
3636
description="Open source library for training and deploying models on Amazon SageMaker.",
3737
packages=find_packages('src'),
3838
package_dir={'': 'src'},
39-
py_modules=[os.splitext(os.basename(path))[0] for path in glob('src/*.py')],
39+
py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob('src/*.py')],
4040
long_description=read('README.rst'),
4141
author="Amazon Web Services",
4242
url='https://github.com/aws/sagemaker-python-sdk/',

src/sagemaker/fw_registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,9 @@ def registry(region_name, framework=None):
8484
except KeyError:
8585
logging.error("The specific image or region does not exist")
8686
raise
87+
88+
89+
def default_framework_uri(framework, region_name, image_tag):
90+
repository_name = "sagemaker-{}".format(framework)
91+
account_name = registry(region_name, framework)
92+
return "{}/{}:{}".format(account_name, repository_name, image_tag)

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def framework_name_from_image(image_name):
193193
else:
194194
# extract framework, python version and image tag
195195
# We must support both the legacy and current image name format.
196-
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|chainer|pytorch):(.*?)-(.*?)-(py2|py3)$')
196+
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|chainer|pytorch|scikit-learn):(.*?)-(.*?)-(py2|py3)$')
197197
legacy_name_pattern = re.compile('^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
198198
name_match = name_pattern.match(sagemaker_match.group(8))
199199
legacy_match = legacy_name_pattern.match(sagemaker_match.group(8))

src/sagemaker/sklearn/README.rst

Lines changed: 652 additions & 0 deletions
Large diffs are not rendered by default.

src/sagemaker/sklearn/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker.sklearn.estimator import SKLearn
16+
from sagemaker.sklearn.model import SKLearnModel, SKLearnPredictor
17+
18+
__all__ = [SKLearn, SKLearnModel, SKLearnPredictor]

src/sagemaker/sklearn/defaults.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
SKLEARN_VERSION = '0.20.0'

src/sagemaker/sklearn/estimator.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import logging
16+
17+
from sagemaker.estimator import Framework
18+
from sagemaker.fw_registry import default_framework_uri
19+
from sagemaker.fw_utils import framework_name_from_image, empty_framework_version_warning
20+
from sagemaker.sklearn.defaults import SKLEARN_VERSION
21+
from sagemaker.sklearn.model import SKLearnModel
22+
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
23+
24+
logging.basicConfig()
25+
logger = logging.getLogger('sagemaker')
26+
27+
28+
class SKLearn(Framework):
29+
"""Handle end-to-end training and deployment of custom Scikit-learn code."""
30+
31+
__framework_name__ = "scikit-learn"
32+
33+
def __init__(self, entry_point, framework_version=SKLEARN_VERSION, source_dir=None, hyperparameters=None,
34+
py_version='py3', image_name=None, **kwargs):
35+
"""
36+
This ``Estimator`` executes an Scikit-learn script in a managed Scikit-learn execution environment, within a
37+
SageMaker Training Job. The managed Scikit-learn environment is an Amazon-built Docker container that executes
38+
functions defined in the supplied ``entry_point`` Python script.
39+
40+
Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
41+
After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a
42+
hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.sklearn.model.SKLearnPredictor` instance
43+
that can be used to perform inference against the hosted model.
44+
45+
Technical documentation on preparing Scikit-learn scripts for SageMaker training and using the Scikit-learn
46+
Estimator is available on the project home-page: https://github.com/aws/sagemaker-python-sdk
47+
48+
Args:
49+
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
50+
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
51+
source_dir (str): Path (absolute or relative) to a directory with any other training
52+
source code dependencies aside from tne entry point file (default: None). Structure within this
53+
directory are preserved when training on Amazon SageMaker.
54+
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
55+
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
56+
For convenience, this accepts other types for keys and values, but ``str()`` will be called
57+
to convert them before training.
58+
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
59+
One of 'py2' or 'py3'.
60+
framework_version (str): Scikit-learn version you want to use for executing your model training code.
61+
List of supported versions https://github.com/aws/sagemaker-python-sdk#sklearn-sagemaker-estimators
62+
image_name (str): If specified, the estimator will use this image for training and hosting, instead of
63+
selecting the appropriate SageMaker official image based on framework_version and py_version. It can
64+
be an ECR url or dockerhub image and tag.
65+
Examples:
66+
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
67+
custom-image:latest.
68+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
69+
"""
70+
# SciKit-Learn does not support distributed training or training on GPU instance types. Fail fast.
71+
train_instance_type = kwargs.get('train_instance_type')
72+
_validate_not_gpu_instance_type(train_instance_type)
73+
74+
train_instance_count = kwargs.get('train_instance_count')
75+
if train_instance_count:
76+
if train_instance_count != 1:
77+
raise AttributeError("SciKit-Learn does not support distributed training. "
78+
"Please remove the 'train_instance_count' argument or set "
79+
"'train_instance_count=1' when initializing SKLearn.")
80+
super(SKLearn, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name,
81+
**dict(kwargs, train_instance_count=1))
82+
83+
self.py_version = py_version
84+
85+
if framework_version is None:
86+
logger.warning(empty_framework_version_warning(SKLEARN_VERSION, SKLEARN_VERSION))
87+
self.framework_version = framework_version or SKLEARN_VERSION
88+
89+
if image_name is None:
90+
image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version)
91+
self.image_name = default_framework_uri(
92+
SKLearn.__framework_name__,
93+
self.sagemaker_session.boto_region_name,
94+
image_tag)
95+
96+
def create_model(self, model_server_workers=None, role=None,
97+
vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
98+
"""Create a SageMaker ``SKLearnModel`` object that can be deployed to an ``Endpoint``.
99+
100+
Args:
101+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
102+
transform jobs. If not specified, the role from the Estimator will be used.
103+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
104+
If None, server will use one worker per vCPU.
105+
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the model.
106+
Default: use subnets and security groups from this Estimator.
107+
* 'Subnets' (list[str]): List of subnet ids.
108+
* 'SecurityGroupIds' (list[str]): List of security group ids.
109+
**kwargs: Passed to initialization of ``SKLearnModel``.
110+
111+
Returns:
112+
sagemaker.sklearn.model.SKLearnModel: A SageMaker ``SKLearnModel`` object.
113+
See :func:`~sagemaker.sklearn.model.SKLearnModel` for full details.
114+
"""
115+
role = role or self.role
116+
return SKLearnModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
117+
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
118+
container_log_level=self.container_log_level, code_location=self.code_location,
119+
py_version=self.py_version, framework_version=self.framework_version,
120+
model_server_workers=model_server_workers, image=self.image_name,
121+
sagemaker_session=self.sagemaker_session,
122+
vpc_config=self.get_vpc_config(vpc_config_override),
123+
**kwargs)
124+
125+
@classmethod
126+
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
127+
"""Convert the job description to init params that can be handled by the class constructor
128+
129+
Args:
130+
job_details: the returned job details from a describe_training_job API call.
131+
132+
Returns:
133+
dictionary: The transformed init_params
134+
135+
"""
136+
init_params = super(SKLearn, cls)._prepare_init_params_from_job_description(job_details)
137+
138+
image_name = init_params.pop('image')
139+
framework, py_version, _ = framework_name_from_image(image_name)
140+
init_params['py_version'] = py_version
141+
142+
if framework and framework != cls.__framework_name__:
143+
training_job_name = init_params['base_job_name']
144+
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
145+
elif not framework:
146+
# If we were unable to parse the framework name from the image it is not one of our
147+
# officially supported images, in this case just add the image to the init params.
148+
init_params['image_name'] = image_name
149+
return init_params
150+
151+
152+
def _validate_not_gpu_instance_type(training_instance_type):
153+
gpu_instance_types = ['ml.p2.xlarge', 'ml.p2.8xlarge', 'ml.p2.16xlarge',
154+
'ml.p3.xlarge', 'ml.p3.8xlarge', 'ml.p3.16xlarge']
155+
156+
if training_instance_type in gpu_instance_types:
157+
raise ValueError("GPU training in not supported for SciKit-Learn. "
158+
"Please pick a different instance type from here: "
159+
"https://aws.amazon.com/ec2/instance-types/")

src/sagemaker/sklearn/model.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import sagemaker
16+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
17+
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
18+
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
19+
from sagemaker.sklearn.defaults import SKLEARN_VERSION
20+
21+
22+
class SKLearnPredictor(RealTimePredictor):
23+
"""A RealTimePredictor for inference against Scikit-learn Endpoints.
24+
25+
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for Scikit-learn
26+
inference."""
27+
28+
def __init__(self, endpoint_name, sagemaker_session=None):
29+
"""Initialize an ``SKLearnPredictor``.
30+
31+
Args:
32+
endpoint_name (str): The name of the endpoint to perform inference on.
33+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
34+
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
35+
using the default AWS configuration chain.
36+
"""
37+
super(SKLearnPredictor, self).__init__(endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer)
38+
39+
40+
class SKLearnModel(FrameworkModel):
41+
"""An Scikit-learn SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
42+
43+
__framework_name__ = 'scikit-learn'
44+
45+
def __init__(self, model_data, role, entry_point, image=None, py_version='py3', framework_version=SKLEARN_VERSION,
46+
predictor_cls=SKLearnPredictor, model_server_workers=None, **kwargs):
47+
"""Initialize an SKLearnModel.
48+
49+
Args:
50+
model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file.
51+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
52+
that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
53+
After the endpoint is created, the inference code might use the IAM role,
54+
if it needs to access an AWS resource.
55+
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
56+
as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
57+
image (str): A Docker image URI (default: None). If not specified, a default image for Scikit-learn
58+
will be used.
59+
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
60+
framework_version (str): Scikit-learn version you want to use for executing your model training code.
61+
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor
62+
with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of
63+
invoking this function on the created endpoint name.
64+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
65+
If None, server will use one worker per vCPU.
66+
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
67+
"""
68+
super(SKLearnModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
69+
**kwargs)
70+
self.py_version = py_version
71+
self.framework_version = framework_version
72+
self.model_server_workers = model_server_workers
73+
74+
def prepare_container_def(self, instance_type):
75+
"""Return a container definition with framework configuration set in model environment variables.
76+
77+
Args:
78+
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
79+
80+
Returns:
81+
dict[str, str]: A container definition object usable with the CreateModel API.
82+
"""
83+
deploy_image = self.image
84+
if not deploy_image:
85+
region_name = self.sagemaker_session.boto_session.region_name
86+
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
87+
self.framework_version, self.py_version)
88+
89+
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
90+
self._upload_code(deploy_key_prefix)
91+
deploy_env = dict(self.env)
92+
deploy_env.update(self._framework_env_vars())
93+
94+
if self.model_server_workers:
95+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
96+
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)

0 commit comments

Comments
 (0)