21
21
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")
22
22
23
23
24
- def model_fn(model_dir):
24
+ def model_fn(model_dir, context=None ):
25
25
"""Overrides default method for loading a model"""
26
26
shared_libs_path = Path(model_dir + "/shared_libs")
27
27
@@ -40,7 +40,7 @@ def model_fn(model_dir):
40
40
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))
41
41
42
42
43
- def input_fn(input_data, content_type):
43
+ def input_fn(input_data, content_type, context=None ):
44
44
"""Deserializes the bytes that were received from the model server"""
45
45
try:
46
46
if hasattr(schema_builder, "custom_input_translator"):
@@ -72,12 +72,12 @@ def input_fn(input_data, content_type):
72
72
raise Exception("Encountered error in deserialize_request.") from e
73
73
74
74
75
- def predict_fn(input_data, predict_callable):
75
+ def predict_fn(input_data, predict_callable, context=None ):
76
76
"""Invokes the model that is taken in by model server"""
77
77
return predict_callable(input_data)
78
78
79
79
80
- def output_fn(predictions, accept_type):
80
+ def output_fn(predictions, accept_type, context=None ):
81
81
"""Prediction is serialized to bytes and sent back to the customer"""
82
82
try:
83
83
if hasattr(inference_spec, "postprocess"):
0 commit comments