diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 1d2440f5f9..908ffcc7aa 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -21,7 +21,7 @@ METADATA_PATH = Path(__file__).parent.joinpath("metadata.json") -def model_fn(model_dir): +def model_fn(model_dir, context=None): """Overrides default method for loading a model""" shared_libs_path = Path(model_dir + "/shared_libs") @@ -40,7 +40,7 @@ def model_fn(model_dir): return partial(inference_spec.invoke, model=inference_spec.load(model_dir)) -def input_fn(input_data, content_type): +def input_fn(input_data, content_type, context=None): """Deserializes the bytes that were received from the model server""" try: if hasattr(schema_builder, "custom_input_translator"): @@ -72,12 +72,12 @@ def input_fn(input_data, content_type): raise Exception("Encountered error in deserialize_request.") from e -def predict_fn(input_data, predict_callable): +def predict_fn(input_data, predict_callable, context=None): """Invokes the model that is taken in by model server""" return predict_callable(input_data) -def output_fn(predictions, accept_type): +def output_fn(predictions, accept_type, context=None): """Prediction is serialized to bytes and sent back to the customer""" try: if hasattr(inference_spec, "postprocess"):