55 'filter_with_mask' , 'is_nan' , 'is_none' , 'is_nan_or_none' , 'match_if_categorical' , 'vertical_concat' ,
66 'horizontal_concat' , 'copy_if_pandas' , 'join' , 'drop_index_if_pandas' , 'rename' , 'sort' , 'offset_times' ,
77 'offset_dates' , 'time_ranges' , 'repeat' , 'cv_times' , 'group_by' , 'group_by_agg' , 'is_in' , 'between' ,
8- 'fill_null' , 'cast' , 'value_cols_to_numpy' , 'make_future_dataframe' , 'anti_join' , 'process_df ' ,
9- 'DataFrameProcessor' , 'backtest_splits' , 'add_insample_levels' ]
8+ 'fill_null' , 'cast' , 'value_cols_to_numpy' , 'make_future_dataframe' , 'anti_join' , 'ensure_sorted ' ,
9+ 'process_df' , ' DataFrameProcessor' , 'backtest_splits' , 'add_insample_levels' ]
1010
1111# %% ../nbs/processing.ipynb 2
1212import re
1313import reprlib
1414import warnings
15- from typing import Any , Dict , Generator , List , Optional , Tuple , Union
15+ from typing import Any , Dict , Generator , List , NamedTuple , Optional , Tuple , Union
1616
1717import numpy as np
1818import pandas as pd
@@ -626,12 +626,27 @@ def anti_join(df1: DataFrame, df2: DataFrame, on: Union[str, List[str]]) -> Data
626626 return out
627627
628628# %% ../nbs/processing.ipynb 74
629+ def ensure_sorted (df : DataFrame , id_col : str , time_col : str ) -> DataFrame :
630+ sort_idxs = maybe_compute_sort_indices (df = df , id_col = id_col , time_col = time_col )
631+ if sort_idxs is not None :
632+ df = take_rows (df = df , idxs = sort_idxs )
633+ return df
634+
635+ # %% ../nbs/processing.ipynb 75
636+ class _ProcessedDF (NamedTuple ):
637+ uids : Series
638+ times : np .ndarray
639+ data : np .ndarray
640+ indptr : np .ndarray
641+ sort_idxs : Optional [np .ndarray ]
642+
643+ # %% ../nbs/processing.ipynb 76
629644def process_df (
630645 df : DataFrame ,
631646 id_col : str ,
632647 time_col : str ,
633648 target_col : Optional [str ],
634- ) -> Tuple [ Series , np . ndarray , np . ndarray , np . ndarray , Optional [ np . ndarray ]] :
649+ ) -> _ProcessedDF :
635650 """Extract components from dataframe
636651
637652 Parameters
@@ -674,9 +689,9 @@ def process_df(
674689 data = data [sort_idxs ]
675690 last_idxs = sort_idxs [last_idxs ]
676691 times = df [time_col ].to_numpy ()[last_idxs ]
677- return uids , times , data , indptr , sort_idxs
692+ return _ProcessedDF ( uids , times , data , indptr , sort_idxs )
678693
679- # %% ../nbs/processing.ipynb 76
694+ # %% ../nbs/processing.ipynb 78
680695class DataFrameProcessor :
681696 def __init__ (
682697 self ,
@@ -693,7 +708,7 @@ def process(
693708 ) -> Tuple [Series , np .ndarray , np .ndarray , np .ndarray , Optional [np .ndarray ]]:
694709 return process_df (df , self .id_col , self .time_col , self .target_col )
695710
696- # %% ../nbs/processing.ipynb 81
711+ # %% ../nbs/processing.ipynb 83
697712def _single_split (
698713 df : DataFrame ,
699714 i_window : int ,
@@ -758,7 +773,7 @@ def _single_split(
758773 )
759774 return cutoffs , train_mask , valid_mask
760775
761- # %% ../nbs/processing.ipynb 82
776+ # %% ../nbs/processing.ipynb 84
762777def backtest_splits (
763778 df : DataFrame ,
764779 n_windows : int ,
@@ -790,7 +805,7 @@ def backtest_splits(
790805 valid = filter_with_mask (df , valid_mask )
791806 yield cutoffs , train , valid
792807
793- # %% ../nbs/processing.ipynb 86
808+ # %% ../nbs/processing.ipynb 88
794809def add_insample_levels (
795810 df : DataFrame ,
796811 models : List [str ],
0 commit comments