Skip to content

Commit 49fd37b

Browse files
authored
handle month ends for polars in offset_times (#42)
1 parent 534283c commit 49fd37b

File tree

2 files changed

+64
-15
lines changed

2 files changed

+64
-15
lines changed

nbs/processing.ipynb

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,11 +749,51 @@
749749
" elif isinstance(times, pl_Series) and isinstance(freq, str):\n",
750750
" total_offset = _multiply_pl_freq(freq, n)\n",
751751
" out = times.dt.offset_by(total_offset)\n",
752+
" if 'mo' in freq:\n",
753+
" next_days = times.dt.offset_by('1d')\n",
754+
" month_ends = (next_days.dt.month() != times.dt.month()).all()\n",
755+
" if month_ends:\n",
756+
" out = out.dt.month_end()\n",
752757
" else:\n",
753758
" raise ValueError(f\"Can't process the following combination {(type(times), type(freq))}\")\n",
754759
" return out"
755760
]
756761
},
762+
{
763+
"cell_type": "code",
764+
"execution_count": null,
765+
"id": "03a9f253-3753-4c11-8a7a-410ea924469a",
766+
"metadata": {},
767+
"outputs": [],
768+
"source": [
769+
"pd.testing.assert_index_equal(\n",
770+
" offset_times(pd.to_datetime(['2020-01-31', '2020-02-29', '2020-03-31']), 'M', 1),\n",
771+
" pd.Index(pd.to_datetime(['2020-02-29', '2020-03-31', '2020-04-30'])),\n",
772+
")\n",
773+
"pd.testing.assert_index_equal(\n",
774+
" offset_times(pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01']), 'MS', 1),\n",
775+
" pd.Index(pd.to_datetime(['2020-02-01', '2020-03-01', '2020-04-01'])),\n",
776+
")"
777+
]
778+
},
779+
{
780+
"cell_type": "code",
781+
"execution_count": null,
782+
"id": "63ad919f-fea3-4f61-a766-6f952da8bf75",
783+
"metadata": {},
784+
"outputs": [],
785+
"source": [
786+
"#| polars\n",
787+
"pl.testing.assert_series_equal(\n",
788+
" offset_times(pl_Series([dt(2020, 1, 31), dt(2020, 2, 28), dt(2020, 3, 31)]), '1mo_saturating', 1),\n",
789+
" pl_Series([dt(2020, 2, 29), dt(2020, 3, 28), dt(2020, 4, 30)]),\n",
790+
")\n",
791+
"pl.testing.assert_series_equal(\n",
792+
" offset_times(pl_Series([dt(2020, 1, 31), dt(2020, 2, 29), dt(2020, 3, 31)]), '1mo_saturating', 1),\n",
793+
" pl_Series([dt(2020, 2, 29), dt(2020, 3, 31), dt(2020, 4, 30)]),\n",
794+
")"
795+
]
796+
},
757797
{
758798
"cell_type": "code",
759799
"execution_count": null,
@@ -1017,6 +1057,8 @@
10171057
" valid_idxs = np.repeat(cutoff_idxs + 1, h) + np.tile(np.arange(h), cutoff_idxs.size)\n",
10181058
" out_times.append(times[valid_idxs])\n",
10191059
" out_cutoffs.append(np.repeat(times[cutoff_idxs], h))\n",
1060+
" if isinstance(uids, pl_Series):\n",
1061+
" use_series = pl_Series(use_series)\n",
10201062
" out_ids.append(repeat(filter_with_mask(uids, use_series), h))\n",
10211063
" return df_constructor(\n",
10221064
" {\n",

utilsforecast/processing.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -328,13 +328,18 @@ def offset_times(
328328
elif isinstance(times, pl_Series) and isinstance(freq, str):
329329
total_offset = _multiply_pl_freq(freq, n)
330330
out = times.dt.offset_by(total_offset)
331+
if "mo" in freq:
332+
next_days = times.dt.offset_by("1d")
333+
month_ends = (next_days.dt.month() != times.dt.month()).all()
334+
if month_ends:
335+
out = out.dt.month_end()
331336
else:
332337
raise ValueError(
333338
f"Can't process the following combination {(type(times), type(freq))}"
334339
)
335340
return out
336341

337-
# %% ../nbs/processing.ipynb 38
342+
# %% ../nbs/processing.ipynb 40
338343
def offset_dates(
339344
dates: Union[Series, pd.Index],
340345
freq: Union[int, str, BaseOffset],
@@ -345,7 +350,7 @@ def offset_dates(
345350
)
346351
return offset_times(dates, freq, n)
347352

348-
# %% ../nbs/processing.ipynb 39
353+
# %% ../nbs/processing.ipynb 41
349354
def time_ranges(
350355
starts: Union[Series, pd.Index],
351356
freq: Union[int, str, BaseOffset],
@@ -384,7 +389,7 @@ def time_ranges(
384389
out = out.alias(starts.name)
385390
return out
386391

387-
# %% ../nbs/processing.ipynb 42
392+
# %% ../nbs/processing.ipynb 44
388393
def repeat(
389394
s: Union[Series, pd.Index, np.ndarray], n: Union[int, np.ndarray, Series]
390395
) -> Union[Series, pd.Index, np.ndarray]:
@@ -403,7 +408,7 @@ def repeat(
403408
out = out.reset_index(drop=True)
404409
return out
405410

406-
# %% ../nbs/processing.ipynb 45
411+
# %% ../nbs/processing.ipynb 47
407412
def cv_times(
408413
times: np.ndarray,
409414
uids: Union[Series, pd.Index],
@@ -435,6 +440,8 @@ def cv_times(
435440
)
436441
out_times.append(times[valid_idxs])
437442
out_cutoffs.append(np.repeat(times[cutoff_idxs], h))
443+
if isinstance(uids, pl_Series):
444+
use_series = pl_Series(use_series)
438445
out_ids.append(repeat(filter_with_mask(uids, use_series), h))
439446
return df_constructor(
440447
{
@@ -444,7 +451,7 @@ def cv_times(
444451
}
445452
)
446453

447-
# %% ../nbs/processing.ipynb 47
454+
# %% ../nbs/processing.ipynb 49
448455
def group_by(df: Union[Series, DataFrame], by, maintain_order=False):
449456
if isinstance(df, (pd.Series, pd.DataFrame)):
450457
out = df.groupby(by, observed=True, sort=not maintain_order)
@@ -457,7 +464,7 @@ def group_by(df: Union[Series, DataFrame], by, maintain_order=False):
457464
out = df.groupby(by, maintain_order=maintain_order)
458465
return out
459466

460-
# %% ../nbs/processing.ipynb 48
467+
# %% ../nbs/processing.ipynb 50
461468
def group_by_agg(df: DataFrame, by, aggs, maintain_order=False) -> DataFrame:
462469
if isinstance(df, pd.DataFrame):
463470
out = group_by(df, by, maintain_order).agg(aggs).reset_index()
@@ -467,39 +474,39 @@ def group_by_agg(df: DataFrame, by, aggs, maintain_order=False) -> DataFrame:
467474
)
468475
return out
469476

470-
# %% ../nbs/processing.ipynb 51
477+
# %% ../nbs/processing.ipynb 53
471478
def is_in(s: Series, collection) -> Series:
472479
if isinstance(s, pl_Series):
473480
out = s.is_in(collection)
474481
else:
475482
out = s.isin(collection)
476483
return out
477484

478-
# %% ../nbs/processing.ipynb 54
485+
# %% ../nbs/processing.ipynb 56
479486
def between(s: Series, lower: Series, upper: Series) -> Series:
480487
if isinstance(s, pd.Series):
481488
out = s.between(lower, upper)
482489
else:
483490
out = s.is_between(lower, upper)
484491
return out
485492

486-
# %% ../nbs/processing.ipynb 57
493+
# %% ../nbs/processing.ipynb 59
487494
def fill_null(df: DataFrame, mapping: Dict[str, Any]) -> DataFrame:
488495
if isinstance(df, pd.DataFrame):
489496
out = df.fillna(mapping)
490497
else:
491498
out = df.with_columns(*[pl.col(col).fill_null(v) for col, v in mapping.items()])
492499
return out
493500

494-
# %% ../nbs/processing.ipynb 60
501+
# %% ../nbs/processing.ipynb 62
495502
def cast(s: Series, dtype: type) -> Series:
496503
if isinstance(s, pd.Series):
497504
s = s.astype(dtype)
498505
else:
499506
s = s.cast(dtype)
500507
return s
501508

502-
# %% ../nbs/processing.ipynb 63
509+
# %% ../nbs/processing.ipynb 65
503510
def value_cols_to_numpy(
504511
df: DataFrame, id_col: str, time_col: str, target_col: str
505512
) -> np.ndarray:
@@ -510,7 +517,7 @@ def value_cols_to_numpy(
510517
data = data.astype(np.float32)
511518
return data
512519

513-
# %% ../nbs/processing.ipynb 64
520+
# %% ../nbs/processing.ipynb 66
514521
def process_df(
515522
df: DataFrame, id_col: str, time_col: str, target_col: str
516523
) -> Tuple[Series, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
@@ -558,7 +565,7 @@ def process_df(
558565
times = df[time_col].to_numpy()[last_idxs]
559566
return uids, times, data, indptr, sort_idxs
560567

561-
# %% ../nbs/processing.ipynb 66
568+
# %% ../nbs/processing.ipynb 68
562569
class DataFrameProcessor:
563570
def __init__(
564571
self,
@@ -575,7 +582,7 @@ def process(
575582
) -> Tuple[Series, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
576583
return process_df(df, self.id_col, self.time_col, self.target_col)
577584

578-
# %% ../nbs/processing.ipynb 70
585+
# %% ../nbs/processing.ipynb 72
579586
def _single_split(
580587
df: DataFrame,
581588
i_window: int,
@@ -635,7 +642,7 @@ def _single_split(
635642
)
636643
return cutoffs, train_mask, valid_mask
637644

638-
# %% ../nbs/processing.ipynb 71
645+
# %% ../nbs/processing.ipynb 73
639646
def backtest_splits(
640647
df: DataFrame,
641648
n_windows: int,

0 commit comments

Comments
 (0)