Skip to content

Commit c8208d8

Browse files
Merge pull request #52 from predict-idlab/bug_transpose
Fix wrong aggregation of shap values
2 parents 8d09bce + f694cd6 commit c8208d8

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ clean:
1919
rm -rf .coverage
2020
rm -rf .ruff_cache
2121
rm -rf catboost_info
22+
23+
.PHONY: test
24+
test:
25+
poetry run pytest --cov=powershap tests/

powershap/shap_wrappers/shap_explainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,15 @@ def explain(
171171
Shap_values = np.abs(Shap_values)
172172

173173
if len(np.shape(Shap_values)) > 2:
174-
# Shap_values = np.max(Shap_values, axis=0)
175-
Shap_values = np.max(Shap_values, axis=0).T
174+
# SHAPE: (n_samples, n_features, n_outputs)
175+
assert len(np.shape(Shap_values)) == 3, "Shap values should be 3D"
176+
# in case of multi-output, we take the max of the outputs as the shap value
177+
Shap_values = np.max(Shap_values, axis=-1)
178+
# new shape = (n_samples, n_features)
176179

177180
# TODO: consider to convert to even float16?
178181
Shap_values = np.mean(Shap_values, axis=0).astype("float32")
182+
# new shape = (n_features,)
179183

180184
shaps += [Shap_values]
181185

0 commit comments

Comments
 (0)