Skip to content

Commit 09dd417

Browse files
authored
Document data encoding / decoding in inference.md (#884)
Custom endpoints need to encode and decode the data passed to the callback functions. This is a little unclear in the current documentation and one may assumes the data has already `dict` format (like I did). Thus, I updated the examples to make the encoding and decoding of the JSON string arguments explicit.
1 parent ccaa63c commit 09dd417

File tree

1 file changed

+32
-7
lines changed

1 file changed

+32
-7
lines changed

docs/sagemaker/inference.md

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -346,25 +346,50 @@ The `inference.py` file contains your custom inference module, and the `requirem
346346
Here is an example of a custom inference module with `model_fn`, `input_fn`, `predict_fn`, and `output_fn`:
347347

348348
```python
349+
from sagemaker_huggingface_inference_toolkit import decoder_encoder
350+
349351
def model_fn(model_dir):
350-
return "model"
352+
# implement custom code to load the model
353+
loaded_model = ...
354+
355+
return loaded_model
351356

352-
def input_fn(data, content_type):
353-
return "data"
357+
def input_fn(input_data, content_type):
358+
# decode the input data (e.g. JSON string -> dict)
359+
data = decoder_encoder.decode(input_data, content_type)
360+
return data
354361

355362
def predict_fn(data, model):
356-
return "output"
363+
# call your custom model with the data
364+
outputs = model(data , ... )
365+
return predictions
357366

358367
def output_fn(prediction, accept):
359-
return prediction
368+
# convert the model output to the desired output format (e.g. dict -> JSON string)
369+
response = decoder_encoder.encode(prediction, accept)
370+
return response
360371
```
361372

362373
Customize your inference module with only `model_fn` and `transform_fn`:
363374

364375
```python
376+
from sagemaker_huggingface_inference_toolkit import decoder_encoder
377+
365378
def model_fn(model_dir):
366-
return "loading model"
379+
# implement custom code to load the model
380+
loaded_model = ...
381+
382+
return loaded_model
367383

368384
def transform_fn(model, input_data, content_type, accept):
369-
return f"output"
385+
# decode the input data (e.g. JSON string -> dict)
386+
data = decoder_encoder.decode(input_data, content_type)
387+
388+
# call your custom model with the data
389+
outputs = model(data , ... )
390+
391+
# convert the model output to the desired output format (e.g. dict -> JSON string)
392+
response = decoder_encoder.encode(output, accept)
393+
394+
return response
370395
```

0 commit comments

Comments
 (0)