-
-
Notifications
You must be signed in to change notification settings - Fork 331
New feature: Lag or windows features grouped by #727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4d653a9
4e9d849
7f40391
b476748
02c59bd
0dd92cc
47de2d6
dd43c27
7459811
12aa825
c3bee66
67725dc
9cb01ea
72ce43c
9d999b0
b7b8bc9
ba375a4
90f08f4
66baa75
92f996d
152c037
5343e50
ef1eaa8
09db782
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import datetime | ||
from typing import List, Union | ||
from typing import Dict, List, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
@@ -475,7 +475,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series = None): | |
threshold_cat = self.threshold | ||
|
||
# Compute the PSI by looping over the features | ||
self.psi_values_ = {} | ||
self.psi_values_: Dict = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We resolved this in a different PR. Could we remove this change from here please? |
||
self.features_to_drop_ = [] | ||
|
||
# Compute PSI for numerical features | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
) | ||
from feature_engine._docstrings.init_parameters.all_trasnformers import ( | ||
_drop_original_docstring, | ||
_group_by_docstring, | ||
_missing_values_docstring, | ||
_variables_numerical_docstring, | ||
) | ||
|
@@ -32,6 +33,7 @@ | |
n_features_in_=_n_features_in_docstring, | ||
fit=_fit_not_learn_docstring, | ||
fit_transform=_fit_transform_docstring, | ||
group_by=_group_by_docstring, | ||
) | ||
class LagFeatures(BaseForecastTransformer): | ||
""" | ||
|
@@ -74,6 +76,8 @@ class LagFeatures(BaseForecastTransformer): | |
|
||
{drop_original} | ||
|
||
{group_by} | ||
|
||
Attributes | ||
---------- | ||
variables_: | ||
|
@@ -127,6 +131,7 @@ def __init__( | |
sort_index: bool = True, | ||
missing_values: str = "raise", | ||
drop_original: bool = False, | ||
group_by: Union[None, int, str, List[Union[str, int]]] = None, | ||
) -> None: | ||
|
||
if not ( | ||
|
@@ -151,7 +156,7 @@ def __init__( | |
"sort_index takes values True and False." f"Got {sort_index} instead." | ||
) | ||
|
||
super().__init__(variables, missing_values, drop_original) | ||
super().__init__(variables, missing_values, drop_original, group_by) | ||
|
||
self.periods = periods | ||
self.freq = freq | ||
|
@@ -180,35 +185,57 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame: | |
if isinstance(self.freq, list): | ||
df_ls = [] | ||
for fr in self.freq: | ||
tmp = X[self.variables_].shift( | ||
freq=fr, | ||
axis=0, | ||
) | ||
if self.group_by: | ||
tmp = self._agg_freq_lags( | ||
grouped_df=X.groupby(self.group_by), | ||
freq=fr, | ||
) | ||
else: | ||
tmp = X[self.variables_].shift( | ||
freq=fr, | ||
axis=0, | ||
) | ||
df_ls.append(tmp) | ||
tmp = pd.concat(df_ls, axis=1) | ||
|
||
else: | ||
tmp = X[self.variables_].shift( | ||
freq=self.freq, | ||
axis=0, | ||
) | ||
if self.group_by: | ||
tmp = self._agg_freq_lags( | ||
grouped_df=X.groupby(self.group_by), | ||
freq=self.freq, | ||
) | ||
else: | ||
tmp = X[self.variables_].shift( | ||
freq=self.freq, | ||
axis=0, | ||
) | ||
|
||
else: | ||
if isinstance(self.periods, list): | ||
df_ls = [] | ||
for pr in self.periods: | ||
tmp = X[self.variables_].shift( | ||
periods=pr, | ||
axis=0, | ||
) | ||
if self.group_by: | ||
tmp = X.groupby(self.group_by)[self.variables_].shift( | ||
periods=pr, | ||
) | ||
else: | ||
tmp = X[self.variables_].shift( | ||
periods=pr, | ||
axis=0, | ||
) | ||
df_ls.append(tmp) | ||
tmp = pd.concat(df_ls, axis=1) | ||
|
||
else: | ||
tmp = X[self.variables_].shift( | ||
periods=self.periods, | ||
axis=0, | ||
) | ||
if self.group_by: | ||
tmp = X.groupby(self.group_by)[self.variables_].shift( | ||
periods=self.periods, | ||
) | ||
else: | ||
tmp = X[self.variables_].shift( | ||
periods=self.periods, | ||
axis=0, | ||
) | ||
|
||
tmp.columns = self._get_new_features_name() | ||
|
||
|
@@ -243,3 +270,30 @@ def _get_new_features_name(self) -> List: | |
] | ||
|
||
return feature_names | ||
|
||
def _agg_freq_lags( | ||
self, | ||
grouped_df: pd.core.groupby.generic.DataFrameGroupBy, | ||
freq: Union[str, List[str]], | ||
) -> Union[pd.Series, pd.DataFrame]: | ||
"""_summary_ | ||
|
||
Parameters | ||
---------- | ||
grouped_df : pd.core.groupby.generic.DataFrameGroupBy | ||
dataframe of groups | ||
freq : Union[str, List[str]] | ||
Offset to use from the tseries module or time rule. | ||
|
||
Returns | ||
------- | ||
Union[pd.Series, pd.DataFrame] | ||
lag feature or dataframe of lag features | ||
""" | ||
tmp_data = [] | ||
for _, group in grouped_df: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to loop over the groups to apply the lags? pandas does the lags per group automatically. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried many approaches to simplify this approach, but it is only working when using |
||
original_idx = group.index | ||
tmp = group[self.variables_].shift(freq=freq).reindex(original_idx) | ||
tmp_data.append(tmp) | ||
tmp = pd.concat(tmp_data).sort_index() | ||
return tmp |
Uh oh!
There was an error while loading. Please reload this page.