12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
15
+ import json
15
16
import logging
16
17
17
18
import sagemaker
18
19
from sagemaker import fw_utils , local , session , utils
19
20
21
+ NEO_ALLOWED_TARGET_INSTANCE_FAMILY = set (['ml_c5' , 'ml_m5' , 'ml_c4' , 'ml_m4' , 'jetson_tx1' , 'jetson_tx2' , 'ml_p2' ,
22
+ 'ml_p3' , 'deeplens' , 'rasp3b' ])
23
+ NEO_ALLOWED_FRAMEWORKS = set (['mxnet' , 'tensorflow' , 'pytorch' , 'onnx' , 'xgboost' ])
24
+
25
+ NEO_IMAGE_ACCOUNT = {
26
+ 'us-west-2' : '301217895009' ,
27
+ 'us-east-1' : '785573368785' ,
28
+ 'eu-west-1' : '802834080501' ,
29
+ 'us-east-2' : '007439368137'
30
+ }
31
+
20
32
21
33
class Model (object ):
22
34
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
@@ -53,6 +65,7 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n
53
65
self .vpc_config = vpc_config
54
66
self .sagemaker_session = sagemaker_session
55
67
self ._model_name = None
68
+ self ._is_compiled_model = False
56
69
57
70
def prepare_container_def (self , instance_type ): # pylint: disable=unused-argument
58
71
"""Return a dict created by ``sagemaker.container_def()`` for deploying this model to a specified instance type.
@@ -68,6 +81,93 @@ def prepare_container_def(self, instance_type): # pylint: disable=unused-argume
68
81
"""
69
82
return sagemaker .container_def (self .image , self .model_data , self .env )
70
83
84
+ def _framework (self ):
85
+ return getattr (self , '__framework_name__' , None )
86
+
87
+ def _get_framework_version (self ):
88
+ return getattr (self , 'framework_version' , None )
89
+
90
+ def _compilation_job_config (self , target_instance_type , input_shape , output_path , role , compile_max_run ,
91
+ job_name , framework , tags ):
92
+ input_model_config = {
93
+ 'S3Uri' : self .model_data ,
94
+ 'DataInputConfig' : input_shape if type (input_shape ) != dict else json .dumps (input_shape ),
95
+ 'Framework' : framework
96
+ }
97
+ role = self .sagemaker_session .expand_role (role )
98
+ output_model_config = {
99
+ 'TargetDevice' : target_instance_type ,
100
+ 'S3OutputLocation' : output_path
101
+ }
102
+
103
+ return {'input_model_config' : input_model_config ,
104
+ 'output_model_config' : output_model_config ,
105
+ 'role' : role ,
106
+ 'stop_condition' : {
107
+ 'MaxRuntimeInSeconds' : compile_max_run
108
+ },
109
+ 'tags' : tags ,
110
+ 'job_name' : job_name }
111
+
112
+ def _neo_image_account (self , region ):
113
+ if region not in NEO_IMAGE_ACCOUNT :
114
+ raise ValueError ("Neo is not currently supported in {}, "
115
+ "valid regions: {}" .format (region , NEO_IMAGE_ACCOUNT .keys ()))
116
+ return NEO_IMAGE_ACCOUNT [region ]
117
+
118
+ def _neo_image (self , region , target_instance_type , framework , framework_version ):
119
+ return fw_utils .create_image_uri (region ,
120
+ 'neo-' + framework .lower (),
121
+ target_instance_type .replace ('_' , '.' ),
122
+ framework_version ,
123
+ py_version = 'py3' ,
124
+ account = self ._neo_image_account (region ))
125
+
126
+ def compile (self , target_instance_family , input_shape , output_path , role ,
127
+ tags = None , job_name = None , compile_max_run = 5 * 60 , framework = None , framework_version = None ):
128
+ """Compile this ``Model`` with SageMaker Neo.
129
+
130
+ Args:
131
+ target_instance_family (str): Identifies the device that you want to run your model after compilation, for
132
+ example: ml_c5. Allowed strings are: ml_c5, ml_m5, ml_c4, ml_m4, jetsontx1, jetsontx2, ml_p2, ml_p3,
133
+ deeplens, rasp3b
134
+ input_shape (dict): Specifies the name and shape of the expected inputs for your trained model in json
135
+ dictionary form, for example: {‘data’:[1,3,1024,1024]}, or {‘var1’: [1,1,28,28], ‘var2’:[1,1,28,28]}
136
+ output_path (str): Specifies where to store the compiled model
137
+ role (str): Execution role
138
+ tags (list[dict]): List of tags for labeling a compilation job. For more, see
139
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
140
+ job_name (str): The name of the compilation job
141
+ compile_max_run (int): Timeout in seconds for compilation (default: 3 * 60).
142
+ After this amount of time Amazon SageMaker Neo terminates the compilation job regardless of its
143
+ current status.
144
+ framework (str): The framework that is used to train the original model. Allowed values: 'mxnet',
145
+ 'tensorflow', 'pytorch', 'onnx', 'xgboost'
146
+ framework_version (str)
147
+ Returns:
148
+ sagemaker.model.Model: A SageMaker ``Model`` object. See :func:`~sagemaker.model.Model` for full details.
149
+ """
150
+ framework = self ._framework () or framework
151
+ if framework is None :
152
+ raise ValueError ("You must specify framework, allowed values {}" .format (NEO_ALLOWED_FRAMEWORKS ))
153
+ if framework not in NEO_ALLOWED_FRAMEWORKS :
154
+ raise ValueError ("You must provide valid framework, allowed values {}" .format (NEO_ALLOWED_FRAMEWORKS ))
155
+ if job_name is None :
156
+ raise ValueError ("You must provide a compilation job name" )
157
+
158
+ framework = framework .upper ()
159
+ framework_version = self ._get_framework_version () or framework_version
160
+
161
+ config = self ._compilation_job_config (target_instance_family , input_shape , output_path , role ,
162
+ compile_max_run , job_name , framework , tags )
163
+ self .sagemaker_session .compile_model (** config )
164
+ job_status = self .sagemaker_session .wait_for_compilation_job (job_name )
165
+ self .model_data = job_status ['ModelArtifacts' ]['S3ModelArtifacts' ]
166
+ self .image = self ._neo_image (self .sagemaker_session .boto_region_name , target_instance_family , framework ,
167
+ framework_version )
168
+ self ._is_compiled_model = True
169
+ return self
170
+
71
171
def deploy (self , initial_instance_count , instance_type , endpoint_name = None , tags = None ):
72
172
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
73
173
@@ -98,13 +198,21 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
98
198
else :
99
199
self .sagemaker_session = session .Session ()
100
200
201
+ compiled_model_suffix = '-' .join (instance_type .split ('.' )[:- 1 ])
101
202
container_def = self .prepare_container_def (instance_type )
102
203
self .name = self .name or utils .name_from_image (container_def ['Image' ])
103
204
if self .role is None :
104
205
raise ValueError ("Role can not be null for deploying a model" )
206
+ if self ._is_compiled_model :
207
+ self .name += compiled_model_suffix
105
208
self .sagemaker_session .create_model (self .name , self .role , container_def , vpc_config = self .vpc_config )
106
209
production_variant = sagemaker .production_variant (self .name , instance_type , initial_instance_count )
107
- self .endpoint_name = endpoint_name or self .name
210
+ if endpoint_name :
211
+ self .endpoint_name = endpoint_name
212
+ else :
213
+ self .endpoint_name = self .name
214
+ if self ._is_compiled_model and not self .endpoint_name .endswith (compiled_model_suffix ):
215
+ self .endpoint_name += compiled_model_suffix
108
216
self .sagemaker_session .endpoint_from_production_variants (self .endpoint_name , [production_variant ], tags )
109
217
if self .predictor_cls :
110
218
return self .predictor_cls (self .endpoint_name , self .sagemaker_session )
0 commit comments