55from sklearn .base import is_classifier
66from sklearn .metrics import get_scorer
77from sklearn .model_selection import check_cv , cross_validate
8- from sklearn .utils .validation import check_random_state
8+ from sklearn .utils .validation import check_random_state , _check_sample_weight
99
1010from feature_engine ._docstrings .fit_attributes import (
1111 _feature_names_in_docstring ,
@@ -185,16 +185,25 @@ def __init__(
185185 self .cv = cv
186186 self .random_state = random_state
187187
188- def fit (self , X : pd .DataFrame , y : pd .Series ):
188+ def fit (
189+ self ,
190+ X : pd .DataFrame ,
191+ y : pd .Series ,
192+ sample_weight : Union [np .array , pd .Series , List ] = None ,
193+ ):
189194 """
190195 Find the important features.
191196
192197 Parameters
193198 ----------
194199 X: pandas dataframe of shape = [n_samples, n_features]
195200 The input dataframe.
201+
196202 y: array-like of shape (n_samples)
197203 Target variable. Required to train the estimator.
204+
205+ sample_weight : array-like of shape (n_samples,), default=None
206+ Sample weights. If None, then samples are equally weighted.
198207 """
199208
200209 X , y = check_X_y (X , y )
@@ -203,6 +212,9 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
203212 X = X .reset_index (drop = True )
204213 y = y .reset_index (drop = True )
205214
215+ if sample_weight is not None :
216+ sample_weight = _check_sample_weight (sample_weight , X )
217+
206218 # If required exclude variables that are not in the input dataframe
207219 self ._confirm_variables (X )
208220
@@ -220,6 +232,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
220232 cv = self .cv ,
221233 return_estimator = True ,
222234 scoring = self .scoring ,
235+ fit_params = {"sample_weight" : sample_weight },
223236 )
224237
225238 # store initial model performance
0 commit comments