Skip to content

Commit 6676610

Browse files
authored
add gputree support (#291)
* add gputree support * formatting * Improve gputree support and docs
1 parent e688c4a commit 6676610

File tree

4 files changed

+41
-18
lines changed

4 files changed

+41
-18
lines changed

README.md

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,18 +213,20 @@ Some of the calculations for the dashboard such as calculating SHAP (interaction
213213
and permutation importances can be slow for large datasets and complicated models.
214214
There are a few tricks to make this less painful:
215215

216-
1. Switching off the interactions tab (`shap_interaction=False`) and disabling
217-
permutation importances (`no_permutations=True`). Especially SHAP interaction
218-
values can be very slow to calculate, and often are not needed for analysis.
219-
For permutation importances you can set the `n_jobs` parameter to speed up
220-
the calculation in parallel.
221-
2. Calculate approximate shap values. You can pass approximate=True as a shap parameter by
222-
passing `shap_kwargs=dict(approximate=True)` to the explainer initialization.
223-
4. Storing the explainer. The calculated properties are only calculated once
224-
for each instance, however each time when you instantiate a new explainer
225-
instance they will have to be recalculated. You can store them with
226-
`explainer.dump("explainer.joblib")` and load with e.g.
227-
`ClassifierExplainer.from_file("explainer.joblib")`. All calculated properties
216+
1. Switching off the interactions tab (`shap_interaction=False`) and disabling
217+
permutation importances (`no_permutations=True`). Especially SHAP interaction
218+
values can be very slow to calculate, and often are not needed for analysis.
219+
For permutation importances you can set the `n_jobs` parameter to speed up
220+
the calculation in parallel.
221+
2. Calculate approximate shap values. You can pass approximate=True as a shap parameter by
222+
passing `shap_kwargs=dict(approximate=True)` to the explainer initialization.
223+
3. Use GPU Tree SHAP by passing `shap='gputree'` when your model supports it.
224+
This requires an NVIDIA GPU and a CUDA-enabled SHAP build (see the SHAP docs).
225+
4. Storing the explainer. The calculated properties are only calculated once
226+
for each instance, however each time when you instantiate a new explainer
227+
instance they will have to be recalculated. You can store them with
228+
`explainer.dump("explainer.joblib")` and load with e.g.
229+
`ClassifierExplainer.from_file("explainer.joblib")`. All calculated properties
228230
are stored along with the explainer.
229231
5. Using a smaller (test) dataset, or using smaller decision trees.
230232
TreeShap computational complexity is `O(TLD^2)`, where `T` is the

RELEASE_NOTES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
- Preserve categorical dtypes during permutation importance shuffles and PDP grid generation to prevent dtype-related model errors (e.g., LightGBM).
1010
- Align categorical/boolean dtypes for user-provided `X_row` inputs and add dtype alignment tests.
1111

12+
### Improvements
13+
- Add support for GPU Tree SHAP explainers via `shap='gputree'` (requires CUDA-enabled SHAP).
14+
1215
## Version 0.5.4:
1316

1417
### Breaking Changes

explainerdashboard/dashboard_components/overview_components.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def __init__(
231231
self.importance_type = "shap"
232232
if self.description is None:
233233
self.description = """
234-
Shows the features sorted from most important to least important. Can
235-
be either sorted by absolute SHAP value (average absolute impact of
234+
Shows the features sorted from most important to least important. Can
235+
be either sorted by absolute SHAP value (average absolute impact of
236236
the feature on final prediction) or by permutation importance (how much
237237
does the model get worse when you shuffle this feature, rendering it
238238
useless?).
@@ -647,7 +647,7 @@ def __init__(
647647
of observations and how these observations would change with this
648648
feature (gridlines). The average effect is shown in grey. The effect
649649
of changing the feature for a single {self.explainer.index_name} is
650-
shown in blue. You can adjust how many observations to sample for the
650+
shown in blue. You can adjust how many observations to sample for the
651651
average, how many gridlines to show, and how many points along the
652652
x-axis to calculate model predictions for (gridpoints).
653653
"""

explainerdashboard/explainers.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,9 @@ def __init__(
330330
"sklearn-compatible NeuralNet wrapper are supported for now! "
331331
"See https://github.com/skorch-dev/skorch"
332332
)
333-
assert shap in ["tree", "linear", "deep", "kernel", "skorch"], (
334-
"ERROR! Only shap='guess', 'tree', 'linear', ' kernel' or 'skorch' are "
335-
" supported for now!"
333+
assert shap in ["tree", "linear", "deep", "kernel", "skorch", "gputree"], (
334+
"ERROR! Only shap='guess', 'tree', 'linear', ' kernel', 'skorch' "
335+
"or 'gputree' are supported for now!"
336336
)
337337
self.shap = shap
338338
if self.shap in {"kernel", "skorch", "linear"}:
@@ -1276,6 +1276,24 @@ def model_predict(data_asarray):
12761276
if self.X_background is not None
12771277
else shap.sample(self.X, 50),
12781278
)
1279+
elif self.shap == "gputree":
1280+
print(
1281+
"Generating self.shap_explainer = shap.GPUTreeExplainer(model, X). "
1282+
"Make sure you have a CUDA-enabled GPU and a CUDA-built SHAP "
1283+
"installed. See https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/explainers/GPUTree.html#" # noqa: E501
1284+
)
1285+
X_data = self.X_background if self.X_background is not None else self.X
1286+
if hasattr(shap, "explainers") and hasattr(shap.explainers, "GPUTree"):
1287+
explainer_cls = shap.explainers.GPUTree
1288+
elif hasattr(shap, "GPUTreeExplainer"):
1289+
explainer_cls = shap.GPUTreeExplainer
1290+
else:
1291+
raise ValueError(
1292+
"shap does not expose GPUTreeExplainer. "
1293+
"Please install a CUDA-enabled SHAP build that includes "
1294+
"GPUTree support."
1295+
)
1296+
self._shap_explainer = explainer_cls(self.model, X_data)
12791297
return self._shap_explainer
12801298

12811299
@insert_pos_label

0 commit comments

Comments
 (0)