Skip to content

Commit 4d55ed7

Browse files
authored
allow sample weight in shuffle features (#662)
* allow sample weight in shuffle features * add additional tag with test that fails for sample_weights
1 parent feddb06 commit 4d55ed7

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

feature_engine/selection/shuffle_features.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sklearn.base import is_classifier
66
from sklearn.metrics import get_scorer
77
from sklearn.model_selection import check_cv, cross_validate
8-
from sklearn.utils.validation import check_random_state
8+
from sklearn.utils.validation import check_random_state, _check_sample_weight
99

1010
from feature_engine._docstrings.fit_attributes import (
1111
_feature_names_in_docstring,
@@ -185,16 +185,25 @@ def __init__(
185185
self.cv = cv
186186
self.random_state = random_state
187187

188-
def fit(self, X: pd.DataFrame, y: pd.Series):
188+
def fit(
189+
self,
190+
X: pd.DataFrame,
191+
y: pd.Series,
192+
sample_weight: Union[np.array, pd.Series, List] = None,
193+
):
189194
"""
190195
Find the important features.
191196
192197
Parameters
193198
----------
194199
X: pandas dataframe of shape = [n_samples, n_features]
195200
The input dataframe.
201+
196202
y: array-like of shape (n_samples)
197203
Target variable. Required to train the estimator.
204+
205+
sample_weight : array-like of shape (n_samples,), default=None
206+
Sample weights. If None, then samples are equally weighted.
198207
"""
199208

200209
X, y = check_X_y(X, y)
@@ -203,6 +212,9 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
203212
X = X.reset_index(drop=True)
204213
y = y.reset_index(drop=True)
205214

215+
if sample_weight is not None:
216+
sample_weight = _check_sample_weight(sample_weight, X)
217+
206218
# If required exclude variables that are not in the input dataframe
207219
self._confirm_variables(X)
208220

@@ -220,6 +232,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
220232
cv=self.cv,
221233
return_estimator=True,
222234
scoring=self.scoring,
235+
fit_params={"sample_weight": sample_weight},
223236
)
224237

225238
# store initial model performance

feature_engine/tags.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def _return_tags():
1414
# The test aims to check that the check_X_y function from sklearn is
1515
# working, but we do not use that check, because we work with dfs.
1616
"check_transformer_data_not_an_array": "Ok to fail",
17+
"check_sample_weights_not_an_array": "Ok to fail",
1718
# TODO: we probably need the test below!!
1819
"check_methods_sample_order_invariance": "Test does not work on dataframes",
1920
# TODO: we probably need the test below!!

tests/test_selection/test_shuffle_features.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,21 @@ def test_automatic_variable_selection(df_test):
134134
]
135135
# test transform output
136136
pd.testing.assert_frame_equal(sel.transform(X), Xtransformed)
137+
138+
139+
def test_sample_weights():
140+
X = pd.DataFrame(
141+
dict(
142+
x1=[1000, 2000, 1000, 1000, 2000, 3000],
143+
x2=[1000, 2000, 1000, 1000, 2000, 3000],
144+
)
145+
)
146+
y = pd.Series([1, 0, 0, 1, 1, 0])
147+
148+
sbs = SelectByShuffling(
149+
RandomForestClassifier(random_state=42), cv=2, random_state=42
150+
)
151+
152+
sample_weight = [1000, 2000, 1000, 1000, 2000, 3000]
153+
sbs.fit_transform(X, y, sample_weight=sample_weight)
154+
assert sbs.initial_model_performance_ == 0.125

0 commit comments

Comments
 (0)