33import numpy as np
44from sklearn .base import is_classifier , is_regressor
55from sklearn .svm import SVR , SVC
6- from sklearn .ensemble import RandomForestRegressor
6+ from sklearn .ensemble import RandomForestRegressor , RandomForestClassifier
77from sklearn .model_selection import train_test_split , cross_val_score
88from sklearn .pipeline import make_pipeline
99from sklearn .feature_selection import SelectFromModel
@@ -165,6 +165,7 @@ def test_explain_weights(iris_train):
165165 for _expl in res :
166166 assert "petal width (cm)" in _expl
167167
168+
168169def test_pandas_xgboost_support (iris_train ):
169170 xgboost = pytest .importorskip ('xgboost' )
170171 pd = pytest .importorskip ('pandas' )
@@ -175,3 +176,17 @@ def test_pandas_xgboost_support(iris_train):
175176 est .fit (X , y )
176177 # we expect no exception to be raised here when using xgboost with pd.DataFrame
177178 perm = PermutationImportance (est ).fit (X , y )
179+
180+
181+ def test_cv_sample_weight (iris_train ):
182+ X , y , feature_names , target_names = iris_train
183+ weights_ones = np .ones (len (y ))
184+ model = RandomForestClassifier (random_state = 42 )
185+
186+ # we expect no exception to be raised when passing weights with a CV
187+ perm_weights = PermutationImportance (model , cv = 5 , random_state = 42 ).\
188+ fit (X , y , sample_weight = weights_ones )
189+ perm = PermutationImportance (model , cv = 5 , random_state = 42 ).fit (X , y )
190+
191+ # passing a vector of weights filled with one should be the same as passing no weights
192+ assert (perm .feature_importances_ == perm_weights .feature_importances_ ).all ()
0 commit comments