Skip to content

Commit a58920c

Browse files
prats13bagEC2 Default UserEC2 Default User
authored
updating mxnet_mnist notebook (#1588)
* updating mxnet_mnist notebook * typo fix * refactoring * refactored mnist.py * updated bucket paths in the notebook for better organization * notebook updated to handle sdk upgrade Co-authored-by: EC2 Default User <[email protected]> Co-authored-by: EC2 Default User <[email protected]>
1 parent 6f86c68 commit a58920c

File tree

3 files changed

+114
-119
lines changed

3 files changed

+114
-119
lines changed

sagemaker_neo_compilation_jobs/mxnet_mnist/mnist.py

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import json
44
import logging
55
import os
6+
import io
67
import struct
7-
88
import mxnet as mx
99
import numpy as np
10-
10+
from collections import namedtuple
1111

1212
def load_data(path):
1313
with gzip.open(find_file(path, "labels.gz")) as flbl:
@@ -107,42 +107,59 @@ def parse_args():
107107

108108
return parser.parse_args()
109109

110-
### NOTE: this function cannot use MXNet
111-
def neo_preprocess(payload, content_type):
112-
import logging
113-
import numpy as np
114-
import io
115-
116-
logging.info('Invoking user-defined pre-processing function')
117-
118-
if content_type != 'application/vnd+python.numpy+binary':
119-
raise RuntimeError('Content type must be application/vnd+python.numpy+binary')
120-
121-
f = io.BytesIO(payload)
122-
return np.load(f)
123-
124-
### NOTE: this function cannot use MXNet
125-
def neo_postprocess(result):
126-
import logging
127-
import numpy as np
128-
import json
129-
130-
logging.info('Invoking user-defined post-processing function')
131-
132-
# Softmax (assumes batch size 1)
110+
### NOTE: model_fn and transform_fn are used to load the model and serve inference
111+
def model_fn(model_dir):
112+
import neomxnet # noqa: F401
113+
114+
logging.info('Invoking user-defined model_fn')
115+
116+
# change context to mx.gpu() when optimizing and deploying with Neo for GPU endpoints
117+
ctx = mx.cpu()
118+
119+
Batch = namedtuple('Batch', ['data'])
120+
sym, arg_params, aux_params = mx.model.load_checkpoint(os.path.join(model_dir, 'compiled'), 0)
121+
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
122+
exe = mod.bind(for_training=False,
123+
data_shapes=[('data', (1,784))],
124+
label_shapes=mod._label_shapes)
125+
mod.set_params(arg_params, aux_params, allow_missing=True)
126+
# run warm-up inference on empty data
127+
data = mx.nd.empty((1,784), ctx=ctx)
128+
mod.forward(Batch([data]))
129+
return mod
130+
131+
def transform_fn(mod, payload, input_content_type, output_content_type):
132+
133+
logging.info('Invoking user-defined transform_fn')
134+
Batch = namedtuple('Batch', ['data'])
135+
136+
# change context to mx.gpu() when optimizing and deploying with Neo for GPU endpoints
137+
ctx = mx.cpu()
138+
139+
if input_content_type != 'application/x-npy':
140+
raise RuntimeError('Input content type must be application/x-npy')
141+
142+
# pre-processing
143+
io_bytes_obj = io.BytesIO(payload)
144+
npy_payload = np.load(io_bytes_obj)
145+
mx_ndarray = mx.nd.array(npy_payload)
146+
inference_payload = mx_ndarray.as_in_context(ctx)
147+
148+
# prediction/inference
149+
mod.forward(Batch([inference_payload]))
150+
151+
# post-processing
152+
result = mod.get_outputs()[0].asnumpy()
133153
result = np.squeeze(result)
134154
result_exp = np.exp(result - np.max(result))
135155
result = result_exp / np.sum(result_exp)
136-
137-
response_body = json.dumps(result.tolist())
138-
content_type = 'application/json'
139-
140-
return response_body, content_type
141-
156+
output_json = json.dumps(result.tolist())
157+
output_content_type = 'application/json'
158+
return output_json, output_content_type
142159

143160
if __name__ == '__main__':
144161
args = parse_args()
145162
num_gpus = int(os.environ['SM_NUM_GPUS'])
146-
147-
train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test,
148-
args.hosts, args.current_host, args.model_dir)
163+
train(args.batch_size, args.epochs, args.learning_rate,
164+
num_gpus, args.train, args.test, args.hosts,
165+
args.current_host, args.model_dir)
Binary file not shown.

0 commit comments

Comments
 (0)