Skip to content

Commit 78c4e9c

Browse files
authored
Skip input_schema for hftransformersv2 back compat (#88)
* Skip input_schema for hftransformersv2 back compat Signed-off-by: Walter Martin <[email protected]> * lint Signed-off-by: Walter Martin <[email protected]> * only check for hftransformers if input arg is dict Signed-off-by: Walter Martin <[email protected]> --------- Signed-off-by: Walter Martin <[email protected]>
1 parent 9f5ace8 commit 78c4e9c

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

inference_schema/schema_decorators.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ def input_schema(param_name, param_type, convert_to_provided_type=True, optional
4343

4444
@_schema_decorator(attr_name=INPUT_SCHEMA_ATTR, schema=swagger_schema, supported_versions=supported_versions)
4545
def decorator_input(user_run, instance, args, kwargs):
46-
if convert_to_provided_type:
46+
is_hftransformersv2 = False
47+
if len(args) > 0 and type(args[0]) is dict:
48+
args_keys = args[0].keys()
49+
is_hftransformersv2 = len(args_keys) == 2 and "parameters" in args_keys and "input_string" in args_keys
50+
# skip all of this for hftransformersv2
51+
if convert_to_provided_type and not is_hftransformersv2:
4752
args = list(args)
48-
4953
if param_name not in kwargs.keys() and not optional:
5054
decorators = _get_decorators(user_run)
5155
arg_names = inspect.getfullargspec(decorators[-1]).args

tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,19 @@ def standard_py_func(param):
266266
return standard_py_func
267267

268268

269+
@pytest.fixture(scope="session")
270+
def decorated_standard_func_parameters(standard_sample_input, sample_param_dict):
271+
@input_schema('input_data', StandardPythonParameterType(standard_sample_input))
272+
@input_schema('params', StandardPythonParameterType(sample_param_dict), optional=False)
273+
def standard_params_func(input_data, params=None):
274+
if params is not None:
275+
assert type(params) is dict
276+
beams = params['num_beams'] if params is not None else 0
277+
return input_data["input_string"], beams
278+
279+
return standard_params_func
280+
281+
269282
@pytest.fixture(scope="session")
270283
def standard_sample_input_multitype_list():
271284
return ['foo', 1]

tests/test_standard_parameter_type.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,15 @@ def test_float_int_handling(self, decorated_float_func):
5252
int_input = 1
5353
result = decorated_float_func(int_input)
5454
assert int_input == result
55+
56+
def test_standard_params_handling_hftransformersv2(self, decorated_standard_func_parameters):
57+
input_data = {
58+
"input_string": ["the meaning of life is"],
59+
"parameters": {
60+
"num_beams": 2,
61+
"max_length": 512
62+
}
63+
}
64+
result = decorated_standard_func_parameters(input_data)
65+
assert result[0][0] == "the meaning of life is"
66+
assert result[1] == 0

0 commit comments

Comments
 (0)