Skip to content

Commit 017c738

Browse files
authored
Merge pull request #359 from rg2410/issue-358
Sample weight sliced to work with cross validation, issue #358
2 parents 4839d19 + 729e557 commit 017c738

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

eli5/sklearn/permutation_importance.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ def _cv_scores_importances(self, X, y, groups=None, **fit_params):
214214
cv = check_cv(self.cv, y, is_classifier(self.estimator))
215215
feature_importances = [] # type: List
216216
base_scores = [] # type: List[float]
217+
weights = fit_params.pop('sample_weight', None)
218+
fold_fit_params = fit_params.copy()
217219
for train, test in cv.split(X, y, groups):
218-
est = clone(self.estimator).fit(X[train], y[train], **fit_params)
220+
if weights is not None:
221+
fold_fit_params['sample_weight'] = weights[train]
222+
est = clone(self.estimator).fit(X[train], y[train], **fold_fit_params)
219223
score_func = partial(self.scorer_, est)
220224
_base_score, _importances = self._get_score_importances(
221225
score_func, X[test], y[test])

tests/test_sklearn_permutation_importance.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from sklearn.base import is_classifier, is_regressor
55
from sklearn.svm import SVR, SVC
6-
from sklearn.ensemble import RandomForestRegressor
6+
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
77
from sklearn.model_selection import train_test_split, cross_val_score
88
from sklearn.pipeline import make_pipeline
99
from sklearn.feature_selection import SelectFromModel
@@ -165,6 +165,7 @@ def test_explain_weights(iris_train):
165165
for _expl in res:
166166
assert "petal width (cm)" in _expl
167167

168+
168169
def test_pandas_xgboost_support(iris_train):
169170
xgboost = pytest.importorskip('xgboost')
170171
pd = pytest.importorskip('pandas')
@@ -175,3 +176,17 @@ def test_pandas_xgboost_support(iris_train):
175176
est.fit(X, y)
176177
# we expect no exception to be raised here when using xgboost with pd.DataFrame
177178
perm = PermutationImportance(est).fit(X, y)
179+
180+
181+
def test_cv_sample_weight(iris_train):
182+
X, y, feature_names, target_names = iris_train
183+
weights_ones = np.ones(len(y))
184+
model = RandomForestClassifier(random_state=42)
185+
186+
# we expect no exception to be raised when passing weights with a CV
187+
perm_weights = PermutationImportance(model, cv=5, random_state=42).\
188+
fit(X, y, sample_weight=weights_ones)
189+
perm = PermutationImportance(model, cv=5, random_state=42).fit(X, y)
190+
191+
# passing a vector of weights filled with one should be the same as passing no weights
192+
assert (perm.feature_importances_ == perm_weights.feature_importances_).all()

0 commit comments

Comments
 (0)