Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,12 @@ def __init__(
"sklearn-compatible NeuralNet wrapper are supported for now! "
"See https://github.com/skorch-dev/skorch"
)
assert shap in ["tree", "linear", "deep", "kernel", "skorch"], (
"ERROR! Only shap='guess', 'tree', 'linear', ' kernel' or 'skorch' are "
" supported for now!"
assert shap in ["tree", "linear", "deep", "kernel", "skorch", "gputree"], (
"ERROR! Only shap='guess', 'tree', 'linear', ' kernel', 'skorch' "
"or 'gputree' are supported for now!"
)
self.shap = shap
if self.shap in {"kernel", "skorch", "linear"}:
if self.shap in {"kernel", "skorch", "linear", "gputree"}:
print(
f"WARNING: For shap='{self.shap}', shap interaction values can unfortunately "
"not be calculated!"
Expand Down Expand Up @@ -1123,6 +1123,13 @@ def model_predict(data_asarray):
if self.X_background is not None
else shap.sample(self.X, 50),
)
elif self.shap == "gputree":
print(
"Generating self.shap_explainer = shap.explainer.GPUTree(model, X)."
"Make sure you have a cuda enabled GPU and followed installation"
"instructions at https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/explainers/GPUTree.html#" # noqa: E501
)
self._shap_explainer = shap.explainers.GPUTree(self.model, self.X)
return self._shap_explainer

@insert_pos_label
Expand Down