Skip to content

Commit cf71ce5

Browse files
authored
Fix Pandas datetime and timedelta handling (#48)
* Change pandas handling to use string formatting for datetime. Add corresponding test. * Fixing timedelta handling, and corresponding test
1 parent 4938541 commit cf71ce5

File tree

5 files changed

+82
-3
lines changed

5 files changed

+82
-3
lines changed

inference_schema/parameter_types/pandas_parameter_type.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def deserialize_input(self, input_data):
7373
if self.enforce_column_type:
7474
sample_input_column_types = self.sample_input.dtypes.to_dict()
7575
converted_types = {x: sample_input_column_types.get(x, object) for x in data_frame.columns}
76+
for column_name, column_type in converted_types.items():
77+
if str(column_type).startswith('timedelta'):
78+
data_frame[column_name] = pd.to_timedelta(data_frame[column_name])
7679
data_frame = data_frame.astype(dtype=converted_types)
7780

7881
if self.enforce_shape:
@@ -101,11 +104,44 @@ def input_to_swagger(self):
101104
:rtype: dict
102105
"""
103106
LIST_LIKE_ORIENTS = ('records', 'values')
104-
json_sample = json.loads(self.sample_input.to_json(orient=self.orient))
107+
json_sample = json.loads(self.sample_input.to_json(orient=self.orient, date_format='iso'))
105108

106109
if self.orient in LIST_LIKE_ORIENTS:
107110
swagger_schema = get_swagger_for_list(json_sample)
108111
else:
109112
swagger_schema = get_swagger_for_nested_dict(json_sample)
110113

114+
if self.orient == 'records':
115+
for column_name in self.sample_input.columns:
116+
data_type = str(self.sample_input.dtypes[column_name])
117+
if data_type.startswith('datetime'):
118+
swagger_schema['items']['properties'][str(column_name)]['format'] = 'date-time'
119+
elif data_type.startswith('timedelta'):
120+
swagger_schema['items']['properties'][str(column_name)]['format'] = 'timedelta'
121+
elif self.orient == 'index':
122+
for row in swagger_schema['properties'].values():
123+
for column_name in self.sample_input.columns:
124+
data_type = str(self.sample_input.dtypes[column_name])
125+
if data_type.startswith('datetime'):
126+
row['properties'][str(column_name)]['format'] = 'date-time'
127+
elif data_type.startswith('timedelta'):
128+
row['properties'][str(column_name)]['format'] = 'timedelta'
129+
elif self.orient == 'columns':
130+
for column_name in self.sample_input.columns:
131+
for row_info in swagger_schema['properties'][str(column_name)]['properties'].values():
132+
data_type = str(self.sample_input.dtypes[column_name])
133+
if data_type.startswith('datetime'):
134+
row_info['format'] = 'date-time'
135+
elif data_type.startswith('timedelta'):
136+
row_info['format'] = 'timedelta'
137+
elif self.orient == 'table':
138+
for column_name in self.sample_input.columns:
139+
data_type = str(self.sample_input.dtypes[column_name])
140+
if data_type.startswith('datetime'):
141+
swagger_schema['properties']['data']['items']['properties'][str(column_name)]['format'] = \
142+
'date-time'
143+
elif data_type.startswith('timedelta'):
144+
swagger_schema['properties']['data']['items']['properties'][str(column_name)]['format'] = \
145+
'timedelta'
146+
111147
return swagger_schema

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def pandas_func(param):
8484
@pytest.fixture(scope="session")
8585
def decorated_pandas_datetime_func():
8686
pandas_sample_timestamp_input = pd.DataFrame({'datetime': pd.Series(['2013-12-31T00:00:00.000Z'],
87-
dtype='datetime64[ns]')})
87+
dtype='datetime64[ns]'),
88+
'days': pd.Series([pd.Timedelta(days=1)])})
8889

8990
@input_schema('param', PandasParameterType(pandas_sample_timestamp_input))
9091
def pandas_datetime_func(param):
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"type": "object",
3+
"properties": {
4+
"param": {
5+
"type": "array",
6+
"items": {
7+
"type": "object",
8+
"required": [
9+
"datetime",
10+
"days"
11+
],
12+
"properties": {
13+
"datetime": {
14+
"type": "string",
15+
"format": "date-time"
16+
},
17+
"days": {
18+
"type": "string",
19+
"format": "timedelta"
20+
}
21+
}
22+
}
23+
}
24+
},
25+
"example": {
26+
"param": [
27+
{
28+
"datetime": "2013-12-31T00:00:00.000Z",
29+
"days": "P1DT0H0M0S"
30+
}
31+
]
32+
}
33+
}

tests/test_pandas_parameter_type.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def test_pandas_orient_handling(self, decorated_pandas_func_split_orient):
3333

3434
def test_pandas_timestamp_handling(self, decorated_pandas_datetime_func):
3535
datetime_str = '2013-12-31 00:00:00,000000'
36-
pandas_input = {'param': [{'datetime': datetime_str}]}
36+
timedelta_str = 'P1DT0H0M0S'
37+
pandas_input = {'param': [{'datetime': datetime_str, 'days': timedelta_str}]}
3738
datetime = pd.DataFrame(
3839
pd.DataFrame({'datetime': pd.Series([datetime_str], dtype='datetime64[ns]')})['datetime'])
3940
result = decorated_pandas_datetime_func(**pandas_input)

tests/test_schema_generation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ def test_pandas_handling(self, decorated_pandas_func):
3333
assert ordered(get_output_schema(decorated_pandas_func)) == ordered(self.pandas_sample_output_schema)
3434

3535

36+
class TestPandasDatetimeSchemaGeneration(object):
37+
pandas_sample_datetime_schema = json.loads(
38+
resource_string(__name__, os.path.join('resources', 'sample_pandas_datetime_schema.json')).decode('ascii'))
39+
40+
def test_pandas_datetime_handling(self, decorated_pandas_datetime_func):
41+
assert ordered(get_input_schema(decorated_pandas_datetime_func)) == ordered(self.pandas_sample_datetime_schema)
42+
43+
3644
class TestSparkSchemaGeneration(object):
3745
spark_sample_input_schema = json.loads(
3846
resource_string(__name__, os.path.join('resources', 'sample_spark_input_schema.json')).decode('ascii'))

0 commit comments

Comments
 (0)