@@ -104,27 +104,32 @@ def train(
104104 save (model_dir , mlp_model )
105105
106106
107- def neo_preprocess ( payload , content_type ):
108- logging . info ( "Invoking user-defined pre-processing function" )
107+ def model_fn ( path_to_model_files ):
108+ import neomxnet # noqa: F401
109109
110- if content_type != "application/vnd+python.numpy+binary" :
111- raise RuntimeError ("Content type must be application/vnd+python.numpy+binary" )
110+ ctx = mx .cpu ()
111+ sym , arg_params , aux_params = mx .model .load_checkpoint (
112+ os .path .join (path_to_model_files , "compiled" ), 0
113+ )
114+ mod = mx .mod .Module (symbol = sym , context = ctx , label_names = None )
115+ mod .bind (
116+ for_training = False , data_shapes = [("data" , (1 , 1 , 28 , 28 ))], label_shapes = mod ._label_shapes
117+ )
118+ mod .set_params (arg_params , aux_params , allow_missing = True )
119+ return mod
112120
113- return np .asarray (json .loads (payload .decode ("utf-8" )))
114121
122+ def transform_fn (mod , payload , input_content_type , requested_output_content_type ):
123+ import neomxnet # noqa: F401
115124
116- # NOTE: this function cannot use MXNet
117- def neo_postprocess (result ):
118- logging .info ("Invoking user-defined post-processing function" )
125+ if input_content_type != "application/vnd+python.numpy+binary" :
126+ raise RuntimeError ("Input content type must be application/vnd+python.numpy+binary" )
119127
120- # Softmax (assumes batch size 1)
128+ inference_payload = np .asarray (json .loads (payload .decode ("utf-8" )))
129+ result = mod .predict (inference_payload )
121130 result = np .squeeze (result )
122- result_exp = np .exp (result - np .max (result ))
123- result = result_exp / np .sum (result_exp )
124-
125- response_body = json .dumps (result .tolist ())
131+ response_body = json .dumps (result .asnumpy ().tolist ())
126132 content_type = "application/json"
127-
128133 return response_body , content_type
129134
130135
@@ -135,7 +140,7 @@ def neo_postprocess(result):
135140 parser = argparse .ArgumentParser ()
136141
137142 parser .add_argument ("--batch-size" , type = int , default = 100 )
138- parser .add_argument ("--epochs" , type = int , default = 10 )
143+ parser .add_argument ("--epochs" , type = int , default = 1 )
139144 parser .add_argument ("--learning-rate" , type = float , default = 0.1 )
140145
141146 parser .add_argument ("--model-dir" , type = str , default = os .environ ["SM_MODEL_DIR" ])
0 commit comments