|
3 | 3 | import json
|
4 | 4 | import logging
|
5 | 5 | import os
|
| 6 | +import io |
6 | 7 | import struct
|
7 |
| - |
8 | 8 | import mxnet as mx
|
9 | 9 | import numpy as np
|
10 |
| - |
| 10 | +from collections import namedtuple |
11 | 11 |
|
12 | 12 | def load_data(path):
|
13 | 13 | with gzip.open(find_file(path, "labels.gz")) as flbl:
|
@@ -107,42 +107,59 @@ def parse_args():
|
107 | 107 |
|
108 | 108 | return parser.parse_args()
|
109 | 109 |
|
110 |
| -### NOTE: this function cannot use MXNet |
111 |
| -def neo_preprocess(payload, content_type): |
112 |
| - import logging |
113 |
| - import numpy as np |
114 |
| - import io |
115 |
| - |
116 |
| - logging.info('Invoking user-defined pre-processing function') |
117 |
| - |
118 |
| - if content_type != 'application/vnd+python.numpy+binary': |
119 |
| - raise RuntimeError('Content type must be application/vnd+python.numpy+binary') |
120 |
| - |
121 |
| - f = io.BytesIO(payload) |
122 |
| - return np.load(f) |
123 |
| - |
124 |
| -### NOTE: this function cannot use MXNet |
125 |
| -def neo_postprocess(result): |
126 |
| - import logging |
127 |
| - import numpy as np |
128 |
| - import json |
129 |
| - |
130 |
| - logging.info('Invoking user-defined post-processing function') |
131 |
| - |
132 |
| - # Softmax (assumes batch size 1) |
| 110 | +### NOTE: model_fn and transform_fn are used to load the model and serve inference |
| 111 | +def model_fn(model_dir): |
| 112 | + import neomxnet # noqa: F401 |
| 113 | + |
| 114 | + logging.info('Invoking user-defined model_fn') |
| 115 | + |
| 116 | + # change context to mx.gpu() when optimizing and deploying with Neo for GPU endpoints |
| 117 | + ctx = mx.cpu() |
| 118 | + |
| 119 | + Batch = namedtuple('Batch', ['data']) |
| 120 | + sym, arg_params, aux_params = mx.model.load_checkpoint(os.path.join(model_dir, 'compiled'), 0) |
| 121 | + mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None) |
| 122 | + exe = mod.bind(for_training=False, |
| 123 | + data_shapes=[('data', (1,784))], |
| 124 | + label_shapes=mod._label_shapes) |
| 125 | + mod.set_params(arg_params, aux_params, allow_missing=True) |
| 126 | + # run warm-up inference on empty data |
| 127 | + data = mx.nd.empty((1,784), ctx=ctx) |
| 128 | + mod.forward(Batch([data])) |
| 129 | + return mod |
| 130 | + |
| 131 | +def transform_fn(mod, payload, input_content_type, output_content_type): |
| 132 | + |
| 133 | + logging.info('Invoking user-defined transform_fn') |
| 134 | + Batch = namedtuple('Batch', ['data']) |
| 135 | + |
| 136 | + # change context to mx.gpu() when optimizing and deploying with Neo for GPU endpoints |
| 137 | + ctx = mx.cpu() |
| 138 | + |
| 139 | + if input_content_type != 'application/x-npy': |
| 140 | + raise RuntimeError('Input content type must be application/x-npy') |
| 141 | + |
| 142 | + # pre-processing |
| 143 | + io_bytes_obj = io.BytesIO(payload) |
| 144 | + npy_payload = np.load(io_bytes_obj) |
| 145 | + mx_ndarray = mx.nd.array(npy_payload) |
| 146 | + inference_payload = mx_ndarray.as_in_context(ctx) |
| 147 | + |
| 148 | + # prediction/inference |
| 149 | + mod.forward(Batch([inference_payload])) |
| 150 | + |
| 151 | + # post-processing |
| 152 | + result = mod.get_outputs()[0].asnumpy() |
133 | 153 | result = np.squeeze(result)
|
134 | 154 | result_exp = np.exp(result - np.max(result))
|
135 | 155 | result = result_exp / np.sum(result_exp)
|
136 |
| - |
137 |
| - response_body = json.dumps(result.tolist()) |
138 |
| - content_type = 'application/json' |
139 |
| - |
140 |
| - return response_body, content_type |
141 |
| - |
| 156 | + output_json = json.dumps(result.tolist()) |
| 157 | + output_content_type = 'application/json' |
| 158 | + return output_json, output_content_type |
142 | 159 |
|
143 | 160 | if __name__ == '__main__':
|
144 | 161 | args = parse_args()
|
145 | 162 | num_gpus = int(os.environ['SM_NUM_GPUS'])
|
146 |
| - |
147 |
| - train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test, |
148 |
| - args.hosts, args.current_host, args.model_dir) |
| 163 | + train(args.batch_size, args.epochs, args.learning_rate, |
| 164 | + num_gpus, args.train, args.test, args.hosts, |
| 165 | + args.current_host, args.model_dir) |
0 commit comments