Skip to content

Commit aa962e6

Browse files
jarverhajarverha
authored andcommitted
bugfixes for pipeline explainer
1 parent f0b39d0 commit aa962e6

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

powershap/shap_wrappers/shap_explainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,8 @@ def __init__(self, model: Any):
382382
383383
"""
384384
assert self.supports_model(model)
385-
self.shap_explainer = ShapExplainerFactory.get_explainer(model=ShapExplainer(model.steps[-1][1]))
385+
self.shap_explainer = ShapExplainerFactory.get_explainer(model=model.steps[-1][1])
386+
self.model = model
386387

387388
@staticmethod
388389
def supports_model(model) -> bool:
@@ -419,7 +420,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
419420

420421
def validate_data(self, _estimator, X, y, **kwargs):
421422
# The assumption here is that the used model is the limiting factor for validation of the data
422-
self.shap_explainer.validate_data(_estimator, X, y, **kwargs)
423+
return self.shap_explainer.validate_data(_estimator, X, y, **kwargs)
423424

424425
def _get_more_tags(self):
425426
return self.shap_explainer._get_more_tags()

0 commit comments

Comments
 (0)