21
21
METADATA_PATH = Path (__file__ ).parent .joinpath ("metadata.json" )
22
22
23
23
24
- def model_fn (model_dir ):
24
+ def model_fn (model_dir , context = None ):
25
25
"""Overrides default method for loading a model"""
26
26
shared_libs_path = Path (model_dir + "/shared_libs" )
27
27
@@ -40,7 +40,7 @@ def model_fn(model_dir):
40
40
return partial (inference_spec .invoke , model = inference_spec .load (model_dir ))
41
41
42
42
43
- def input_fn (input_data , content_type ):
43
+ def input_fn (input_data , content_type , context = None ):
44
44
"""Deserializes the bytes that were received from the model server"""
45
45
try :
46
46
if hasattr (schema_builder , "custom_input_translator" ):
@@ -62,12 +62,12 @@ def input_fn(input_data, content_type):
62
62
raise Exception ("Encountered error in deserialize_request." ) from e
63
63
64
64
65
- def predict_fn (input_data , predict_callable ):
65
+ def predict_fn (input_data , predict_callable , context = None ):
66
66
"""Invokes the model that is taken in by model server"""
67
67
return predict_callable (input_data )
68
68
69
69
70
- def output_fn (predictions , accept_type ):
70
+ def output_fn (predictions , accept_type , context = None ):
71
71
"""Prediction is serialized to bytes and sent back to the customer"""
72
72
try :
73
73
if hasattr (inference_spec , "postprocess" ):
0 commit comments