Skip to content

Commit d6f82c2

Browse files
authored
add ensure_time_dtype and handle pandas nullable dtypes in validate_format (#38)
1 parent 0c7db12 commit d6f82c2

File tree

4 files changed

+205
-25
lines changed

4 files changed

+205
-25
lines changed

nbs/validation.ipynb

Lines changed: 141 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,135 @@
3131
"import numpy as np\n",
3232
"import pandas as pd\n",
3333
"\n",
34-
"from utilsforecast.compat import DataFrame, pl_DataFrame"
34+
"from utilsforecast.compat import DataFrame, Series, pl_DataFrame, pl_Series, pl"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": null,
40+
"id": "7638f2a6-0303-41a9-a103-d3055e7bc572",
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"from fastcore.test import test_fail"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": null,
50+
"id": "e001379f-a552-4fd9-ac82-3f4d5c76f555",
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"#| polars\n",
55+
"import polars.testing"
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": null,
61+
"id": "98e0f72a-3128-46d7-8d6c-cbd0c62ce68a",
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"#| exporti\n",
66+
"def _is_dt_or_int(s: Series) -> bool:\n",
67+
" dtype = s.head(1).to_numpy().dtype\n",
68+
" is_dt = np.issubdtype(dtype, np.datetime64)\n",
69+
" is_int = np.issubdtype(dtype, np.integer)\n",
70+
" return is_dt or is_int"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": null,
76+
"id": "6679513a-4f2b-4eda-9970-e5c6725dd761",
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"#| export\n",
81+
"def ensure_shallow_copy(df: pd.DataFrame) -> pd.DataFrame:\n",
82+
" from packaging.version import Version\n",
83+
"\n",
84+
" if Version(pd.__version__) < Version(\"1.4\"):\n",
85+
" # https://github.com/pandas-dev/pandas/pull/43406\n",
86+
" df = df.copy()\n",
87+
" return df"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"id": "ccb15f20-56e4-4d0d-96bc-830c5effff19",
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"#| export\n",
98+
"def ensure_time_dtype(df: DataFrame, time_col: str = 'ds') -> DataFrame:\n",
99+
" \"\"\"Make sure that `time_col` contains timestamps or integers.\n",
100+
" If it contains strings, try to cast them as timestamps.\"\"\"\n",
101+
" times = df[time_col]\n",
102+
" if _is_dt_or_int(times):\n",
103+
" return df\n",
104+
" parse_err_msg = (\n",
105+
" f\"Failed to parse '{time_col}' from string to datetime. \"\n",
106+
" 'Please make sure that it contains valid timestamps or integers.'\n",
107+
" )\n",
108+
" if isinstance(times, pd.Series) and pd.api.types.is_object_dtype(times):\n",
109+
" try:\n",
110+
" times = pd.to_datetime(times)\n",
111+
" except ValueError:\n",
112+
" raise ValueError(parse_err_msg)\n",
113+
" df = ensure_shallow_copy(df.copy(deep=False))\n",
114+
" df[time_col] = times\n",
115+
" elif isinstance(times, pl_Series) and times.dtype == pl.Utf8:\n",
116+
" try:\n",
117+
" times = times.str.to_datetime()\n",
118+
" except pl.exceptions.ComputeError:\n",
119+
" raise ValueError(parse_err_msg)\n",
120+
" df = df.with_columns(times)\n",
121+
" else:\n",
122+
" raise ValueError(f\"'{time_col}' should have valid timestamps or integers.\")\n",
123+
" return df"
124+
]
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": null,
129+
"id": "604ec21b-314f-42cb-9ee0-5950fd611b97",
130+
"metadata": {},
131+
"outputs": [],
132+
"source": [
133+
"pd.testing.assert_frame_equal(\n",
134+
" ensure_time_dtype(pd.DataFrame({'ds': ['2000-01-01']})),\n",
135+
" pd.DataFrame({'ds': pd.to_datetime(['2000-01-01'])})\n",
136+
")\n",
137+
"df = pd.DataFrame({'ds': [1, 2]})\n",
138+
"assert df is ensure_time_dtype(df)\n",
139+
"test_fail(\n",
140+
" lambda: ensure_time_dtype(pd.DataFrame({'ds': ['2000-14-14']})),\n",
141+
" contains='Please make sure that it contains valid timestamps',\n",
142+
")"
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": null,
148+
"id": "5335d217-e240-46df-90a0-6aeeb07a0586",
149+
"metadata": {},
150+
"outputs": [],
151+
"source": [
152+
"#| polars\n",
153+
"pl.testing.assert_frame_equal(\n",
154+
" ensure_time_dtype(pl.DataFrame({'ds': ['2000-01-01']})),\n",
155+
" pl.DataFrame().with_columns(ds=pl.datetime(2000, 1, 1))\n",
156+
")\n",
157+
"df = pl.DataFrame({'ds': [1, 2]})\n",
158+
"assert df is ensure_time_dtype(df)\n",
159+
"test_fail(\n",
160+
" lambda: ensure_time_dtype(pl.DataFrame({'ds': ['hello']})),\n",
161+
" contains='Please make sure that it contains valid timestamps',\n",
162+
")"
35163
]
36164
},
37165
{
@@ -76,14 +204,18 @@
76204
" raise ValueError(f\"The following columns are missing: {missing_cols}\")\n",
77205
"\n",
78206
" # time col\n",
79-
" times_dtype = df[time_col].head(1).to_numpy().dtype\n",
80-
" if not (np.issubdtype(times_dtype, np.datetime64) or np.issubdtype(times_dtype, np.integer)):\n",
207+
" if not _is_dt_or_int(df[time_col]):\n",
208+
" times_dtype = df[time_col].head(1).to_numpy().dtype\n",
81209
" raise ValueError(f\"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'.\")\n",
82210
"\n",
83211
" # target col\n",
84-
" target_dtype = df[target_col].head(1).to_numpy().dtype\n",
85-
" if not np.issubdtype(target_dtype, np.number):\n",
86-
" raise ValueError(f\"The target column ('{target_col}') should have a numeric data type, got '{target_dtype}')\")"
212+
" target = df[target_col]\n",
213+
" if isinstance(target, pd.Series):\n",
214+
" is_numeric = np.issubdtype(target.dtype.type, np.number)\n",
215+
" else:\n",
216+
" is_numeric = target.is_numeric()\n",
217+
" if not is_numeric:\n",
218+
" raise ValueError(f\"The target column ('{target_col}') should have a numeric data type, got '{target.dtype}')\")"
87219
]
88220
},
89221
{
@@ -108,7 +240,7 @@
108240
"text/markdown": [
109241
"---\n",
110242
"\n",
111-
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L12){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
243+
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L57){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
112244
"\n",
113245
"### validate_format\n",
114246
"\n",
@@ -130,7 +262,7 @@
130262
"text/plain": [
131263
"---\n",
132264
"\n",
133-
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L12){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
265+
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L57){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
134266
"\n",
135267
"### validate_format\n",
136268
"\n",
@@ -168,10 +300,7 @@
168300
"source": [
169301
"import datetime\n",
170302
"\n",
171-
"import pandas as pd\n",
172-
"from fastcore.test import test_fail\n",
173-
"\n",
174-
"from utilsforecast.compat import POLARS_INSTALLED, pl\n",
303+
"from utilsforecast.compat import POLARS_INSTALLED\n",
175304
"from utilsforecast.data import generate_series"
176305
]
177306
},

settings.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ keywords = time-series analysis forecasting
2626
language = English
2727
status = 3
2828
user = Nixtla
29-
requirements = numpy pandas>=1.1.1
29+
requirements = numpy packaging pandas>=1.1.1
3030
plotting_requirements = matplotlib plotly plotly-resampler
3131
dev_requirements = matplotlib numba plotly polars pyarrow
3232
readme_nb = index.ipynb

utilsforecast/_modidx.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,11 @@
166166
'utilsforecast/target_transforms.py'),
167167
'utilsforecast.target_transforms._transform': ( 'target_transforms.html#_transform',
168168
'utilsforecast/target_transforms.py')},
169-
'utilsforecast.validation': { 'utilsforecast.validation.validate_format': ( 'validation.html#validate_format',
169+
'utilsforecast.validation': { 'utilsforecast.validation._is_dt_or_int': ( 'validation.html#_is_dt_or_int',
170+
'utilsforecast/validation.py'),
171+
'utilsforecast.validation.ensure_shallow_copy': ( 'validation.html#ensure_shallow_copy',
172+
'utilsforecast/validation.py'),
173+
'utilsforecast.validation.ensure_time_dtype': ( 'validation.html#ensure_time_dtype',
174+
'utilsforecast/validation.py'),
175+
'utilsforecast.validation.validate_format': ( 'validation.html#validate_format',
170176
'utilsforecast/validation.py')}}}

utilsforecast/validation.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,59 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/validation.ipynb.
22

33
# %% auto 0
4-
__all__ = ['validate_format']
4+
__all__ = ['ensure_shallow_copy', 'ensure_time_dtype', 'validate_format']
55

66
# %% ../nbs/validation.ipynb 2
77
import numpy as np
88
import pandas as pd
99

10-
from .compat import DataFrame, pl_DataFrame
10+
from .compat import DataFrame, Series, pl_DataFrame, pl_Series, pl
1111

12-
# %% ../nbs/validation.ipynb 3
12+
# %% ../nbs/validation.ipynb 5
13+
def _is_dt_or_int(s: Series) -> bool:
14+
dtype = s.head(1).to_numpy().dtype
15+
is_dt = np.issubdtype(dtype, np.datetime64)
16+
is_int = np.issubdtype(dtype, np.integer)
17+
return is_dt or is_int
18+
19+
# %% ../nbs/validation.ipynb 6
20+
def ensure_shallow_copy(df: pd.DataFrame) -> pd.DataFrame:
21+
from packaging.version import Version
22+
23+
if Version(pd.__version__) < Version("1.4"):
24+
# https://github.com/pandas-dev/pandas/pull/43406
25+
df = df.copy()
26+
return df
27+
28+
# %% ../nbs/validation.ipynb 7
29+
def ensure_time_dtype(df: DataFrame, time_col: str = "ds") -> DataFrame:
30+
"""Make sure that `time_col` contains timestamps or integers.
31+
If it contains strings, try to cast them as timestamps."""
32+
times = df[time_col]
33+
if _is_dt_or_int(times):
34+
return df
35+
parse_err_msg = (
36+
f"Failed to parse '{time_col}' from string to datetime. "
37+
"Please make sure that it contains valid timestamps or integers."
38+
)
39+
if isinstance(times, pd.Series) and pd.api.types.is_object_dtype(times):
40+
try:
41+
times = pd.to_datetime(times)
42+
except ValueError:
43+
raise ValueError(parse_err_msg)
44+
df = ensure_shallow_copy(df.copy(deep=False))
45+
df[time_col] = times
46+
elif isinstance(times, pl_Series) and times.dtype == pl.Utf8:
47+
try:
48+
times = times.str.to_datetime()
49+
except pl.exceptions.ComputeError:
50+
raise ValueError(parse_err_msg)
51+
df = df.with_columns(times)
52+
else:
53+
raise ValueError(f"'{time_col}' should have valid timestamps or integers.")
54+
return df
55+
56+
# %% ../nbs/validation.ipynb 10
1357
def validate_format(
1458
df: DataFrame,
1559
id_col: str = "unique_id",
@@ -44,18 +88,19 @@ def validate_format(
4488
raise ValueError(f"The following columns are missing: {missing_cols}")
4589

4690
# time col
47-
times_dtype = df[time_col].head(1).to_numpy().dtype
48-
if not (
49-
np.issubdtype(times_dtype, np.datetime64)
50-
or np.issubdtype(times_dtype, np.integer)
51-
):
91+
if not _is_dt_or_int(df[time_col]):
92+
times_dtype = df[time_col].head(1).to_numpy().dtype
5293
raise ValueError(
5394
f"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'."
5495
)
5596

5697
# target col
57-
target_dtype = df[target_col].head(1).to_numpy().dtype
58-
if not np.issubdtype(target_dtype, np.number):
98+
target = df[target_col]
99+
if isinstance(target, pd.Series):
100+
is_numeric = np.issubdtype(target.dtype.type, np.number)
101+
else:
102+
is_numeric = target.is_numeric()
103+
if not is_numeric:
59104
raise ValueError(
60-
f"The target column ('{target_col}') should have a numeric data type, got '{target_dtype}')"
105+
f"The target column ('{target_col}') should have a numeric data type, got '{target.dtype}')"
61106
)

0 commit comments

Comments
 (0)