2121METADATA_PATH = Path (__file__ ).parent .joinpath ("metadata.json" )
2222
2323
24- def model_fn (model_dir ):
24+ def model_fn (model_dir , context = None ):
2525 """Overrides default method for loading a model"""
2626 shared_libs_path = Path (model_dir + "/shared_libs" )
2727
@@ -40,7 +40,7 @@ def model_fn(model_dir):
4040 return partial (inference_spec .invoke , model = inference_spec .load (model_dir ))
4141
4242
43- def input_fn (input_data , content_type ):
43+ def input_fn (input_data , content_type , context = None ):
4444 """Deserializes the bytes that were received from the model server"""
4545 try :
4646 if hasattr (schema_builder , "custom_input_translator" ):
@@ -62,12 +62,12 @@ def input_fn(input_data, content_type):
6262 raise Exception ("Encountered error in deserialize_request." ) from e
6363
6464
65- def predict_fn (input_data , predict_callable ):
65+ def predict_fn (input_data , predict_callable , context = None ):
6666 """Invokes the model that is taken in by model server"""
6767 return predict_callable (input_data )
6868
6969
70- def output_fn (predictions , accept_type ):
70+ def output_fn (predictions , accept_type , context = None ):
7171 """Prediction is serialized to bytes and sent back to the customer"""
7272 try :
7373 if hasattr (inference_spec , "postprocess" ):
0 commit comments