Skip to content

Commit ecd6516

Browse files
authored
Some more fixes for column label types handling (#45)
* Some more fixes for column label types handling * typo
1 parent 28d0c97 commit ecd6516

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

inference_schema/parameter_types/pandas_parameter_type.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(self, sample_input, enforce_column_type=True, enforce_shape=True, a
2727
:param enforce_shape: Enforce that input shape must match that of the provided sample when `deserialize_input`
2828
is called.
2929
:type enforce_shape: bool
30-
:param apply_column_names: Apply column names fromt he provided sample onto the input when `deserialize_input`
30+
:param apply_column_names: Apply column names from the provided sample onto the input when `deserialize_input`
3131
is called.
3232
:type apply_column_names: bool
3333
:param orient: The Pandas orient to use when converting between a json object and a DataFrame. Possible orients
@@ -67,13 +67,13 @@ def deserialize_input(self, input_data):
6767

6868
data_frame = pd.read_json(json.dumps(input_data), orient=self.orient)
6969

70-
if self.apply_column_names and isinstance(input_data, list) and not isinstance(input_data[0], dict):
70+
if self.apply_column_names:
7171
data_frame.columns = self.sample_input.columns.copy()
7272

7373
if self.enforce_column_type:
7474
sample_input_column_types = self.sample_input.dtypes.to_dict()
7575
converted_types = {x: sample_input_column_types.get(x, object) for x in data_frame.columns}
76-
data_frame = data_frame.astype(dtype=converted_types, copy=False)
76+
data_frame = data_frame.astype(dtype=converted_types)
7777

7878
if self.enforce_shape:
7979
expected_shape = self.sample_input.shape

tests/test_pandas_parameter_type.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ def test_pandas_timestamp_handling(self, decorated_pandas_datetime_func):
3939
result = decorated_pandas_datetime_func(**pandas_input)
4040
assert_frame_equal(result, datetime)
4141

42-
def test_pandas_multi_type_columns_labels_handling(self, decorated_pandas_func_multi_type_column_labels,
43-
pandas_sample_input_multi_type_column_labels):
44-
result = decorated_pandas_func_multi_type_column_labels(pandas_sample_input_multi_type_column_labels)
45-
assert_frame_equal(result, pandas_sample_input_multi_type_column_labels)
42+
def test_pandas_multi_type_columns_labels_handling(self, decorated_pandas_func_multi_type_column_labels):
43+
pandas_input = {'name': ['Sarah', 'John'], 1: ['WA', 'CA']}
44+
result = decorated_pandas_func_multi_type_column_labels(pandas_input)
45+
expected_result = pd.DataFrame(pandas_input)
46+
assert_frame_equal(result, expected_result)
4647

4748

4849
class TestNestedType(object):

0 commit comments

Comments
 (0)