Skip to content

Commit 1e11c37

Browse files
authored
add backtest_splits (#36)
1 parent 74ed1d4 commit 1e11c37

File tree

3 files changed

+336
-4
lines changed

3 files changed

+336
-4
lines changed

nbs/processing.ipynb

Lines changed: 236 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
"source": [
3232
"#| export\n",
3333
"import re\n",
34-
"from typing import Any, Dict, List, Optional, Tuple, Union\n",
34+
"import reprlib\n",
35+
"import warnings\n",
36+
"from typing import Any, Dict, Generator, List, Optional, Tuple, Union\n",
3537
"\n",
3638
"import numpy as np\n",
3739
"import pandas as pd\n",
@@ -48,7 +50,7 @@
4850
"metadata": {},
4951
"outputs": [],
5052
"source": [
51-
"from fastcore.test import test_eq\n",
53+
"from fastcore.test import test_eq, test_fail\n",
5254
"from nbdev import show_doc\n",
5355
"\n",
5456
"from utilsforecast.compat import POLARS_INSTALLED\n",
@@ -1074,6 +1076,238 @@
10741076
" test_eq(data, series_pl.select(pl.col(c).map_batches(lambda s: s.to_physical()) for c in ['y'] + static_features[:n_static_features]).to_numpy())\n",
10751077
" test_eq(np.diff(indptr), grouped.count().sort('unique_id')['count'].to_numpy())"
10761078
]
1079+
},
1080+
{
1081+
"cell_type": "code",
1082+
"execution_count": null,
1083+
"id": "37bc6aa4-ce39-4559-9964-01c06b7d7dbd",
1084+
"metadata": {},
1085+
"outputs": [],
1086+
"source": [
1087+
"#| exporti\n",
1088+
"def _single_split(\n",
1089+
" df: DataFrame,\n",
1090+
" i_window: int, \n",
1091+
" n_windows: int,\n",
1092+
" h: int,\n",
1093+
" id_col: str,\n",
1094+
" time_col: str,\n",
1095+
" freq: Union[int, str, pd.offsets.BaseOffset],\n",
1096+
" max_dates: Series, \n",
1097+
" step_size: Optional[int] = None,\n",
1098+
" input_size: Optional[int] = None,\n",
1099+
") -> Tuple[DataFrame, Series, Series]:\n",
1100+
" if step_size is None:\n",
1101+
" step_size = h\n",
1102+
" test_size = h + step_size * (n_windows - 1)\n",
1103+
" offset = test_size - i_window * step_size\n",
1104+
" train_ends = offset_dates(max_dates, freq, -offset)\n",
1105+
" valid_ends = offset_dates(train_ends, freq, h)\n",
1106+
" train_mask = df[time_col].le(train_ends)\n",
1107+
" valid_mask = df[time_col].gt(train_ends) & df[time_col].le(valid_ends) \n",
1108+
" if input_size is not None:\n",
1109+
" train_starts = offset_dates(train_ends, freq, -input_size)\n",
1110+
" train_mask &= df[time_col].gt(train_starts)\n",
1111+
" train_sizes = group_by(train_mask, df[id_col], maintain_order=True).sum()\n",
1112+
" if isinstance(train_sizes, pd.Series):\n",
1113+
" train_sizes = train_sizes.reset_index()\n",
1114+
" zeros_mask = train_sizes[time_col].eq(0) \n",
1115+
" if zeros_mask.all():\n",
1116+
" raise ValueError(\n",
1117+
" 'All series are too short for the cross validation settings, '\n",
1118+
" f'at least {offset + 1} samples are required.\\n'\n",
1119+
" 'Please reduce `n_windows` or `h`.'\n",
1120+
" )\n",
1121+
" elif zeros_mask.any():\n",
1122+
" ids = filter_with_mask(train_sizes[id_col], zeros_mask)\n",
1123+
" warnings.warn(\n",
1124+
" 'The following series are too short for the window '\n",
1125+
" f'and will be dropped: {reprlib.repr(list(ids))}'\n",
1126+
" )\n",
1127+
" dropped_ids = is_in(df[id_col], ids)\n",
1128+
" valid_mask &= ~dropped_ids\n",
1129+
" if isinstance(train_ends, pd.Series):\n",
1130+
" cutoffs: DataFrame = (\n",
1131+
" train_ends\n",
1132+
" .set_axis(df[id_col])\n",
1133+
" .groupby(id_col, observed=True)\n",
1134+
" .head(1)\n",
1135+
" .rename(\"cutoff\")\n",
1136+
" .reset_index()\n",
1137+
" )\n",
1138+
" else:\n",
1139+
" cutoffs = train_ends.to_frame().with_columns(df[id_col])\n",
1140+
" cutoffs = (\n",
1141+
" group_by(cutoffs, id_col)\n",
1142+
" .agg(pl.col(time_col).head(1))\n",
1143+
" .explode(pl.col(time_col))\n",
1144+
" .rename({time_col: 'cutoff'})\n",
1145+
" )\n",
1146+
" return cutoffs, train_mask, valid_mask"
1147+
]
1148+
},
1149+
{
1150+
"cell_type": "code",
1151+
"execution_count": null,
1152+
"id": "c5c3370a-9a55-4436-9326-b459d03525dc",
1153+
"metadata": {},
1154+
"outputs": [],
1155+
"source": [
1156+
"#|export\n",
1157+
"def backtest_splits(\n",
1158+
" df: DataFrame,\n",
1159+
" n_windows: int,\n",
1160+
" h: int,\n",
1161+
" id_col: str,\n",
1162+
" time_col: str,\n",
1163+
" freq: Union[int, str, pd.offsets.BaseOffset],\n",
1164+
" step_size: Optional[int] = None,\n",
1165+
" input_size: Optional[int] = None,\n",
1166+
") -> Generator[Tuple[DataFrame, DataFrame, DataFrame], None, None]:\n",
1167+
" if isinstance(df, pd.DataFrame):\n",
1168+
" max_dates = df.groupby(id_col, observed=True)[time_col].transform('max')\n",
1169+
" else:\n",
1170+
" max_dates = df.select(pl.col(time_col).max().over(id_col))[time_col]\n",
1171+
" for i in range(n_windows):\n",
1172+
" cutoffs, train_mask, valid_mask = _single_split(\n",
1173+
" df,\n",
1174+
" i_window=i,\n",
1175+
" n_windows=n_windows,\n",
1176+
" h=h,\n",
1177+
" id_col=id_col,\n",
1178+
" time_col=time_col,\n",
1179+
" freq=freq,\n",
1180+
" max_dates=max_dates,\n",
1181+
" step_size=step_size,\n",
1182+
" input_size=input_size,\n",
1183+
" )\n",
1184+
" train = filter_with_mask(df, train_mask)\n",
1185+
" valid = filter_with_mask(df, valid_mask)\n",
1186+
" yield cutoffs, train, valid"
1187+
]
1188+
},
1189+
{
1190+
"cell_type": "code",
1191+
"execution_count": null,
1192+
"id": "ae3ef1ca-418c-4506-990f-0502481c6fef",
1193+
"metadata": {},
1194+
"outputs": [],
1195+
"source": [
1196+
"#| hide\n",
1197+
"short_series = generate_series(100, max_length=50)\n",
1198+
"backtest_results = list(\n",
1199+
" backtest_splits(\n",
1200+
" short_series,\n",
1201+
" n_windows=1,\n",
1202+
" h=49,\n",
1203+
" id_col='unique_id',\n",
1204+
" time_col='ds',\n",
1205+
" freq=pd.offsets.Day(),\n",
1206+
" )\n",
1207+
")[0]\n",
1208+
"test_fail(\n",
1209+
" lambda: list(\n",
1210+
" backtest_splits(\n",
1211+
" short_series,\n",
1212+
" n_windows=1,\n",
1213+
" h=50,\n",
1214+
" id_col='unique_id',\n",
1215+
" time_col='ds',\n",
1216+
" freq=pd.offsets.Day(),\n",
1217+
" )\n",
1218+
" ),\n",
1219+
" contains='at least 51 samples are required'\n",
1220+
")\n",
1221+
"some_short_series = generate_series(100, min_length=20, max_length=100)\n",
1222+
"with warnings.catch_warnings(record=True) as issued_warnings:\n",
1223+
" warnings.simplefilter('always', UserWarning)\n",
1224+
" splits = list(\n",
1225+
" backtest_splits(\n",
1226+
" some_short_series,\n",
1227+
" n_windows=1,\n",
1228+
" h=50,\n",
1229+
" id_col='unique_id',\n",
1230+
" time_col='ds',\n",
1231+
" freq=pd.offsets.Day(),\n",
1232+
" )\n",
1233+
" )\n",
1234+
" assert any('will be dropped' in str(w.message) for w in issued_warnings)\n",
1235+
"short_series_int = short_series.copy()\n",
1236+
"short_series_int['ds'] = short_series.groupby('unique_id', observed=True).transform('cumcount')\n",
1237+
"backtest_int_results = list(\n",
1238+
" backtest_splits(\n",
1239+
" short_series_int,\n",
1240+
" n_windows=1,\n",
1241+
" h=40,\n",
1242+
" id_col='unique_id',\n",
1243+
" time_col='ds',\n",
1244+
" freq=1\n",
1245+
" )\n",
1246+
")[0]"
1247+
]
1248+
},
1249+
{
1250+
"cell_type": "code",
1251+
"execution_count": null,
1252+
"id": "cbecf3fc-0354-4d3c-82ac-39929e50a01d",
1253+
"metadata": {},
1254+
"outputs": [],
1255+
"source": [
1256+
"#| hide\n",
1257+
"def test_backtest_splits(df, n_windows, h, step_size, input_size):\n",
1258+
" max_dates = df.groupby('unique_id', observed=True)['ds'].max()\n",
1259+
" day_offset = pd.offsets.Day() \n",
1260+
" common_kwargs = dict(\n",
1261+
" n_windows=n_windows,\n",
1262+
" h=h,\n",
1263+
" id_col='unique_id',\n",
1264+
" time_col='ds',\n",
1265+
" freq=pd.offsets.Day(), \n",
1266+
" step_size=step_size,\n",
1267+
" input_size=input_size, \n",
1268+
" )\n",
1269+
" permuted_df = df.sample(frac=1.0)\n",
1270+
" splits = backtest_splits(df, **common_kwargs)\n",
1271+
" splits_on_permuted = list(backtest_splits(permuted_df, **common_kwargs))\n",
1272+
" if step_size is None:\n",
1273+
" step_size = h\n",
1274+
" test_size = h + step_size * (n_windows - 1)\n",
1275+
" for window, (cutoffs, train, valid) in enumerate(splits):\n",
1276+
" offset = test_size - window * step_size\n",
1277+
" expected_max_train_dates = max_dates - day_offset * offset\n",
1278+
" max_train_dates = train.groupby('unique_id', observed=True)['ds'].max()\n",
1279+
" pd.testing.assert_series_equal(max_train_dates, expected_max_train_dates)\n",
1280+
" pd.testing.assert_frame_equal(cutoffs, max_train_dates.rename('cutoff').reset_index())\n",
1281+
" \n",
1282+
" if input_size is not None:\n",
1283+
" expected_min_train_dates = expected_max_train_dates - day_offset * (input_size - 1)\n",
1284+
" min_train_dates = train.groupby('unique_id', observed=True)['ds'].min()\n",
1285+
" pd.testing.assert_series_equal(min_train_dates, expected_min_train_dates)\n",
1286+
"\n",
1287+
" expected_min_valid_dates = expected_max_train_dates + day_offset\n",
1288+
" min_valid_dates = valid.groupby('unique_id', observed=True)['ds'].min()\n",
1289+
" pd.testing.assert_series_equal(min_valid_dates, expected_min_valid_dates)\n",
1290+
"\n",
1291+
" expected_max_valid_dates = expected_max_train_dates + day_offset * h\n",
1292+
" max_valid_dates = valid.groupby('unique_id', observed=True)['ds'].max()\n",
1293+
" pd.testing.assert_series_equal(max_valid_dates, expected_max_valid_dates)\n",
1294+
"\n",
1295+
" if window == n_windows - 1:\n",
1296+
" pd.testing.assert_series_equal(max_valid_dates, max_dates)\n",
1297+
"\n",
1298+
" _, permuted_train, permuted_valid = splits_on_permuted[window] \n",
1299+
" pd.testing.assert_frame_equal(train, permuted_train.sort_values(['unique_id', 'ds']))\n",
1300+
" pd.testing.assert_frame_equal(valid, permuted_valid.sort_values(['unique_id', 'ds']))\n",
1301+
"\n",
1302+
"n_series = 20\n",
1303+
"min_length = 100\n",
1304+
"max_length = 1000\n",
1305+
"series = generate_series(n_series, freq='D', min_length=min_length, max_length=max_length)\n",
1306+
"\n",
1307+
"for step_size in (None, 1, 2):\n",
1308+
" for input_size in (None, 4):\n",
1309+
" test_backtest_splits(series, n_windows=3, h=14, step_size=step_size, input_size=input_size)"
1310+
]
10771311
}
10781312
],
10791313
"metadata": {

utilsforecast/_modidx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,12 @@
6666
'utilsforecast/processing.py'),
6767
'utilsforecast.processing._polars_categorical_to_numerical': ( 'processing.html#_polars_categorical_to_numerical',
6868
'utilsforecast/processing.py'),
69+
'utilsforecast.processing._single_split': ( 'processing.html#_single_split',
70+
'utilsforecast/processing.py'),
6971
'utilsforecast.processing.assign_columns': ( 'processing.html#assign_columns',
7072
'utilsforecast/processing.py'),
73+
'utilsforecast.processing.backtest_splits': ( 'processing.html#backtest_splits',
74+
'utilsforecast/processing.py'),
7175
'utilsforecast.processing.between': ('processing.html#between', 'utilsforecast/processing.py'),
7276
'utilsforecast.processing.cast': ('processing.html#cast', 'utilsforecast/processing.py'),
7377
'utilsforecast.processing.copy_if_pandas': ( 'processing.html#copy_if_pandas',

utilsforecast/processing.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
'is_none', 'is_nan_or_none', 'match_if_categorical', 'vertical_concat', 'horizontal_concat',
66
'copy_if_pandas', 'join', 'drop_index_if_pandas', 'rename', 'sort', 'offset_dates', 'group_by',
77
'group_by_agg', 'is_in', 'between', 'fill_null', 'cast', 'value_cols_to_numpy', 'process_df',
8-
'DataFrameProcessor']
8+
'DataFrameProcessor', 'backtest_splits']
99

1010
# %% ../nbs/processing.ipynb 2
1111
import re
12-
from typing import Any, Dict, List, Optional, Tuple, Union
12+
import reprlib
13+
import warnings
14+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
1315

1416
import numpy as np
1517
import pandas as pd
@@ -441,3 +443,95 @@ def process(
441443
self, df: DataFrame
442444
) -> Tuple[Series, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
443445
return process_df(df, self.id_col, self.time_col, self.target_col)
446+
447+
# %% ../nbs/processing.ipynb 57
448+
def _single_split(
449+
df: DataFrame,
450+
i_window: int,
451+
n_windows: int,
452+
h: int,
453+
id_col: str,
454+
time_col: str,
455+
freq: Union[int, str, pd.offsets.BaseOffset],
456+
max_dates: Series,
457+
step_size: Optional[int] = None,
458+
input_size: Optional[int] = None,
459+
) -> Tuple[DataFrame, Series, Series]:
460+
if step_size is None:
461+
step_size = h
462+
test_size = h + step_size * (n_windows - 1)
463+
offset = test_size - i_window * step_size
464+
train_ends = offset_dates(max_dates, freq, -offset)
465+
valid_ends = offset_dates(train_ends, freq, h)
466+
train_mask = df[time_col].le(train_ends)
467+
valid_mask = df[time_col].gt(train_ends) & df[time_col].le(valid_ends)
468+
if input_size is not None:
469+
train_starts = offset_dates(train_ends, freq, -input_size)
470+
train_mask &= df[time_col].gt(train_starts)
471+
train_sizes = group_by(train_mask, df[id_col], maintain_order=True).sum()
472+
if isinstance(train_sizes, pd.Series):
473+
train_sizes = train_sizes.reset_index()
474+
zeros_mask = train_sizes[time_col].eq(0)
475+
if zeros_mask.all():
476+
raise ValueError(
477+
"All series are too short for the cross validation settings, "
478+
f"at least {offset + 1} samples are required.\n"
479+
"Please reduce `n_windows` or `h`."
480+
)
481+
elif zeros_mask.any():
482+
ids = filter_with_mask(train_sizes[id_col], zeros_mask)
483+
warnings.warn(
484+
"The following series are too short for the window "
485+
f"and will be dropped: {reprlib.repr(list(ids))}"
486+
)
487+
dropped_ids = is_in(df[id_col], ids)
488+
valid_mask &= ~dropped_ids
489+
if isinstance(train_ends, pd.Series):
490+
cutoffs: DataFrame = (
491+
train_ends.set_axis(df[id_col])
492+
.groupby(id_col, observed=True)
493+
.head(1)
494+
.rename("cutoff")
495+
.reset_index()
496+
)
497+
else:
498+
cutoffs = train_ends.to_frame().with_columns(df[id_col])
499+
cutoffs = (
500+
group_by(cutoffs, id_col)
501+
.agg(pl.col(time_col).head(1))
502+
.explode(pl.col(time_col))
503+
.rename({time_col: "cutoff"})
504+
)
505+
return cutoffs, train_mask, valid_mask
506+
507+
# %% ../nbs/processing.ipynb 58
508+
def backtest_splits(
509+
df: DataFrame,
510+
n_windows: int,
511+
h: int,
512+
id_col: str,
513+
time_col: str,
514+
freq: Union[int, str, pd.offsets.BaseOffset],
515+
step_size: Optional[int] = None,
516+
input_size: Optional[int] = None,
517+
) -> Generator[Tuple[DataFrame, DataFrame, DataFrame], None, None]:
518+
if isinstance(df, pd.DataFrame):
519+
max_dates = df.groupby(id_col, observed=True)[time_col].transform("max")
520+
else:
521+
max_dates = df.select(pl.col(time_col).max().over(id_col))[time_col]
522+
for i in range(n_windows):
523+
cutoffs, train_mask, valid_mask = _single_split(
524+
df,
525+
i_window=i,
526+
n_windows=n_windows,
527+
h=h,
528+
id_col=id_col,
529+
time_col=time_col,
530+
freq=freq,
531+
max_dates=max_dates,
532+
step_size=step_size,
533+
input_size=input_size,
534+
)
535+
train = filter_with_mask(df, train_mask)
536+
valid = filter_with_mask(df, valid_mask)
537+
yield cutoffs, train, valid

0 commit comments

Comments
 (0)