Skip to content

Commit 44da1d7

Browse files
committed
🐛 fix wrong aggregation when shap_values dim > 2
1 parent 8d09bce commit 44da1d7

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

powershap/shap_wrappers/shap_explainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ def explain(
171171
Shap_values = np.abs(Shap_values)
172172

173173
if len(np.shape(Shap_values)) > 2:
174-
# Shap_values = np.max(Shap_values, axis=0)
175-
Shap_values = np.max(Shap_values, axis=0).T
174+
Shap_values = np.max(Shap_values, axis=-1)
176175

177176
# TODO: consider to convert to even float16?
178177
Shap_values = np.mean(Shap_values, axis=0).astype("float32")

0 commit comments

Comments
 (0)