Skip to content

Commit 289a59d

Browse files
authored
Remove block on OpenAPI 2.0, Add Support Function to Provide Info (#58)
* Removing OpenAPI 2.0 error, adding util - the commit is more of a proposal Signed-off-by: Walter Martin <[email protected]> * tests passing Signed-off-by: Walter Martin <[email protected]> * switch output function to be more useful, add more testing Signed-off-by: Walter Martin <[email protected]> * code quality fixes Signed-off-by: Walter Martin <[email protected]> * one more trailing whitespace Signed-off-by: Walter Martin <[email protected]> * consolidate user functions Signed-off-by: Walter Martin <[email protected]> * centralized function based on json rather than type dependent Signed-off-by: Walter Martin <[email protected]> * whitespace Signed-off-by: Walter Martin <[email protected]> * PR comments Signed-off-by: Walter Martin <[email protected]> * code quality Signed-off-by: Walter Martin <[email protected]> * more cleanup Signed-off-by: Walter Martin <[email protected]>
1 parent 382580f commit 289a59d

File tree

9 files changed

+164
-9
lines changed

9 files changed

+164
-9
lines changed

inference_schema/parameter_types/_util.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ def get_swagger_for_list(python_data):
6868
item_type = type(python_data[0])
6969

7070
for data in python_data:
71-
if type(data) != item_type:
72-
raise Exception('Error, OpenAPI 2.x does not support mixed type in array.')
73-
7471
if issubclass(item_type, AbstractParameterType):
7572
nested_item_swagger = data.input_to_swagger()
7673
else:

inference_schema/parameter_types/abstract_parameter_type.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,28 @@ def __init__(self, sample_input):
2424
self.sample_input = sample_input
2525
self.sample_data_type = type(sample_input)
2626

27+
def supported_versions(self):
28+
schema = self.input_to_swagger()
29+
supported_list = ['3.0', '3.1']
30+
if self._supports_swagger_2(schema['example']):
31+
supported_list += ['2.0']
32+
return sorted(supported_list)
33+
34+
def _supports_swagger_2(self, obj):
35+
if type(obj) is list:
36+
first_type = type(obj[0])
37+
for elt in obj:
38+
if type(elt) is not first_type:
39+
return False
40+
elif type(elt) is list:
41+
if not self._supports_swagger_2(elt):
42+
return False
43+
elif type(obj) is dict:
44+
for elt in obj.values():
45+
if not self._supports_swagger_2(elt):
46+
return False
47+
return True
48+
2749
@abstractmethod
2850
def deserialize_input(self, input_data):
2951
"""

inference_schema/schema_decorators.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import copy
88
from functools import partial
99

10-
from .schema_util import _get_decorators, _get_function_full_qual_name, __functions_schema__
10+
from .schema_util import _get_decorators, _get_function_full_qual_name, __functions_schema__, __versions__
1111
from .parameter_types.abstract_parameter_type import AbstractParameterType
1212
from ._constants import INPUT_SCHEMA_ATTR, OUTPUT_SCHEMA_ATTR
1313

@@ -39,8 +39,9 @@ def input_schema(param_name, param_type, convert_to_provided_type=True):
3939
'of the AbstractParameterType.')
4040

4141
swagger_schema = {param_name: param_type.input_to_swagger()}
42+
supported_versions = param_type.supported_versions()
4243

43-
@_schema_decorator(attr_name=INPUT_SCHEMA_ATTR, schema=swagger_schema)
44+
@_schema_decorator(attr_name=INPUT_SCHEMA_ATTR, schema=swagger_schema, supported_versions=supported_versions)
4445
def decorator_input(user_run, instance, args, kwargs):
4546
if convert_to_provided_type:
4647
args = list(args)
@@ -82,16 +83,17 @@ def output_schema(output_type):
8283
'of the AbstractParameterType.')
8384

8485
swagger_schema = output_type.input_to_swagger()
86+
supported_versions = output_type.supported_versions()
8587

86-
@_schema_decorator(attr_name=OUTPUT_SCHEMA_ATTR, schema=swagger_schema)
88+
@_schema_decorator(attr_name=OUTPUT_SCHEMA_ATTR, schema=swagger_schema, supported_versions=supported_versions)
8789
def decorator_input(user_run, instance, args, kwargs):
8890
return user_run(*args, **kwargs)
8991

9092
return decorator_input
9193

9294

9395
# Heavily based on the wrapt.decorator implementation
94-
def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None):
96+
def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None, supported_versions=None):
9597
"""
9698
Decorator to generate decorators, preserving the metadata passed to the
9799
decorator arguments, that is needed to be able to extact that information
@@ -107,6 +109,8 @@ def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None):
107109
:type attr_name: str | None
108110
:param schema:
109111
:type schema: dict | None
112+
:param supported_versions:
113+
:type supported_versions: List | None
110114
:return:
111115
:rtype: function | FunctionWrapper
112116
"""
@@ -134,6 +138,7 @@ def _capture(target_wrapped):
134138
return _capture
135139

136140
_add_schema_to_global_schema_dictionary(attr_name, schema, args[0])
141+
_add_versions_to_global_versions_dictionary(attr_name, supported_versions, args[0])
137142
target_wrapped = args[0]
138143

139144
_enabled = enabled
@@ -165,7 +170,8 @@ def _capture(target_wrapped):
165170
_schema_decorator,
166171
enabled=enabled,
167172
attr_name=attr_name,
168-
schema=schema
173+
schema=schema,
174+
supported_versions=supported_versions
169175
)
170176

171177

@@ -201,6 +207,35 @@ def _add_schema_to_global_schema_dictionary(attr_name, schema, user_func):
201207
pass
202208

203209

210+
def _add_versions_to_global_versions_dictionary(attr_name, versions, user_func):
211+
"""
212+
function to add supported swagger versions for 'attr_name', to the function versions dict
213+
214+
:param attr_name:
215+
:type attr_name: str
216+
:param versions:
217+
:type versions: List
218+
:param user_func:
219+
:type user_func: function | FunctionWrapper
220+
:return:
221+
:rtype:
222+
"""
223+
224+
if attr_name is None or versions is None:
225+
pass
226+
227+
decorators = _get_decorators(user_func)
228+
base_func_name = _get_function_full_qual_name(decorators[-1])
229+
230+
if base_func_name not in __versions__.keys():
231+
__versions__[base_func_name] = {}
232+
233+
if attr_name == INPUT_SCHEMA_ATTR or attr_name == OUTPUT_SCHEMA_ATTR:
234+
_add_attr_versions_to_global_schema_dictionary(base_func_name, versions, attr_name)
235+
else:
236+
pass
237+
238+
204239
def _add_input_schema_to_global_schema_dictionary(base_func_name, arg_names, schema):
205240
"""
206241
function to add a generated input schema, to the function schema dict
@@ -233,6 +268,29 @@ def _add_input_schema_to_global_schema_dictionary(base_func_name, arg_names, sch
233268
__functions_schema__[base_func_name][INPUT_SCHEMA_ATTR]["properties"][k] = item_swagger
234269

235270

271+
def _add_attr_versions_to_global_schema_dictionary(base_func_name, versions, attr):
272+
"""
273+
function to add supported swagger versions to the version dict
274+
275+
:param base_func_name: function full qualified name
276+
:type base_func_name: str
277+
:param versions:
278+
:type versions: list
279+
:param attr:
280+
:type attr: str
281+
:return:
282+
:rtype:
283+
"""
284+
285+
if attr not in __versions__[base_func_name].keys():
286+
__versions__[base_func_name][attr] = {
287+
"type": "object",
288+
"versions": {}
289+
}
290+
291+
__versions__[base_func_name][attr]["versions"] = versions
292+
293+
236294
def _add_output_schema_to_global_schema_dictionary(base_func_name, schema):
237295
"""
238296
function to add a generated output schema, to the function schema dict

inference_schema/schema_util.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from inference_schema._constants import INPUT_SCHEMA_ATTR, OUTPUT_SCHEMA_ATTR
99

1010
__functions_schema__ = {}
11+
__versions__ = {}
1112

1213

1314
def get_input_schema(func):
@@ -36,6 +37,24 @@ def get_output_schema(func):
3637
return _get_schema_from_dictionary(OUTPUT_SCHEMA_ATTR, func)
3738

3839

40+
def get_supported_versions(func):
41+
"""
42+
Extract supported swagger versions from the decorated function.
43+
44+
:param func:
45+
:type func: function | FunctionWrapper
46+
:return:
47+
:rtype: list
48+
"""
49+
decorators = _get_decorators(func)
50+
func_base_name = _get_function_full_qual_name(decorators[-1])
51+
52+
input_versions = __versions__.get(func_base_name, {}).get(INPUT_SCHEMA_ATTR, {}).get('versions', [])
53+
output_versions = __versions__.get(func_base_name, {}).get(OUTPUT_SCHEMA_ATTR, {}).get('versions', [])
54+
set_intersection = set(input_versions) & set(output_versions)
55+
return sorted(list(set_intersection))
56+
57+
3958
def get_schemas_dict():
4059
"""
4160
Retrieve a deepcopy of the dictionary that is used to track the provided function schemas

tests/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,27 @@ def standard_py_func(param):
184184
return standard_py_func
185185

186186

187+
@pytest.fixture(scope="session")
188+
def standard_sample_input_multitype_list():
189+
return ['foo', 1]
190+
191+
192+
@pytest.fixture(scope="session")
193+
def standard_sample_output_multitype_list():
194+
return 5
195+
196+
197+
@pytest.fixture(scope="session")
198+
def decorated_standard_func_multitype_list(standard_sample_input_multitype_list, standard_sample_output_multitype_list):
199+
@input_schema('param', StandardPythonParameterType(standard_sample_input_multitype_list))
200+
@output_schema(StandardPythonParameterType(standard_sample_output_multitype_list))
201+
def standard_py_func_multitype_list(param):
202+
assert type(param) is list
203+
return param[1]
204+
205+
return standard_py_func_multitype_list
206+
207+
187208
@pytest.fixture(scope="session")
188209
def decorated_float_func():
189210
@input_schema('param', StandardPythonParameterType(1.0))

tests/test_numpy_parameter_type.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# ---------------------------------------------------------
44

55
import numpy as np
6+
from inference_schema.schema_util import get_supported_versions
67

78

89
class TestNumpyParameterType(object):
@@ -21,3 +22,8 @@ def test_numpy_handling(self, decorated_numpy_func):
2122
numpy_input = {"param": [{"name": "Sarah", "grades": [8.0, 7.0]}]}
2223
result = decorated_numpy_func(**numpy_input)
2324
assert np.array_equal(result, grades)
25+
26+
version_list = get_supported_versions(decorated_numpy_func)
27+
assert '2.0' in version_list
28+
assert '3.0' in version_list
29+
assert '3.1' in version_list

tests/test_pandas_parameter_type.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88

99
from pandas.testing import assert_frame_equal
10+
from inference_schema.schema_util import get_supported_versions
1011

1112

1213
class TestPandasParameterType(object):
@@ -25,6 +26,11 @@ def test_pandas_handling(self, decorated_pandas_func):
2526
result = decorated_pandas_func(**pandas_input)
2627
assert_frame_equal(result, state)
2728

29+
version_list = get_supported_versions(decorated_pandas_func)
30+
assert '2.0' in version_list
31+
assert '3.0' in version_list
32+
assert '3.1' in version_list
33+
2834
def test_pandas_orient_handling(self, decorated_pandas_func_split_orient):
2935
pandas_input = {"columns": ["name", "state"], "index": [0], "data": [["Sarah", "WA"]]}
3036
state = pd.DataFrame(pd.read_json(json.dumps(pandas_input), orient='split')['state'])

tests/test_spark_parameter_type.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pandas as pd
66

77
from pyspark.sql.session import SparkSession
8+
from inference_schema.schema_util import get_supported_versions
89

910

1011
class TestSparkParameterType(object):
@@ -25,3 +26,8 @@ def test_spark_handling(self, decorated_spark_func):
2526
spark_input = {'param': [{'name': 'Sarah', 'state': 'WA'}]}
2627
result = decorated_spark_func(**spark_input)
2728
assert state.subtract(result).count() == result.subtract(state).count() == 0
29+
30+
version_list = get_supported_versions(decorated_spark_func)
31+
assert '2.0' in version_list
32+
assert '3.0' in version_list
33+
assert '3.1' in version_list

tests/test_standard_parameter_type.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4+
from inference_schema.parameter_types.standard_py_parameter_type import StandardPythonParameterType
5+
from inference_schema.schema_util import get_supported_versions
46

57

68
class TestStandardPythonParameterType(object):
79

8-
def test_standard_handling(self, decorated_standard_func):
10+
def test_standard_handling_unique(self, decorated_standard_func):
911
standard_input = {'name': ['Sarah'], 'state': ['WA']}
1012
state = {'state': ['WA']}
1113
result = decorated_standard_func(standard_input)
@@ -15,6 +17,24 @@ def test_standard_handling(self, decorated_standard_func):
1517
result = decorated_standard_func(**standard_input)
1618
assert state == result
1719

20+
version_list = get_supported_versions(decorated_standard_func)
21+
assert '2.0' in version_list
22+
assert '3.0' in version_list
23+
assert '3.1' in version_list
24+
25+
def test_standard_handling_list(self, decorated_standard_func_multitype_list):
26+
standard_input = ['foo', 1]
27+
assert 1 == decorated_standard_func_multitype_list(standard_input)
28+
29+
version_list = get_supported_versions(decorated_standard_func_multitype_list)
30+
assert '2.0' not in version_list
31+
assert '3.0' in version_list
32+
assert '3.1' in version_list
33+
34+
def test_supported_versions_string(self):
35+
assert '2.0' in StandardPythonParameterType({'name': ['Sarah'], 'state': ['WA']}).supported_versions()
36+
assert '2.0' not in StandardPythonParameterType(['foo', 1]).supported_versions()
37+
1838
def test_float_int_handling(self, decorated_float_func):
1939
float_input = 1.0
2040
result = decorated_float_func(float_input)

0 commit comments

Comments
 (0)