2121METADATA_PATH = Path (__file__ ).parent .joinpath ("metadata.json" )
2222
2323
24- def model_fn (model_dir ):
24+ def model_fn (model_dir , context = None ):
2525 """Overrides default method for loading a model"""
2626 shared_libs_path = Path (model_dir + "/shared_libs" )
2727
@@ -40,7 +40,7 @@ def model_fn(model_dir):
4040 return partial (inference_spec .invoke , model = inference_spec .load (model_dir ))
4141
4242
43- def input_fn (input_data , content_type ):
43+ def input_fn (input_data , content_type , context = None ):
4444 """Deserializes the bytes that were received from the model server"""
4545 try :
4646 if hasattr (schema_builder , "custom_input_translator" ):
@@ -72,12 +72,12 @@ def input_fn(input_data, content_type):
7272 raise Exception ("Encountered error in deserialize_request." ) from e
7373
7474
75- def predict_fn (input_data , predict_callable ):
75+ def predict_fn (input_data , predict_callable , context = None ):
7676 """Invokes the model that is taken in by model server"""
7777 return predict_callable (input_data )
7878
7979
80- def output_fn (predictions , accept_type ):
80+ def output_fn (predictions , accept_type , context = None ):
8181 """Prediction is serialized to bytes and sent back to the customer"""
8282 try :
8383 if hasattr (inference_spec , "postprocess" ):
0 commit comments