Skip to content

Commit 28ac404

Browse files
wweicJonathan Esterhazy
authored andcommitted
add SageMaker Neo support
1 parent 1d1ca43 commit 28ac404

File tree

7 files changed

+562
-4
lines changed

7 files changed

+562
-4
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,13 @@ def registry(region_name, algorithm=None):
353353
"eu-west-2": "644912444149",
354354
"us-west-1": "632365934929",
355355
}[region_name]
356+
elif algorithm in ['image-classification-neo', 'xgboost-neo']:
357+
account_id = {
358+
'us-west-2': '301217895009',
359+
'us-east-1': '785573368785',
360+
'eu-west-1': '802834080501',
361+
'us-east-2': '007439368137'
362+
}[region_name]
356363
else:
357364
raise ValueError("Algorithm class:{} doesn't have mapping to account_id with images".format(algorithm))
358365
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)

src/sagemaker/estimator.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
validate_source_dir)
2727
from sagemaker.job import _Job
2828
from sagemaker.local import LocalSession
29-
from sagemaker.model import Model
29+
from sagemaker.model import Model, NEO_ALLOWED_TARGET_INSTANCE_FAMILY, NEO_ALLOWED_FRAMEWORKS
3030
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
3131
CONTAINER_LOG_LEVEL_PARAM_NAME, JOB_NAME_PARAM_NAME, SAGEMAKER_REGION_PARAM_NAME)
3232
from sagemaker.predictor import RealTimePredictor
@@ -131,6 +131,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
131131
self.output_kms_key = output_kms_key
132132
self.latest_training_job = None
133133

134+
self._compiled_models = {}
135+
134136
# VPC configurations
135137
self.subnets = subnets
136138
self.security_group_ids = security_group_ids
@@ -216,6 +218,57 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None):
216218
if wait:
217219
self.latest_training_job.wait(logs=logs)
218220

221+
def _compilation_job_name(self):
222+
base_name = self.base_job_name or base_name_from_image(self.train_image())
223+
return name_from_base('compilation-' + base_name)
224+
225+
def compile_model(self, target_instance_family, input_shape, output_path, framework=None, framework_version=None,
226+
compile_max_run=5 * 60, tags=None, **kwargs):
227+
"""Compile a Neo model using the input model.
228+
229+
Args:
230+
target_instance_family (str): Identifies the device that you want to run your model after compilation, for
231+
example: ml_c5. Allowed strings are: ml_c5, ml_m5, ml_c4, ml_m4, jetsontx1, jetsontx2, ml_p2, ml_p3,
232+
deeplens, rasp3b
233+
input_shape (dict): Specifies the name and shape of the expected inputs for your trained model in json
234+
dictionary form, for example: {‘data’:[1,3,1024,1024]}, or {‘var1’: [1,1,28,28], ‘var2’:[1,1,28,28]}
235+
output_path (str): Specifies where to store the compiled model
236+
framework (str): The framework that is used to train the original model. Allowed values: 'mxnet',
237+
'tensorflow', 'pytorch', 'onnx', 'xgboost'
238+
framework_version (str): The version of the framework
239+
compile_max_run (int): Timeout in seconds for compilation (default: 3 * 60).
240+
After this amount of time Amazon SageMaker Neo terminates the compilation job regardless of its
241+
current status.
242+
tags (list[dict]): List of tags for labeling a compilation job. For more, see
243+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
244+
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
245+
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
246+
For more, see the implementation docs.
247+
Returns:
248+
sagemaker.model.Model: A SageMaker ``Model`` object. See :func:`~sagemaker.model.Model` for full details.
249+
"""
250+
if target_instance_family not in NEO_ALLOWED_TARGET_INSTANCE_FAMILY:
251+
raise ValueError("Please use valid target_instance_family,"
252+
"allowed values: {}".format(NEO_ALLOWED_TARGET_INSTANCE_FAMILY))
253+
if framework and framework not in NEO_ALLOWED_FRAMEWORKS:
254+
raise ValueError("Please use valid framework, allowed values: {}".format(NEO_ALLOWED_FRAMEWORKS))
255+
256+
if (framework is None) != (framework_version is None):
257+
raise ValueError("You should provide framework and framework_version at the same time.")
258+
259+
model = self.create_model(**kwargs)
260+
261+
self._compiled_models[target_instance_family] = model.compile(target_instance_family,
262+
input_shape,
263+
output_path,
264+
self.role,
265+
tags,
266+
self._compilation_job_name(),
267+
compile_max_run,
268+
framework=framework,
269+
framework_version=framework_version)
270+
return self._compiled_models[target_instance_family]
271+
219272
@classmethod
220273
def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='model'):
221274
"""Attach to an existing training job.
@@ -257,7 +310,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
257310
estimator.latest_training_job.wait()
258311
return estimator
259312

260-
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kwargs):
313+
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, use_compiled_model=False, **kwargs):
261314
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
262315
263316
More information:
@@ -269,6 +322,7 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kw
269322
for example, 'ml.c4.xlarge'.
270323
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of
271324
the training job is used.
325+
use_compiled_model (bool): Flag to select whether to use compiled (optimized) model. Default: False.
272326
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
273327
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
274328
For more, see the implementation docs.
@@ -280,7 +334,15 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kw
280334
self._ensure_latest_training_job()
281335
endpoint_name = endpoint_name or self.latest_training_job.name
282336
self.deploy_instance_type = instance_type
283-
return self.create_model(**kwargs).deploy(
337+
if use_compiled_model:
338+
family = '_'.join(instance_type.split('.')[:-1])
339+
if family not in self._compiled_models:
340+
raise ValueError("No compiled model for {}. "
341+
"Please compile one with compile_model before deploying.".format(family))
342+
model = self._compiled_models[family]
343+
else:
344+
model = self.create_model(**kwargs)
345+
return model.deploy(
284346
instance_type=instance_type,
285347
initial_instance_count=initial_instance_count,
286348
endpoint_name=endpoint_name)

src/sagemaker/model.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,23 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import json
1516
import logging
1617

1718
import sagemaker
1819
from sagemaker import fw_utils, local, session, utils
1920

21+
NEO_ALLOWED_TARGET_INSTANCE_FAMILY = set(['ml_c5', 'ml_m5', 'ml_c4', 'ml_m4', 'jetson_tx1', 'jetson_tx2', 'ml_p2',
22+
'ml_p3', 'deeplens', 'rasp3b'])
23+
NEO_ALLOWED_FRAMEWORKS = set(['mxnet', 'tensorflow', 'pytorch', 'onnx', 'xgboost'])
24+
25+
NEO_IMAGE_ACCOUNT = {
26+
'us-west-2': '301217895009',
27+
'us-east-1': '785573368785',
28+
'eu-west-1': '802834080501',
29+
'us-east-2': '007439368137'
30+
}
31+
2032

2133
class Model(object):
2234
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
@@ -53,6 +65,7 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n
5365
self.vpc_config = vpc_config
5466
self.sagemaker_session = sagemaker_session
5567
self._model_name = None
68+
self._is_compiled_model = False
5669

5770
def prepare_container_def(self, instance_type): # pylint: disable=unused-argument
5871
"""Return a dict created by ``sagemaker.container_def()`` for deploying this model to a specified instance type.
@@ -68,6 +81,93 @@ def prepare_container_def(self, instance_type): # pylint: disable=unused-argume
6881
"""
6982
return sagemaker.container_def(self.image, self.model_data, self.env)
7083

84+
def _framework(self):
85+
return getattr(self, '__framework_name__', None)
86+
87+
def _get_framework_version(self):
88+
return getattr(self, 'framework_version', None)
89+
90+
def _compilation_job_config(self, target_instance_type, input_shape, output_path, role, compile_max_run,
91+
job_name, framework, tags):
92+
input_model_config = {
93+
'S3Uri': self.model_data,
94+
'DataInputConfig': input_shape if type(input_shape) != dict else json.dumps(input_shape),
95+
'Framework': framework
96+
}
97+
role = self.sagemaker_session.expand_role(role)
98+
output_model_config = {
99+
'TargetDevice': target_instance_type,
100+
'S3OutputLocation': output_path
101+
}
102+
103+
return {'input_model_config': input_model_config,
104+
'output_model_config': output_model_config,
105+
'role': role,
106+
'stop_condition': {
107+
'MaxRuntimeInSeconds': compile_max_run
108+
},
109+
'tags': tags,
110+
'job_name': job_name}
111+
112+
def _neo_image_account(self, region):
113+
if region not in NEO_IMAGE_ACCOUNT:
114+
raise ValueError("Neo is not currently supported in {}, "
115+
"valid regions: {}".format(region, NEO_IMAGE_ACCOUNT.keys()))
116+
return NEO_IMAGE_ACCOUNT[region]
117+
118+
def _neo_image(self, region, target_instance_type, framework, framework_version):
119+
return fw_utils.create_image_uri(region,
120+
'neo-' + framework.lower(),
121+
target_instance_type.replace('_', '.'),
122+
framework_version,
123+
py_version='py3',
124+
account=self._neo_image_account(region))
125+
126+
def compile(self, target_instance_family, input_shape, output_path, role,
127+
tags=None, job_name=None, compile_max_run=5 * 60, framework=None, framework_version=None):
128+
"""Compile this ``Model`` with SageMaker Neo.
129+
130+
Args:
131+
target_instance_family (str): Identifies the device that you want to run your model after compilation, for
132+
example: ml_c5. Allowed strings are: ml_c5, ml_m5, ml_c4, ml_m4, jetsontx1, jetsontx2, ml_p2, ml_p3,
133+
deeplens, rasp3b
134+
input_shape (dict): Specifies the name and shape of the expected inputs for your trained model in json
135+
dictionary form, for example: {‘data’:[1,3,1024,1024]}, or {‘var1’: [1,1,28,28], ‘var2’:[1,1,28,28]}
136+
output_path (str): Specifies where to store the compiled model
137+
role (str): Execution role
138+
tags (list[dict]): List of tags for labeling a compilation job. For more, see
139+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
140+
job_name (str): The name of the compilation job
141+
compile_max_run (int): Timeout in seconds for compilation (default: 3 * 60).
142+
After this amount of time Amazon SageMaker Neo terminates the compilation job regardless of its
143+
current status.
144+
framework (str): The framework that is used to train the original model. Allowed values: 'mxnet',
145+
'tensorflow', 'pytorch', 'onnx', 'xgboost'
146+
framework_version (str)
147+
Returns:
148+
sagemaker.model.Model: A SageMaker ``Model`` object. See :func:`~sagemaker.model.Model` for full details.
149+
"""
150+
framework = self._framework() or framework
151+
if framework is None:
152+
raise ValueError("You must specify framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS))
153+
if framework not in NEO_ALLOWED_FRAMEWORKS:
154+
raise ValueError("You must provide valid framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS))
155+
if job_name is None:
156+
raise ValueError("You must provide a compilation job name")
157+
158+
framework = framework.upper()
159+
framework_version = self._get_framework_version() or framework_version
160+
161+
config = self._compilation_job_config(target_instance_family, input_shape, output_path, role,
162+
compile_max_run, job_name, framework, tags)
163+
self.sagemaker_session.compile_model(**config)
164+
job_status = self.sagemaker_session.wait_for_compilation_job(job_name)
165+
self.model_data = job_status['ModelArtifacts']['S3ModelArtifacts']
166+
self.image = self._neo_image(self.sagemaker_session.boto_region_name, target_instance_family, framework,
167+
framework_version)
168+
self._is_compiled_model = True
169+
return self
170+
71171
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags=None):
72172
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
73173
@@ -98,13 +198,21 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
98198
else:
99199
self.sagemaker_session = session.Session()
100200

201+
compiled_model_suffix = '-'.join(instance_type.split('.')[:-1])
101202
container_def = self.prepare_container_def(instance_type)
102203
self.name = self.name or utils.name_from_image(container_def['Image'])
103204
if self.role is None:
104205
raise ValueError("Role can not be null for deploying a model")
206+
if self._is_compiled_model:
207+
self.name += compiled_model_suffix
105208
self.sagemaker_session.create_model(self.name, self.role, container_def, vpc_config=self.vpc_config)
106209
production_variant = sagemaker.production_variant(self.name, instance_type, initial_instance_count)
107-
self.endpoint_name = endpoint_name or self.name
210+
if endpoint_name:
211+
self.endpoint_name = endpoint_name
212+
else:
213+
self.endpoint_name = self.name
214+
if self._is_compiled_model and not self.endpoint_name.endswith(compiled_model_suffix):
215+
self.endpoint_name += compiled_model_suffix
108216
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
109217
if self.predictor_cls:
110218
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)

src/sagemaker/session.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,40 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
283283
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
284284
self.sagemaker_client.create_training_job(**train_request)
285285

286+
def compile_model(self, input_model_config, output_model_config, role,
287+
job_name, stop_condition, tags):
288+
"""Create an Amazon SageMaker Neo compilation job.
289+
290+
Args:
291+
input_model_config (dict): the trained model and the Amazon S3 location where it is stored.
292+
output_model_config (dict): - Identifies the Amazon S3 location where you want Amazon SageMaker Neo to save
293+
the results of compilation job
294+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker Neo compilation jobs use this
295+
role to access model artifacts. You must grant sufficient permissions to this role.
296+
job_name (str): Name of the compilation job being created.
297+
stop_condition (dict): Defines when compilation job shall finish. Contains entries that can be understood
298+
by the service like ``MaxRuntimeInSeconds``.
299+
tags (list[dict]): List of tags for labeling a compile model job. For more, see
300+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
301+
302+
Returns:
303+
str: ARN of the compile model job, if it is created.
304+
"""
305+
306+
compilation_job_request = {
307+
'InputConfig': input_model_config,
308+
'OutputConfig': output_model_config,
309+
'RoleArn': role,
310+
'StoppingCondition': stop_condition,
311+
'CompilationJobName': job_name
312+
}
313+
314+
if tags is not None:
315+
compilation_job_request['Tags'] = tags
316+
317+
LOGGER.info('Creating compilation-job with name: {}'.format(job_name))
318+
self.sagemaker_client.create_compilation_job(**compilation_job_request)
319+
286320
def tune(self, job_name, strategy, objective_type, objective_metric_name,
287321
max_jobs, max_parallel_jobs, parameter_ranges,
288322
static_hyperparameters, image, input_mode, metric_definitions,
@@ -613,6 +647,23 @@ def wait_for_job(self, job, poll=5):
613647
self._check_job_status(job, desc, 'TrainingJobStatus')
614648
return desc
615649

650+
def wait_for_compilation_job(self, job, poll=5):
651+
"""Wait for an Amazon SageMaker Neo compilation job to complete.
652+
653+
Args:
654+
job (str): Name of the compilation job to wait for.
655+
poll (int): Polling interval in seconds (default: 5).
656+
657+
Returns:
658+
(dict): Return value from the ``DescribeCompilationJob`` API.
659+
660+
Raises:
661+
ValueError: If the compilation job fails.
662+
"""
663+
desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll)
664+
self._check_job_status(job, desc, 'CompilationJobStatus')
665+
return desc
666+
616667
def wait_for_tuning_job(self, job, poll=5):
617668
"""Wait for an Amazon SageMaker hyperparameter tuning job to complete.
618669
@@ -1164,6 +1215,28 @@ def _train_done(sagemaker_client, job_name, last_desc):
11641215
return desc, True
11651216

11661217

1218+
def _compilation_job_status(sagemaker_client, job_name):
1219+
compile_status_codes = {
1220+
'Completed': '!',
1221+
'InProgress': '.',
1222+
'Failed': '*',
1223+
'Stopped': 's',
1224+
'Stopping': '_'
1225+
}
1226+
in_progress_statuses = ['InProgress', 'Stopping', 'Starting']
1227+
1228+
desc = sagemaker_client.describe_compilation_job(CompilationJobName=job_name)
1229+
status = desc['CompilationJobStatus']
1230+
1231+
print(compile_status_codes.get(status, '?'), end='')
1232+
sys.stdout.flush()
1233+
1234+
if status in in_progress_statuses:
1235+
return None
1236+
1237+
return desc
1238+
1239+
11671240
def _tuning_job_status(sagemaker_client, job_name):
11681241
tuning_status_codes = {
11691242
'Completed': '!',

0 commit comments

Comments
 (0)