From d9dd3e761f0f3a76c5f15a9d1ddc4d580df37ea8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Thu, 5 Dec 2024 23:22:21 +0000 Subject: [PATCH] updated inference script to cover context --- .../serve/model_server/multi_model_server/inference.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 595b9d9c39..7346209c4a 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -21,7 +21,7 @@ METADATA_PATH = Path(__file__).parent.joinpath("metadata.json") -def model_fn(model_dir): +def model_fn(model_dir, context=None): """Overrides default method for loading a model""" shared_libs_path = Path(model_dir + "/shared_libs") @@ -40,7 +40,7 @@ def model_fn(model_dir): return partial(inference_spec.invoke, model=inference_spec.load(model_dir)) -def input_fn(input_data, content_type): +def input_fn(input_data, content_type, context=None): """Deserializes the bytes that were received from the model server""" try: if hasattr(schema_builder, "custom_input_translator"): @@ -62,12 +62,12 @@ def input_fn(input_data, content_type): raise Exception("Encountered error in deserialize_request.") from e -def predict_fn(input_data, predict_callable): +def predict_fn(input_data, predict_callable, context=None): """Invokes the model that is taken in by model server""" return predict_callable(input_data) -def output_fn(predictions, accept_type): +def output_fn(predictions, accept_type, context=None): """Prediction is serialized to bytes and sent back to the customer""" try: if hasattr(inference_spec, "postprocess"):