Skip to content

Commit 675172e

Browse files
authored
add ensure_sorted (#93)
1 parent 6c6f683 commit 675172e

File tree

3 files changed

+64
-14
lines changed

3 files changed

+64
-14
lines changed

nbs/processing.ipynb

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
"import re\n",
3434
"import reprlib\n",
3535
"import warnings\n",
36-
"from typing import Any, Dict, Generator, List, Optional, Tuple, Union\n",
36+
"from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union\n",
3737
"\n",
3838
"import numpy as np\n",
3939
"import pandas as pd\n",
@@ -242,7 +242,7 @@
242242
"source": [
243243
"for engine in engines:\n",
244244
" series = generate_series(2, engine=engine)\n",
245-
" x = np.random.rand(series.shape[0]) \n",
245+
" x = np.random.rand(series.shape[0])\n",
246246
" series = assign_columns(series, 'x', x)\n",
247247
" series = assign_columns(series, ['y', 'z'], np.vstack([x, x]).T)\n",
248248
" series = assign_columns(series, 'ones', 1)\n",
@@ -1607,7 +1607,38 @@
16071607
{
16081608
"cell_type": "code",
16091609
"execution_count": null,
1610-
"id": "d3e059e1-8c6f-41b1-b3c0-4ca40adf09e4",
1610+
"id": "033b0a8a-98b1-486b-9414-1f5ef698f80f",
1611+
"metadata": {},
1612+
"outputs": [],
1613+
"source": [
1614+
"#| export\n",
1615+
"def ensure_sorted(df: DataFrame, id_col: str, time_col: str) -> DataFrame:\n",
1616+
" sort_idxs = maybe_compute_sort_indices(df=df, id_col=id_col, time_col=time_col)\n",
1617+
" if sort_idxs is not None:\n",
1618+
" df = take_rows(df=df, idxs=sort_idxs)\n",
1619+
" return df"
1620+
]
1621+
},
1622+
{
1623+
"cell_type": "code",
1624+
"execution_count": null,
1625+
"id": "4de12264-0bd1-4eed-935b-7b7fb1cbebc0",
1626+
"metadata": {},
1627+
"outputs": [],
1628+
"source": [
1629+
"#| exporti\n",
1630+
"class _ProcessedDF(NamedTuple):\n",
1631+
" uids: Series\n",
1632+
" times: np.ndarray\n",
1633+
" data: np.ndarray\n",
1634+
" indptr: np.ndarray\n",
1635+
" sort_idxs: Optional[np.ndarray]"
1636+
]
1637+
},
1638+
{
1639+
"cell_type": "code",
1640+
"execution_count": null,
1641+
"id": "62293bd2-b921-40b2-b1af-25f0b8e55006",
16111642
"metadata": {},
16121643
"outputs": [],
16131644
"source": [
@@ -1617,7 +1648,7 @@
16171648
" id_col: str,\n",
16181649
" time_col: str,\n",
16191650
" target_col: Optional[str],\n",
1620-
") -> Tuple[Series, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:\n",
1651+
") -> _ProcessedDF:\n",
16211652
" \"\"\"Extract components from dataframe\n",
16221653
" \n",
16231654
" Parameters\n",
@@ -1660,7 +1691,7 @@
16601691
" data = data[sort_idxs]\n",
16611692
" last_idxs = sort_idxs[last_idxs]\n",
16621693
" times = df[time_col].to_numpy()[last_idxs]\n",
1663-
" return uids, times, data, indptr, sort_idxs "
1694+
" return _ProcessedDF(uids, times, data, indptr, sort_idxs)"
16641695
]
16651696
},
16661697
{

utilsforecast/_modidx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
'utilsforecast/processing.py'),
8585
'utilsforecast.processing.DataFrameProcessor.process': ( 'processing.html#dataframeprocessor.process',
8686
'utilsforecast/processing.py'),
87+
'utilsforecast.processing._ProcessedDF': ( 'processing.html#_processeddf',
88+
'utilsforecast/processing.py'),
8789
'utilsforecast.processing._ensure_month_ends': ( 'processing.html#_ensure_month_ends',
8890
'utilsforecast/processing.py'),
8991
'utilsforecast.processing._multiply_pl_freq': ( 'processing.html#_multiply_pl_freq',
@@ -111,6 +113,8 @@
111113
'utilsforecast/processing.py'),
112114
'utilsforecast.processing.drop_index_if_pandas': ( 'processing.html#drop_index_if_pandas',
113115
'utilsforecast/processing.py'),
116+
'utilsforecast.processing.ensure_sorted': ( 'processing.html#ensure_sorted',
117+
'utilsforecast/processing.py'),
114118
'utilsforecast.processing.fill_null': ( 'processing.html#fill_null',
115119
'utilsforecast/processing.py'),
116120
'utilsforecast.processing.filter_with_mask': ( 'processing.html#filter_with_mask',

utilsforecast/processing.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
'filter_with_mask', 'is_nan', 'is_none', 'is_nan_or_none', 'match_if_categorical', 'vertical_concat',
66
'horizontal_concat', 'copy_if_pandas', 'join', 'drop_index_if_pandas', 'rename', 'sort', 'offset_times',
77
'offset_dates', 'time_ranges', 'repeat', 'cv_times', 'group_by', 'group_by_agg', 'is_in', 'between',
8-
'fill_null', 'cast', 'value_cols_to_numpy', 'make_future_dataframe', 'anti_join', 'process_df',
9-
'DataFrameProcessor', 'backtest_splits', 'add_insample_levels']
8+
'fill_null', 'cast', 'value_cols_to_numpy', 'make_future_dataframe', 'anti_join', 'ensure_sorted',
9+
'process_df', 'DataFrameProcessor', 'backtest_splits', 'add_insample_levels']
1010

1111
# %% ../nbs/processing.ipynb 2
1212
import re
1313
import reprlib
1414
import warnings
15-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
15+
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
1616

1717
import numpy as np
1818
import pandas as pd
@@ -626,12 +626,27 @@ def anti_join(df1: DataFrame, df2: DataFrame, on: Union[str, List[str]]) -> Data
626626
return out
627627

628628
# %% ../nbs/processing.ipynb 74
629+
def ensure_sorted(df: DataFrame, id_col: str, time_col: str) -> DataFrame:
630+
sort_idxs = maybe_compute_sort_indices(df=df, id_col=id_col, time_col=time_col)
631+
if sort_idxs is not None:
632+
df = take_rows(df=df, idxs=sort_idxs)
633+
return df
634+
635+
# %% ../nbs/processing.ipynb 75
636+
class _ProcessedDF(NamedTuple):
637+
uids: Series
638+
times: np.ndarray
639+
data: np.ndarray
640+
indptr: np.ndarray
641+
sort_idxs: Optional[np.ndarray]
642+
643+
# %% ../nbs/processing.ipynb 76
629644
def process_df(
630645
df: DataFrame,
631646
id_col: str,
632647
time_col: str,
633648
target_col: Optional[str],
634-
) -> Tuple[Series, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
649+
) -> _ProcessedDF:
635650
"""Extract components from dataframe
636651
637652
Parameters
@@ -674,9 +689,9 @@ def process_df(
674689
data = data[sort_idxs]
675690
last_idxs = sort_idxs[last_idxs]
676691
times = df[time_col].to_numpy()[last_idxs]
677-
return uids, times, data, indptr, sort_idxs
692+
return _ProcessedDF(uids, times, data, indptr, sort_idxs)
678693

679-
# %% ../nbs/processing.ipynb 76
694+
# %% ../nbs/processing.ipynb 78
680695
class DataFrameProcessor:
681696
def __init__(
682697
self,
@@ -693,7 +708,7 @@ def process(
693708
) -> Tuple[Series, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
694709
return process_df(df, self.id_col, self.time_col, self.target_col)
695710

696-
# %% ../nbs/processing.ipynb 81
711+
# %% ../nbs/processing.ipynb 83
697712
def _single_split(
698713
df: DataFrame,
699714
i_window: int,
@@ -758,7 +773,7 @@ def _single_split(
758773
)
759774
return cutoffs, train_mask, valid_mask
760775

761-
# %% ../nbs/processing.ipynb 82
776+
# %% ../nbs/processing.ipynb 84
762777
def backtest_splits(
763778
df: DataFrame,
764779
n_windows: int,
@@ -790,7 +805,7 @@ def backtest_splits(
790805
valid = filter_with_mask(df, valid_mask)
791806
yield cutoffs, train, valid
792807

793-
# %% ../nbs/processing.ipynb 86
808+
# %% ../nbs/processing.ipynb 88
794809
def add_insample_levels(
795810
df: DataFrame,
796811
models: List[str],

0 commit comments

Comments
 (0)