Skip to content

Commit 872eadb

Browse files
authored
Don't fail validation for a float when provided an int (#59)
* Extra conditional in validation * Fix conditional. Add test
1 parent 78d4167 commit 872eadb

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

inference_schema/parameter_types/standard_py_parameter_type.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ def deserialize_input(self, input_data):
5757
input_data = parser.parse(input_data).timetz()
5858
elif self.sample_data_type is bytearray or (sys.version_info[0] == 3 and self.sample_data_type is bytes):
5959
input_data = base64.b64decode(input_data.encode('utf-8'))
60-
if not isinstance(input_data, self.sample_data_type):
60+
61+
if self.sample_data_type is float and type(input_data) is int:
62+
# Allow users to pass ints when expecting floats, for convenience
63+
pass
64+
elif not isinstance(input_data, self.sample_data_type):
6165
raise ValueError("Invalid input data type to parse. Expected: {0} but got {1}".format(
6266
self.sample_data_type, type(input_data)))
6367

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,15 @@ def standard_py_func(param):
184184
return standard_py_func
185185

186186

187+
@pytest.fixture(scope="session")
188+
def decorated_float_func():
189+
@input_schema('param', StandardPythonParameterType(1.0))
190+
def standard_float_func(param):
191+
return param
192+
193+
return standard_float_func
194+
195+
187196
@pytest.fixture(scope="session")
188197
def decorated_nested_func(standard_sample_input, numpy_sample_input, pandas_sample_input, standard_sample_output,
189198
numpy_sample_output, pandas_sample_output):

tests/test_standard_parameter_type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,12 @@ def test_standard_handling(self, decorated_standard_func):
1414
standard_input = {'param': {'name': ['Sarah'], 'state': ['WA']}}
1515
result = decorated_standard_func(**standard_input)
1616
assert state == result
17+
18+
def test_float_int_handling(self, decorated_float_func):
19+
float_input = 1.0
20+
result = decorated_float_func(float_input)
21+
assert float_input == result
22+
23+
int_input = 1
24+
result = decorated_float_func(int_input)
25+
assert int_input == result

0 commit comments

Comments
 (0)