|
72 | 72 | " else:\n", |
73 | 73 | " if isinstance(freq, str):\n", |
74 | 74 | " # this raises a nice error message if it isn't a valid datetime\n", |
| 75 | + " if isinstance(bound, pd.Timestamp) and bound.tz is not None:\n", |
| 76 | + " bound = bound.tz_localize(None)\n", |
75 | 77 | " val = np.datetime64(bound)\n", |
76 | 78 | " else:\n", |
77 | 79 | " val = bound\n", |
|
203 | 205 | " # such as MS = 'Month Start' -> 'M', YS = 'Year Start' -> 'Y'\n", |
204 | 206 | " freq = freq[0]\n", |
205 | 207 | " delta: Union[np.timedelta64, int] = np.timedelta64(n, freq)\n", |
| 208 | + " tz = df[time_col].dt.tz\n", |
| 209 | + " if tz is not None:\n", |
| 210 | + " df = df.copy(deep=False)\n", |
| 211 | + " df[time_col] = df[time_col].dt.tz_localize(None)\n", |
206 | 212 | " else:\n", |
207 | 213 | " delta = freq\n", |
| 214 | + " tz = None\n", |
208 | 215 | " times_by_id = df.groupby(id_col, observed=True)[time_col].agg(['min', 'max'])\n", |
209 | 216 | " starts = _determine_bound(start, freq, times_by_id, 'min')\n", |
210 | 217 | " ends = _determine_bound(end, freq, times_by_id, 'max') + delta\n", |
|
228 | 235 | " times += offset.base\n", |
229 | 236 | " idx = pd.MultiIndex.from_arrays([uids, times], names=[id_col, time_col])\n", |
230 | 237 | " res = df.set_index([id_col, time_col]).reindex(idx).reset_index()\n", |
| 238 | + " if tz is not None:\n", |
| 239 | + " res[time_col] = res[time_col].dt.tz_localize(tz, ambiguous='infer')\n", |
231 | 240 | " extra_cols = df.columns.drop([id_col, time_col]).tolist()\n", |
232 | 241 | " if extra_cols:\n", |
233 | 242 | " check_col = extra_cols[0]\n", |
|
252 | 261 | "text/markdown": [ |
253 | 262 | "---\n", |
254 | 263 | "\n", |
255 | | - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 264 | + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L58){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
256 | 265 | "\n", |
257 | 266 | "### fill_gaps\n", |
258 | 267 | "\n", |
|
278 | 287 | "text/plain": [ |
279 | 288 | "---\n", |
280 | 289 | "\n", |
281 | | - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 290 | + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L58){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
282 | 291 | "\n", |
283 | 292 | "### fill_gaps\n", |
284 | 293 | "\n", |
|
1598 | 1607 | " # inferred frequency is the expected\n", |
1599 | 1608 | " first_serie = filled[filled['unique_id'] == 1]\n", |
1600 | 1609 | " if isinstance(freq, str):\n", |
1601 | | - " inferred_freq = pd.infer_freq(first_serie['ds'])\n", |
| 1610 | + " inferred_freq = pd.infer_freq(first_serie['ds'].dt.tz_localize(None))\n", |
1602 | 1611 | " assert inferred_freq == pd.tseries.frequencies.to_offset(freq)\n", |
1603 | 1612 | " else:\n", |
1604 | 1613 | " assert all(first_serie['ds'].diff().value_counts().index == [freq])\n", |
|
1628 | 1637 | " assert max_dates[0] == expected_end\n", |
1629 | 1638 | "\n", |
1630 | 1639 | "n_periods = 100\n", |
1631 | | - "freqs = ['YE', 'YS', 'ME', 'MS', 'W', 'W-TUE', 'D', 's', 'ms', 1, 2, '20D', '30s', '2YE', '3YS', '30min', 'B', '1h', 'QS-OCT', 'QE']\n", |
| 1640 | + "freqs = ['YE', 'YS', 'ME', 'MS', 'W', 'W-TUE', 'D', 's', 'ms', 1, 2, '20D', '30s', '2YE', '3YS', '30min', 'B', '1h', 'QS-NOV', 'QE']\n", |
1632 | 1641 | "try:\n", |
1633 | 1642 | " pd.tseries.frequencies.to_offset('YE')\n", |
1634 | 1643 | "except ValueError:\n", |
|
1640 | 1649 | " for f in freqs if isinstance(f, str)\n", |
1641 | 1650 | " ]\n", |
1642 | 1651 | "for freq in freqs:\n", |
1643 | | - " if isinstance(freq, (pd.offsets.BaseOffset, str)): \n", |
| 1652 | + " if isinstance(freq, (pd.offsets.BaseOffset, str)):\n", |
1644 | 1653 | " dates = pd.date_range('1900-01-01', periods=n_periods, freq=freq)\n", |
| 1654 | + " dates = dates.tz_localize('Europe/Berlin')\n", |
1645 | 1655 | " offset = pd.tseries.frequencies.to_offset(freq)\n", |
1646 | 1656 | " else:\n", |
1647 | 1657 | " dates = np.arange(0, freq * n_periods, freq, dtype=np.int64)\n", |
|
0 commit comments