@@ -227,11 +227,6 @@ def supports_model(model) -> bool:
227227 supported_models = [CatBoostRegressor , CatBoostClassifier ]
228228 return isinstance (model , tuple (supported_models ))
229229
230- # def validate_data(self, validate_data: Callable, X, y, **kwargs):
231- # kwargs["force_all_finite"] = False # catboost allows NaNs and infs in X
232- # kwargs["dtype"] = None # allow non-numeric data
233- # return validate_data(self, validate_data, X, y, **kwargs)
234-
235230 def validate_data (self , _estimator , X , y , ** kwargs ):
236231 kwargs ["ensure_all_finite" ] = False # catboost allows NaNs and infs in X
237232 kwargs ["dtype" ] = None # allow non-numeric data
@@ -241,7 +236,7 @@ def validate_data(self, _estimator, X, y, **kwargs):
241236 def _fit_get_shap (self , X_train , Y_train , X_val , Y_val , random_seed , ** kwargs ) -> np .array :
242237 # Fit the model
243238 PowerShap_model = self .model .copy ().set_params (random_seed = random_seed )
244- PowerShap_model .fit (X_train , Y_train , eval_set = (X_val , Y_val ))
239+ PowerShap_model .fit (X_train , Y_train , eval_set = (X_val , Y_val ), ** kwargs )
245240 # Calculate the shap values
246241 C_explainer = shap .TreeExplainer (PowerShap_model )
247242 return C_explainer .shap_values (X_val )
@@ -264,10 +259,6 @@ def supports_model(model) -> bool:
264259 supported_models = [LGBMClassifier , LGBMRegressor ]
265260 return isinstance (model , tuple (supported_models ))
266261
267- # def _validate_data(self, validate_data: Callable, X, y, **kwargs):
268- # kwargs["force_all_finite"] = False # lgbm allows NaNs and infs in X
269- # return super()._validate_data(validate_data, X, y, **kwargs)
270-
271262 def validate_data (self , _estimator , X , y , ** kwargs ):
272263 kwargs ["ensure_all_finite" ] = False # lgbm allows NaNs and infs in X
273264 return validate_data (_estimator , X , y , ** kwargs )
@@ -277,7 +268,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
277268 # Why we need to use deepcopy and delete LGBM __deepcopy__
278269 # https://github.com/microsoft/LightGBM/issues/4085
279270 PowerShap_model = copy (self .model ).set_params (random_seed = random_seed )
280- PowerShap_model .fit (X_train , Y_train , eval_set = (X_val , Y_val ))
271+ PowerShap_model .fit (X_train , Y_train , eval_set = (X_val , Y_val ), ** kwargs )
281272 # Calculate the shap values
282273 C_explainer = shap .TreeExplainer (PowerShap_model )
283274 return C_explainer .shap_values (X_val )
@@ -300,11 +291,6 @@ def supports_model(model) -> bool:
300291
301292 supported_models = [XGBClassifier , XGBRegressor ]
302293 return isinstance (model , tuple (supported_models ))
303-
304- # def validate_data(self, validate_data: Callable, X, y, **kwargs):
305- # kwargs["force_all_finite"] = False # xgboost allows NaNs and infs in X
306- # kwargs["dtype"] = None # allow non-numeric data
307- # return super().validate_data(validate_data, X, y, **kwargs)
308294
309295 def validate_data (self , _estimator , X , y , ** kwargs ):
310296 kwargs ["ensure_all_finite" ] = False # xgboost allows NaNs and infs in X
@@ -314,7 +300,7 @@ def validate_data(self, _estimator, X, y, **kwargs):
314300 def _fit_get_shap (self , X_train , Y_train , X_val , Y_val , random_seed , ** kwargs ) -> np .array :
315301 # Fit the model
316302 PowerShap_model = copy (self .model ).set_params (random_state = random_seed )
317- PowerShap_model .fit (X_train , Y_train , eval_set = [(X_val , Y_val )])
303+ PowerShap_model .fit (X_train , Y_train , eval_set = [(X_val , Y_val )], ** kwargs )
318304 # Calculate the shap values
319305 C_explainer = shap .TreeExplainer (PowerShap_model )
320306 return C_explainer .shap_values (X_val )
@@ -349,7 +335,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
349335
350336 # Fit the model
351337 PowerShap_model = clone (self .model ).set_params (random_state = random_seed )
352- PowerShap_model .fit (X_train , Y_train )
338+ PowerShap_model .fit (X_train , Y_train , ** kwargs )
353339 # Calculate the shap values
354340 C_explainer = shap .TreeExplainer (PowerShap_model )
355341 return C_explainer .shap_values (X_val )
@@ -375,7 +361,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
375361 PowerShap_model = clone (self .model ).set_params (random_state = random_seed )
376362 except Exception :
377363 PowerShap_model = clone (self .model )
378- PowerShap_model .fit (X_train , Y_train )
364+ PowerShap_model .fit (X_train , Y_train , ** kwargs )
379365
380366 # Calculate the shap values
381367 C_explainer = shap .explainers .Linear (PowerShap_model , X_train )
0 commit comments