Skip to content

Commit 90ed570

Browse files
jarverhajarverha
authored andcommitted
added support for fitkwargs in the fit (issue 45)
1 parent 45c8876 commit 90ed570

File tree

1 file changed

+5
-19
lines changed

1 file changed

+5
-19
lines changed

powershap/shap_wrappers/shap_explainer.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,6 @@ def supports_model(model) -> bool:
227227
supported_models = [CatBoostRegressor, CatBoostClassifier]
228228
return isinstance(model, tuple(supported_models))
229229

230-
# def validate_data(self, validate_data: Callable, X, y, **kwargs):
231-
# kwargs["force_all_finite"] = False # catboost allows NaNs and infs in X
232-
# kwargs["dtype"] = None # allow non-numeric data
233-
# return validate_data(self, validate_data, X, y, **kwargs)
234-
235230
def validate_data(self, _estimator, X, y, **kwargs):
236231
kwargs["ensure_all_finite"] = False # catboost allows NaNs and infs in X
237232
kwargs["dtype"] = None # allow non-numeric data
@@ -241,7 +236,7 @@ def validate_data(self, _estimator, X, y, **kwargs):
241236
def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -> np.array:
242237
# Fit the model
243238
PowerShap_model = self.model.copy().set_params(random_seed=random_seed)
244-
PowerShap_model.fit(X_train, Y_train, eval_set=(X_val, Y_val))
239+
PowerShap_model.fit(X_train, Y_train, eval_set=(X_val, Y_val), **kwargs)
245240
# Calculate the shap values
246241
C_explainer = shap.TreeExplainer(PowerShap_model)
247242
return C_explainer.shap_values(X_val)
@@ -264,10 +259,6 @@ def supports_model(model) -> bool:
264259
supported_models = [LGBMClassifier, LGBMRegressor]
265260
return isinstance(model, tuple(supported_models))
266261

267-
# def _validate_data(self, validate_data: Callable, X, y, **kwargs):
268-
# kwargs["force_all_finite"] = False # lgbm allows NaNs and infs in X
269-
# return super()._validate_data(validate_data, X, y, **kwargs)
270-
271262
def validate_data(self, _estimator, X, y, **kwargs):
272263
kwargs["ensure_all_finite"] = False # lgbm allows NaNs and infs in X
273264
return validate_data(_estimator, X, y, **kwargs)
@@ -277,7 +268,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
277268
# Why we need to use deepcopy and delete LGBM __deepcopy__
278269
# https://github.com/microsoft/LightGBM/issues/4085
279270
PowerShap_model = copy(self.model).set_params(random_seed=random_seed)
280-
PowerShap_model.fit(X_train, Y_train, eval_set=(X_val, Y_val))
271+
PowerShap_model.fit(X_train, Y_train, eval_set=(X_val, Y_val), **kwargs)
281272
# Calculate the shap values
282273
C_explainer = shap.TreeExplainer(PowerShap_model)
283274
return C_explainer.shap_values(X_val)
@@ -300,11 +291,6 @@ def supports_model(model) -> bool:
300291

301292
supported_models = [XGBClassifier, XGBRegressor]
302293
return isinstance(model, tuple(supported_models))
303-
304-
# def validate_data(self, validate_data: Callable, X, y, **kwargs):
305-
# kwargs["force_all_finite"] = False # xgboost allows NaNs and infs in X
306-
# kwargs["dtype"] = None # allow non-numeric data
307-
# return super().validate_data(validate_data, X, y, **kwargs)
308294

309295
def validate_data(self, _estimator, X, y, **kwargs):
310296
kwargs["ensure_all_finite"] = False # xgboost allows NaNs and infs in X
@@ -314,7 +300,7 @@ def validate_data(self, _estimator, X, y, **kwargs):
314300
def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -> np.array:
315301
# Fit the model
316302
PowerShap_model = copy(self.model).set_params(random_state=random_seed)
317-
PowerShap_model.fit(X_train, Y_train, eval_set=[(X_val, Y_val)])
303+
PowerShap_model.fit(X_train, Y_train, eval_set=[(X_val, Y_val)], **kwargs)
318304
# Calculate the shap values
319305
C_explainer = shap.TreeExplainer(PowerShap_model)
320306
return C_explainer.shap_values(X_val)
@@ -349,7 +335,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
349335

350336
# Fit the model
351337
PowerShap_model = clone(self.model).set_params(random_state=random_seed)
352-
PowerShap_model.fit(X_train, Y_train)
338+
PowerShap_model.fit(X_train, Y_train, **kwargs)
353339
# Calculate the shap values
354340
C_explainer = shap.TreeExplainer(PowerShap_model)
355341
return C_explainer.shap_values(X_val)
@@ -375,7 +361,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
375361
PowerShap_model = clone(self.model).set_params(random_state=random_seed)
376362
except Exception:
377363
PowerShap_model = clone(self.model)
378-
PowerShap_model.fit(X_train, Y_train)
364+
PowerShap_model.fit(X_train, Y_train, **kwargs)
379365

380366
# Calculate the shap values
381367
C_explainer = shap.explainers.Linear(PowerShap_model, X_train)

0 commit comments

Comments
 (0)