Skip to content

Commit 1b61b1b

Browse files
authored
ignore params key in schema_decorators (#79)
* ignore params key in schema_decorators Signed-off-by: Walter Martin <[email protected]> * change input data fixture name Signed-off-by: Walter Martin <[email protected]> * nested parameter type saves the day Signed-off-by: Walter Martin <[email protected]> * nested schema for params, handle requests without params Signed-off-by: Walter Martin <[email protected]> * linting Signed-off-by: Walter Martin <[email protected]> * linting 2.0 Signed-off-by: Walter Martin <[email protected]> --------- Signed-off-by: Walter Martin <[email protected]>
1 parent 215fe33 commit 1b61b1b

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

inference_schema/schema_decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _deserialize_input_argument(input_data, param_type, param_name):
327327
# parameters other than subclass of AbstractParameterType will not be handled
328328
for k, v in sample_data_type_map.items():
329329
if k not in input_data.keys():
330-
raise Exception('Invalid input. Expected: key "{0}" in "{1}"'.format(k, param_name))
330+
continue
331331
input_data[k] = _deserialize_input_argument(input_data[k], v, k)
332332
elif sample_data_type in (list, tuple):
333333
sample_data_type_list = param_type.sample_data_type_list

tests/conftest.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,26 @@ def pandas_sample_input_with_url():
6969
return pd.DataFrame(data=pandas_input_data)
7070

7171

72+
@pytest.fixture(scope="session")
73+
def pandas_sample_input_for_params():
74+
import json
75+
pandas_input_data = {
76+
"columns": [
77+
"sentence1"
78+
],
79+
"data": [
80+
["this is a string starting with"]
81+
],
82+
"index": [0]
83+
}
84+
return pd.read_json(json.dumps(pandas_input_data), orient='split')
85+
86+
87+
@pytest.fixture(scope="session")
88+
def sample_param_dict():
89+
return {"num_beams": 1, "max_length": 2}
90+
91+
7292
@pytest.fixture(scope="session")
7393
def decorated_pandas_func(pandas_sample_input, pandas_sample_output):
7494
@input_schema('param', PandasParameterType(pandas_sample_input))
@@ -164,6 +184,23 @@ def pandas_url_func(param):
164184
return pandas_url_func
165185

166186

187+
@pytest.fixture(scope="session")
188+
def decorated_pandas_func_parameters(pandas_sample_input_for_params, sample_param_dict):
189+
@input_schema('input_data', StandardPythonParameterType({
190+
'split_df': PandasParameterType(pandas_sample_input_for_params, orient='split'),
191+
'parameters': StandardPythonParameterType(sample_param_dict)
192+
}))
193+
def pandas_params_func(input_data):
194+
assert type(input_data) is dict
195+
assert type(input_data["split_df"]) is pd.DataFrame
196+
if 'parameters' in input_data:
197+
assert type(input_data["parameters"]) is dict
198+
beams = input_data['parameters']['num_beams'] if 'parameters' in input_data else 0
199+
return input_data["split_df"]["sentence1"], beams
200+
201+
return pandas_params_func
202+
203+
167204
@pytest.fixture(scope="session")
168205
def pandas_sample_input_with_categorical():
169206
pandas_input_data = {'state': ['characters'], 'cat': ['000']}

tests/test_pandas_parameter_type.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,42 @@ def test_pandas_categorical_handling(self, decorated_pandas_categorical_func):
7878
result = decorated_pandas_categorical_func(pandas_input)
7979
assert categorical == result
8080

81+
def test_pandas_params_handling(self, decorated_pandas_func_parameters):
82+
pandas_input_data = {"input_data": {
83+
"split_df": {
84+
"columns": [
85+
"sentence1"
86+
],
87+
"data": [
88+
["this is a string starting with"]
89+
],
90+
"index": [0]
91+
},
92+
"parameters": {
93+
"num_beams": 2,
94+
"max_length": 512
95+
}
96+
}}
97+
result = decorated_pandas_func_parameters(**pandas_input_data)
98+
assert result[0][0] == "this is a string starting with"
99+
assert result[1] == 2
100+
101+
def test_pandas_params_handling_without_params(self, decorated_pandas_func_parameters):
102+
pandas_input_data = {"input_data": {
103+
"split_df": {
104+
"columns": [
105+
"sentence1"
106+
],
107+
"data": [
108+
["this is a string starting with"]
109+
],
110+
"index": [0]
111+
}
112+
}}
113+
result = decorated_pandas_func_parameters(**pandas_input_data)
114+
assert result[0][0] == "this is a string starting with"
115+
assert result[1] == 0
116+
81117

82118
class TestNestedType(object):
83119

0 commit comments

Comments
 (0)