Skip to content

Commit f694cd6

Browse files
committed
🖍️ improving docs
1 parent c966a88 commit f694cd6

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

powershap/shap_wrappers/shap_explainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,15 @@ def explain(
171171
Shap_values = np.abs(Shap_values)
172172

173173
if len(np.shape(Shap_values)) > 2:
174+
# SHAPE: (n_samples, n_features, n_outputs)
175+
assert len(np.shape(Shap_values)) == 3, "Shap values should be 3D"
176+
# in case of multi-output, we take the max of the outputs as the shap value
174177
Shap_values = np.max(Shap_values, axis=-1)
178+
# new shape = (n_samples, n_features)
175179

176180
# TODO: consider to convert to even float16?
177181
Shap_values = np.mean(Shap_values, axis=0).astype("float32")
182+
# new shape = (n_features,)
178183

179184
shaps += [Shap_values]
180185

0 commit comments

Comments
 (0)