|
11 | 11 | "cell_type": "markdown", |
12 | 12 | "metadata": {}, |
13 | 13 | "source": [ |
14 | | - "Amazon SageMaker Neo is API to compile machine learning models to optimize them for our choice of hardward targets. Currently, Neo supports pre-trained PyTorch models from [TorchVision](https://pytorch.org/docs/stable/torchvision/models.html). General support for other PyTorch models is forthcoming." |
| 14 | + "Amazon SageMaker Neo is an API to compile machine learning models to optimize them for our choice of hardward targets. Currently, Neo supports pre-trained PyTorch models from [TorchVision](https://pytorch.org/docs/stable/torchvision/models.html). General support for other PyTorch models is forthcoming." |
15 | 15 | ] |
16 | 16 | }, |
17 | 17 | { |
|
20 | 20 | "metadata": {}, |
21 | 21 | "outputs": [], |
22 | 22 | "source": [ |
23 | | - "!~/anaconda3/envs/pytorch_p36/bin/pip install torch==1.2.0 torchvision==0.4.0" |
| 23 | + "!~/anaconda3/envs/pytorch_p36/bin/pip install torch==1.4.0 torchvision==0.5.0" |
| 24 | + ] |
| 25 | + }, |
| 26 | + { |
| 27 | + "cell_type": "code", |
| 28 | + "execution_count": null, |
| 29 | + "metadata": {}, |
| 30 | + "outputs": [], |
| 31 | + "source": [ |
| 32 | + "!~/anaconda3/envs/pytorch_p36/bin/pip install --upgrade sagemaker" |
24 | 33 | ] |
25 | 34 | }, |
26 | 35 | { |
|
34 | 43 | "cell_type": "markdown", |
35 | 44 | "metadata": {}, |
36 | 45 | "source": [ |
37 | | - "We'll import [ResNet18](https://arxiv.org/abs/1512.03385) model from TorchVision and create a model artifact `model.tar.gz`:" |
| 46 | + "We'll import [ResNet18](https://arxiv.org/abs/1512.03385) model from TorchVision and create a model artifact `model.tar.gz`." |
38 | 47 | ] |
39 | 48 | }, |
40 | 49 | { |
|
60 | 69 | "cell_type": "markdown", |
61 | 70 | "metadata": {}, |
62 | 71 | "source": [ |
63 | | - "## Invoke Neo Compilation API" |
64 | | - ] |
65 | | - }, |
66 | | - { |
67 | | - "cell_type": "markdown", |
68 | | - "metadata": {}, |
69 | | - "source": [ |
70 | | - "We then forward the model artifact to Neo Compilation API:" |
| 72 | + "### Upload the model archive to S3" |
71 | 73 | ] |
72 | 74 | }, |
73 | 75 | { |
|
87 | 89 | "bucket = sess.default_bucket()\n", |
88 | 90 | "\n", |
89 | 91 | "compilation_job_name = name_from_base('TorchVision-ResNet18-Neo')\n", |
| 92 | + "prefix = compilation_job_name+'/model'\n", |
90 | 93 | "\n", |
91 | | - "model_key = '{}/model/model.tar.gz'.format(compilation_job_name)\n", |
92 | | - "model_path = 's3://{}/{}'.format(bucket, model_key)\n", |
93 | | - "boto3.resource('s3').Bucket(bucket).upload_file('model.tar.gz', model_key)\n", |
| 94 | + "model_path = sess.upload_data(path='model.tar.gz', key_prefix=prefix)\n", |
94 | 95 | "\n", |
95 | | - "sm_client = boto3.client('sagemaker')\n", |
96 | 96 | "data_shape = '{\"input0\":[1,3,224,224]}'\n", |
97 | 97 | "target_device = 'ml_c5'\n", |
98 | 98 | "framework = 'PYTORCH'\n", |
99 | | - "framework_version = '1.2.0'\n", |
| 99 | + "framework_version = '1.4.0'\n", |
100 | 100 | "compiled_model_path = 's3://{}/{}/output'.format(bucket, compilation_job_name)" |
101 | 101 | ] |
102 | 102 | }, |
103 | | - { |
104 | | - "cell_type": "code", |
105 | | - "execution_count": null, |
106 | | - "metadata": {}, |
107 | | - "outputs": [], |
108 | | - "source": [ |
109 | | - "response = sm_client.create_compilation_job(\n", |
110 | | - " CompilationJobName=compilation_job_name,\n", |
111 | | - " RoleArn=role,\n", |
112 | | - " InputConfig={\n", |
113 | | - " 'S3Uri': model_path,\n", |
114 | | - " 'DataInputConfig': data_shape,\n", |
115 | | - " 'Framework': framework\n", |
116 | | - " },\n", |
117 | | - " OutputConfig={\n", |
118 | | - " 'S3OutputLocation': compiled_model_path,\n", |
119 | | - " 'TargetDevice': target_device\n", |
120 | | - " },\n", |
121 | | - " StoppingCondition={\n", |
122 | | - " 'MaxRuntimeInSeconds': 300\n", |
123 | | - " }\n", |
124 | | - ")\n", |
125 | | - "print(response)\n", |
126 | | - "\n", |
127 | | - "# Poll every 30 sec\n", |
128 | | - "while True:\n", |
129 | | - " response = sm_client.describe_compilation_job(CompilationJobName=compilation_job_name)\n", |
130 | | - " if response['CompilationJobStatus'] == 'COMPLETED':\n", |
131 | | - " break\n", |
132 | | - " elif response['CompilationJobStatus'] == 'FAILED':\n", |
133 | | - " raise RuntimeError('Compilation failed')\n", |
134 | | - " print('Compiling ...')\n", |
135 | | - " time.sleep(30)\n", |
136 | | - "print('Done!')\n", |
137 | | - "\n", |
138 | | - "# Extract compiled model artifact\n", |
139 | | - "compiled_model_path = response['ModelArtifacts']['S3ModelArtifacts']" |
140 | | - ] |
141 | | - }, |
142 | | - { |
143 | | - "cell_type": "markdown", |
144 | | - "metadata": {}, |
145 | | - "source": [ |
146 | | - "## Create prediction endpoint" |
147 | | - ] |
148 | | - }, |
149 | | - { |
150 | | - "cell_type": "markdown", |
151 | | - "metadata": {}, |
152 | | - "source": [ |
153 | | - "To create a prediction endpoint, we first specify two additional functions, to be used with Neo Deep Learning Runtime:\n", |
154 | | - "\n", |
155 | | - "* `neo_preprocess(payload, content_type)`: Function that takes in the payload and Content-Type of each incoming request and returns a NumPy array. Here, the payload is byte-encoded NumPy array, so the function simply decodes the bytes to obtain the NumPy array.\n", |
156 | | - "* `neo_postprocess(result)`: Function that takes the prediction results produced by Deep Learining Runtime and returns the response body" |
157 | | - ] |
158 | | - }, |
159 | | - { |
160 | | - "cell_type": "code", |
161 | | - "execution_count": null, |
162 | | - "metadata": {}, |
163 | | - "outputs": [], |
164 | | - "source": [ |
165 | | - "!pygmentize resnet18.py" |
166 | | - ] |
167 | | - }, |
168 | 103 | { |
169 | 104 | "cell_type": "markdown", |
170 | 105 | "metadata": {}, |
171 | 106 | "source": [ |
172 | | - "Upload the Python script containing the two functions to S3:" |
173 | | - ] |
174 | | - }, |
175 | | - { |
176 | | - "cell_type": "code", |
177 | | - "execution_count": null, |
178 | | - "metadata": {}, |
179 | | - "outputs": [], |
180 | | - "source": [ |
181 | | - "source_key = '{}/source/sourcedir.tar.gz'.format(compilation_job_name)\n", |
182 | | - "source_path = 's3://{}/{}'.format(bucket, source_key)\n", |
183 | | - "\n", |
184 | | - "with tarfile.open('sourcedir.tar.gz', 'w:gz') as f:\n", |
185 | | - " f.add('resnet18.py')\n", |
186 | | - "\n", |
187 | | - "boto3.resource('s3').Bucket(bucket).upload_file('sourcedir.tar.gz', source_key)" |
| 107 | + "## Invoke Neo Compilation API" |
188 | 108 | ] |
189 | 109 | }, |
190 | 110 | { |
191 | 111 | "cell_type": "markdown", |
192 | 112 | "metadata": {}, |
193 | 113 | "source": [ |
194 | | - "We then create a SageMaker model record:" |
| 114 | + "### Create a PyTorch SageMaker model" |
195 | 115 | ] |
196 | 116 | }, |
197 | 117 | { |
|
200 | 120 | "metadata": {}, |
201 | 121 | "outputs": [], |
202 | 122 | "source": [ |
203 | | - "from sagemaker.model import NEO_IMAGE_ACCOUNT\n", |
204 | | - "from sagemaker.fw_utils import create_image_uri\n", |
205 | | - "\n", |
206 | | - "model_name = name_from_base('TorchVision-ResNet18-Neo')\n", |
| 123 | + "from sagemaker.pytorch.model import PyTorchModel\n", |
| 124 | + "from sagemaker.predictor import Predictor\n", |
207 | 125 | "\n", |
208 | | - "image_uri = create_image_uri(region, 'neo-' + framework.lower(), target_device.replace('_', '.'),\n", |
209 | | - " framework_version, py_version='py3', account=NEO_IMAGE_ACCOUNT[region])\n", |
210 | | - "\n", |
211 | | - "response = sm_client.create_model(\n", |
212 | | - " ModelName=model_name,\n", |
213 | | - " PrimaryContainer={\n", |
214 | | - " 'Image': image_uri,\n", |
215 | | - " 'ModelDataUrl': compiled_model_path,\n", |
216 | | - " 'Environment': { 'SAGEMAKER_SUBMIT_DIRECTORY': source_path }\n", |
217 | | - " },\n", |
218 | | - " ExecutionRoleArn=role\n", |
219 | | - ")\n", |
220 | | - "print(response)" |
| 126 | + "sagemaker_model = PyTorchModel(model_data=model_path,\n", |
| 127 | + " predictor_cls=Predictor,\n", |
| 128 | + " framework_version = framework_version,\n", |
| 129 | + " role=role,\n", |
| 130 | + " sagemaker_session=sess,\n", |
| 131 | + " entry_point='resnet18.py',\n", |
| 132 | + " source_dir='code',\n", |
| 133 | + " py_version='py3',\n", |
| 134 | + " env={'MMS_DEFAULT_RESPONSE_TIMEOUT': '500'}\n", |
| 135 | + " )" |
221 | 136 | ] |
222 | 137 | }, |
223 | 138 | { |
224 | 139 | "cell_type": "markdown", |
225 | 140 | "metadata": {}, |
226 | 141 | "source": [ |
227 | | - "Then we create an Endpoint Configuration:" |
| 142 | + "### Use Neo compiler to compile the model" |
228 | 143 | ] |
229 | 144 | }, |
230 | 145 | { |
|
233 | 148 | "metadata": {}, |
234 | 149 | "outputs": [], |
235 | 150 | "source": [ |
236 | | - "config_name = model_name\n", |
237 | | - "\n", |
238 | | - "response = sm_client.create_endpoint_config(\n", |
239 | | - " EndpointConfigName=config_name,\n", |
240 | | - " ProductionVariants=[\n", |
241 | | - " {\n", |
242 | | - " 'VariantName': 'default-variant-name',\n", |
243 | | - " 'ModelName': model_name,\n", |
244 | | - " 'InitialInstanceCount': 1,\n", |
245 | | - " 'InstanceType': 'ml.c5.xlarge',\n", |
246 | | - " 'InitialVariantWeight': 1.0\n", |
247 | | - " },\n", |
248 | | - " ],\n", |
249 | | - ")\n", |
250 | | - "print(response)" |
| 151 | + "compiled_model = sagemaker_model.compile(target_instance_family=target_device, \n", |
| 152 | + " input_shape=data_shape,\n", |
| 153 | + " job_name=compilation_job_name,\n", |
| 154 | + " role=role,\n", |
| 155 | + " framework=framework.lower(),\n", |
| 156 | + " framework_version=framework_version,\n", |
| 157 | + " output_path=compiled_model_path\n", |
| 158 | + " )" |
251 | 159 | ] |
252 | 160 | }, |
253 | 161 | { |
254 | 162 | "cell_type": "markdown", |
255 | 163 | "metadata": {}, |
256 | 164 | "source": [ |
257 | | - "Finally, we create an Endpoint:" |
| 165 | + "## Deploy the model" |
258 | 166 | ] |
259 | 167 | }, |
260 | 168 | { |
|
263 | 171 | "metadata": {}, |
264 | 172 | "outputs": [], |
265 | 173 | "source": [ |
266 | | - "endpoint_name = model_name + '-Endpoint'\n", |
267 | | - "\n", |
268 | | - "response = sm_client.create_endpoint(\n", |
269 | | - " EndpointName=endpoint_name,\n", |
270 | | - " EndpointConfigName=config_name,\n", |
271 | | - ")\n", |
272 | | - "print(response)\n", |
273 | | - "\n", |
274 | | - "print('Creating endpoint ...')\n", |
275 | | - "waiter = sm_client.get_waiter('endpoint_in_service')\n", |
276 | | - "waiter.wait(EndpointName=endpoint_name)\n", |
277 | | - "\n", |
278 | | - "response = sm_client.describe_endpoint(EndpointName=endpoint_name)\n", |
279 | | - "print(response)" |
| 174 | + "predictor = compiled_model.deploy(initial_instance_count = 1,\n", |
| 175 | + " instance_type = 'ml.c5.9xlarge'\n", |
| 176 | + " )" |
280 | 177 | ] |
281 | 178 | }, |
282 | 179 | { |
|
301 | 198 | "metadata": {}, |
302 | 199 | "outputs": [], |
303 | 200 | "source": [ |
304 | | - "import json\n", |
305 | 201 | "import numpy as np\n", |
306 | | - "\n", |
307 | | - "sm_runtime = boto3.Session().client('sagemaker-runtime')\n", |
| 202 | + "import json\n", |
308 | 203 | "\n", |
309 | 204 | "with open('cat.jpg', 'rb') as f:\n", |
310 | 205 | " payload = f.read()\n", |
| 206 | + " payload = bytearray(payload) \n", |
311 | 207 | "\n", |
312 | | - "response = sm_runtime.invoke_endpoint(EndpointName=endpoint_name,\n", |
313 | | - " ContentType='application/x-image',\n", |
314 | | - " Body=payload)\n", |
315 | | - "print(response)\n", |
316 | | - "result = json.loads(response['Body'].read().decode())\n", |
| 208 | + "response = predictor.predict(payload)\n", |
| 209 | + "result = json.loads(response.decode())\n", |
317 | 210 | "print('Most likely class: {}'.format(np.argmax(result)))" |
318 | 211 | ] |
319 | 212 | }, |
|
346 | 239 | "metadata": {}, |
347 | 240 | "outputs": [], |
348 | 241 | "source": [ |
349 | | - "sess.delete_endpoint(endpoint_name)" |
| 242 | + "sess.delete_endpoint(predictor.endpoint_name)" |
350 | 243 | ] |
351 | 244 | } |
352 | 245 | ], |
|
366 | 259 | "name": "python", |
367 | 260 | "nbconvert_exporter": "python", |
368 | 261 | "pygments_lexer": "ipython3", |
369 | | - "version": "3.6.5" |
| 262 | + "version": "3.6.10" |
370 | 263 | } |
371 | 264 | }, |
372 | 265 | "nbformat": 4, |
373 | | - "nbformat_minor": 2 |
| 266 | + "nbformat_minor": 4 |
374 | 267 | } |
0 commit comments