|
31 | 31 | "source": [ |
32 | 32 | "#| export\n", |
33 | 33 | "import re\n", |
34 | | - "from typing import Any, Dict, List, Optional, Tuple, Union\n", |
| 34 | + "import reprlib\n", |
| 35 | + "import warnings\n", |
| 36 | + "from typing import Any, Dict, Generator, List, Optional, Tuple, Union\n", |
35 | 37 | "\n", |
36 | 38 | "import numpy as np\n", |
37 | 39 | "import pandas as pd\n", |
|
48 | 50 | "metadata": {}, |
49 | 51 | "outputs": [], |
50 | 52 | "source": [ |
51 | | - "from fastcore.test import test_eq\n", |
| 53 | + "from fastcore.test import test_eq, test_fail\n", |
52 | 54 | "from nbdev import show_doc\n", |
53 | 55 | "\n", |
54 | 56 | "from utilsforecast.compat import POLARS_INSTALLED\n", |
|
1074 | 1076 | " test_eq(data, series_pl.select(pl.col(c).map_batches(lambda s: s.to_physical()) for c in ['y'] + static_features[:n_static_features]).to_numpy())\n", |
1075 | 1077 | " test_eq(np.diff(indptr), grouped.count().sort('unique_id')['count'].to_numpy())" |
1076 | 1078 | ] |
| 1079 | + }, |
| 1080 | + { |
| 1081 | + "cell_type": "code", |
| 1082 | + "execution_count": null, |
| 1083 | + "id": "37bc6aa4-ce39-4559-9964-01c06b7d7dbd", |
| 1084 | + "metadata": {}, |
| 1085 | + "outputs": [], |
| 1086 | + "source": [ |
| 1087 | + "#| exporti\n", |
| 1088 | + "def _single_split(\n", |
| 1089 | + " df: DataFrame,\n", |
| 1090 | + " i_window: int, \n", |
| 1091 | + " n_windows: int,\n", |
| 1092 | + " h: int,\n", |
| 1093 | + " id_col: str,\n", |
| 1094 | + " time_col: str,\n", |
| 1095 | + " freq: Union[int, str, pd.offsets.BaseOffset],\n", |
| 1096 | + " max_dates: Series, \n", |
| 1097 | + " step_size: Optional[int] = None,\n", |
| 1098 | + " input_size: Optional[int] = None,\n", |
| 1099 | + ") -> Tuple[DataFrame, Series, Series]:\n", |
| 1100 | + " if step_size is None:\n", |
| 1101 | + " step_size = h\n", |
| 1102 | + " test_size = h + step_size * (n_windows - 1)\n", |
| 1103 | + " offset = test_size - i_window * step_size\n", |
| 1104 | + " train_ends = offset_dates(max_dates, freq, -offset)\n", |
| 1105 | + " valid_ends = offset_dates(train_ends, freq, h)\n", |
| 1106 | + " train_mask = df[time_col].le(train_ends)\n", |
| 1107 | + " valid_mask = df[time_col].gt(train_ends) & df[time_col].le(valid_ends) \n", |
| 1108 | + " if input_size is not None:\n", |
| 1109 | + " train_starts = offset_dates(train_ends, freq, -input_size)\n", |
| 1110 | + " train_mask &= df[time_col].gt(train_starts)\n", |
| 1111 | + " train_sizes = group_by(train_mask, df[id_col], maintain_order=True).sum()\n", |
| 1112 | + " if isinstance(train_sizes, pd.Series):\n", |
| 1113 | + " train_sizes = train_sizes.reset_index()\n", |
| 1114 | + " zeros_mask = train_sizes[time_col].eq(0) \n", |
| 1115 | + " if zeros_mask.all():\n", |
| 1116 | + " raise ValueError(\n", |
| 1117 | + " 'All series are too short for the cross validation settings, '\n", |
| 1118 | + " f'at least {offset + 1} samples are required.\\n'\n", |
| 1119 | + " 'Please reduce `n_windows` or `h`.'\n", |
| 1120 | + " )\n", |
| 1121 | + " elif zeros_mask.any():\n", |
| 1122 | + " ids = filter_with_mask(train_sizes[id_col], zeros_mask)\n", |
| 1123 | + " warnings.warn(\n", |
| 1124 | + " 'The following series are too short for the window '\n", |
| 1125 | + " f'and will be dropped: {reprlib.repr(list(ids))}'\n", |
| 1126 | + " )\n", |
| 1127 | + " dropped_ids = is_in(df[id_col], ids)\n", |
| 1128 | + " valid_mask &= ~dropped_ids\n", |
| 1129 | + " if isinstance(train_ends, pd.Series):\n", |
| 1130 | + " cutoffs: DataFrame = (\n", |
| 1131 | + " train_ends\n", |
| 1132 | + " .set_axis(df[id_col])\n", |
| 1133 | + " .groupby(id_col, observed=True)\n", |
| 1134 | + " .head(1)\n", |
| 1135 | + " .rename(\"cutoff\")\n", |
| 1136 | + " .reset_index()\n", |
| 1137 | + " )\n", |
| 1138 | + " else:\n", |
| 1139 | + " cutoffs = train_ends.to_frame().with_columns(df[id_col])\n", |
| 1140 | + " cutoffs = (\n", |
| 1141 | + " group_by(cutoffs, id_col)\n", |
| 1142 | + " .agg(pl.col(time_col).head(1))\n", |
| 1143 | + " .explode(pl.col(time_col))\n", |
| 1144 | + " .rename({time_col: 'cutoff'})\n", |
| 1145 | + " )\n", |
| 1146 | + " return cutoffs, train_mask, valid_mask" |
| 1147 | + ] |
| 1148 | + }, |
| 1149 | + { |
| 1150 | + "cell_type": "code", |
| 1151 | + "execution_count": null, |
| 1152 | + "id": "c5c3370a-9a55-4436-9326-b459d03525dc", |
| 1153 | + "metadata": {}, |
| 1154 | + "outputs": [], |
| 1155 | + "source": [ |
| 1156 | + "#|export\n", |
| 1157 | + "def backtest_splits(\n", |
| 1158 | + " df: DataFrame,\n", |
| 1159 | + " n_windows: int,\n", |
| 1160 | + " h: int,\n", |
| 1161 | + " id_col: str,\n", |
| 1162 | + " time_col: str,\n", |
| 1163 | + " freq: Union[int, str, pd.offsets.BaseOffset],\n", |
| 1164 | + " step_size: Optional[int] = None,\n", |
| 1165 | + " input_size: Optional[int] = None,\n", |
| 1166 | + ") -> Generator[Tuple[DataFrame, DataFrame, DataFrame], None, None]:\n", |
| 1167 | + " if isinstance(df, pd.DataFrame):\n", |
| 1168 | + " max_dates = df.groupby(id_col, observed=True)[time_col].transform('max')\n", |
| 1169 | + " else:\n", |
| 1170 | + " max_dates = df.select(pl.col(time_col).max().over(id_col))[time_col]\n", |
| 1171 | + " for i in range(n_windows):\n", |
| 1172 | + " cutoffs, train_mask, valid_mask = _single_split(\n", |
| 1173 | + " df,\n", |
| 1174 | + " i_window=i,\n", |
| 1175 | + " n_windows=n_windows,\n", |
| 1176 | + " h=h,\n", |
| 1177 | + " id_col=id_col,\n", |
| 1178 | + " time_col=time_col,\n", |
| 1179 | + " freq=freq,\n", |
| 1180 | + " max_dates=max_dates,\n", |
| 1181 | + " step_size=step_size,\n", |
| 1182 | + " input_size=input_size,\n", |
| 1183 | + " )\n", |
| 1184 | + " train = filter_with_mask(df, train_mask)\n", |
| 1185 | + " valid = filter_with_mask(df, valid_mask)\n", |
| 1186 | + " yield cutoffs, train, valid" |
| 1187 | + ] |
| 1188 | + }, |
| 1189 | + { |
| 1190 | + "cell_type": "code", |
| 1191 | + "execution_count": null, |
| 1192 | + "id": "ae3ef1ca-418c-4506-990f-0502481c6fef", |
| 1193 | + "metadata": {}, |
| 1194 | + "outputs": [], |
| 1195 | + "source": [ |
| 1196 | + "#| hide\n", |
| 1197 | + "short_series = generate_series(100, max_length=50)\n", |
| 1198 | + "backtest_results = list(\n", |
| 1199 | + " backtest_splits(\n", |
| 1200 | + " short_series,\n", |
| 1201 | + " n_windows=1,\n", |
| 1202 | + " h=49,\n", |
| 1203 | + " id_col='unique_id',\n", |
| 1204 | + " time_col='ds',\n", |
| 1205 | + " freq=pd.offsets.Day(),\n", |
| 1206 | + " )\n", |
| 1207 | + ")[0]\n", |
| 1208 | + "test_fail(\n", |
| 1209 | + " lambda: list(\n", |
| 1210 | + " backtest_splits(\n", |
| 1211 | + " short_series,\n", |
| 1212 | + " n_windows=1,\n", |
| 1213 | + " h=50,\n", |
| 1214 | + " id_col='unique_id',\n", |
| 1215 | + " time_col='ds',\n", |
| 1216 | + " freq=pd.offsets.Day(),\n", |
| 1217 | + " )\n", |
| 1218 | + " ),\n", |
| 1219 | + " contains='at least 51 samples are required'\n", |
| 1220 | + ")\n", |
| 1221 | + "some_short_series = generate_series(100, min_length=20, max_length=100)\n", |
| 1222 | + "with warnings.catch_warnings(record=True) as issued_warnings:\n", |
| 1223 | + " warnings.simplefilter('always', UserWarning)\n", |
| 1224 | + " splits = list(\n", |
| 1225 | + " backtest_splits(\n", |
| 1226 | + " some_short_series,\n", |
| 1227 | + " n_windows=1,\n", |
| 1228 | + " h=50,\n", |
| 1229 | + " id_col='unique_id',\n", |
| 1230 | + " time_col='ds',\n", |
| 1231 | + " freq=pd.offsets.Day(),\n", |
| 1232 | + " )\n", |
| 1233 | + " )\n", |
| 1234 | + " assert any('will be dropped' in str(w.message) for w in issued_warnings)\n", |
| 1235 | + "short_series_int = short_series.copy()\n", |
| 1236 | + "short_series_int['ds'] = short_series.groupby('unique_id', observed=True).transform('cumcount')\n", |
| 1237 | + "backtest_int_results = list(\n", |
| 1238 | + " backtest_splits(\n", |
| 1239 | + " short_series_int,\n", |
| 1240 | + " n_windows=1,\n", |
| 1241 | + " h=40,\n", |
| 1242 | + " id_col='unique_id',\n", |
| 1243 | + " time_col='ds',\n", |
| 1244 | + " freq=1\n", |
| 1245 | + " )\n", |
| 1246 | + ")[0]" |
| 1247 | + ] |
| 1248 | + }, |
| 1249 | + { |
| 1250 | + "cell_type": "code", |
| 1251 | + "execution_count": null, |
| 1252 | + "id": "cbecf3fc-0354-4d3c-82ac-39929e50a01d", |
| 1253 | + "metadata": {}, |
| 1254 | + "outputs": [], |
| 1255 | + "source": [ |
| 1256 | + "#| hide\n", |
| 1257 | + "def test_backtest_splits(df, n_windows, h, step_size, input_size):\n", |
| 1258 | + " max_dates = df.groupby('unique_id', observed=True)['ds'].max()\n", |
| 1259 | + " day_offset = pd.offsets.Day() \n", |
| 1260 | + " common_kwargs = dict(\n", |
| 1261 | + " n_windows=n_windows,\n", |
| 1262 | + " h=h,\n", |
| 1263 | + " id_col='unique_id',\n", |
| 1264 | + " time_col='ds',\n", |
| 1265 | + " freq=pd.offsets.Day(), \n", |
| 1266 | + " step_size=step_size,\n", |
| 1267 | + " input_size=input_size, \n", |
| 1268 | + " )\n", |
| 1269 | + " permuted_df = df.sample(frac=1.0)\n", |
| 1270 | + " splits = backtest_splits(df, **common_kwargs)\n", |
| 1271 | + " splits_on_permuted = list(backtest_splits(permuted_df, **common_kwargs))\n", |
| 1272 | + " if step_size is None:\n", |
| 1273 | + " step_size = h\n", |
| 1274 | + " test_size = h + step_size * (n_windows - 1)\n", |
| 1275 | + " for window, (cutoffs, train, valid) in enumerate(splits):\n", |
| 1276 | + " offset = test_size - window * step_size\n", |
| 1277 | + " expected_max_train_dates = max_dates - day_offset * offset\n", |
| 1278 | + " max_train_dates = train.groupby('unique_id', observed=True)['ds'].max()\n", |
| 1279 | + " pd.testing.assert_series_equal(max_train_dates, expected_max_train_dates)\n", |
| 1280 | + " pd.testing.assert_frame_equal(cutoffs, max_train_dates.rename('cutoff').reset_index())\n", |
| 1281 | + " \n", |
| 1282 | + " if input_size is not None:\n", |
| 1283 | + " expected_min_train_dates = expected_max_train_dates - day_offset * (input_size - 1)\n", |
| 1284 | + " min_train_dates = train.groupby('unique_id', observed=True)['ds'].min()\n", |
| 1285 | + " pd.testing.assert_series_equal(min_train_dates, expected_min_train_dates)\n", |
| 1286 | + "\n", |
| 1287 | + " expected_min_valid_dates = expected_max_train_dates + day_offset\n", |
| 1288 | + " min_valid_dates = valid.groupby('unique_id', observed=True)['ds'].min()\n", |
| 1289 | + " pd.testing.assert_series_equal(min_valid_dates, expected_min_valid_dates)\n", |
| 1290 | + "\n", |
| 1291 | + " expected_max_valid_dates = expected_max_train_dates + day_offset * h\n", |
| 1292 | + " max_valid_dates = valid.groupby('unique_id', observed=True)['ds'].max()\n", |
| 1293 | + " pd.testing.assert_series_equal(max_valid_dates, expected_max_valid_dates)\n", |
| 1294 | + "\n", |
| 1295 | + " if window == n_windows - 1:\n", |
| 1296 | + " pd.testing.assert_series_equal(max_valid_dates, max_dates)\n", |
| 1297 | + "\n", |
| 1298 | + " _, permuted_train, permuted_valid = splits_on_permuted[window] \n", |
| 1299 | + " pd.testing.assert_frame_equal(train, permuted_train.sort_values(['unique_id', 'ds']))\n", |
| 1300 | + " pd.testing.assert_frame_equal(valid, permuted_valid.sort_values(['unique_id', 'ds']))\n", |
| 1301 | + "\n", |
| 1302 | + "n_series = 20\n", |
| 1303 | + "min_length = 100\n", |
| 1304 | + "max_length = 1000\n", |
| 1305 | + "series = generate_series(n_series, freq='D', min_length=min_length, max_length=max_length)\n", |
| 1306 | + "\n", |
| 1307 | + "for step_size in (None, 1, 2):\n", |
| 1308 | + " for input_size in (None, 4):\n", |
| 1309 | + " test_backtest_splits(series, n_windows=3, h=14, step_size=step_size, input_size=input_size)" |
| 1310 | + ] |
1077 | 1311 | } |
1078 | 1312 | ], |
1079 | 1313 | "metadata": { |
|
0 commit comments