We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c966a88 commit f694cd6Copy full SHA for f694cd6
powershap/shap_wrappers/shap_explainer.py
@@ -171,10 +171,15 @@ def explain(
171
Shap_values = np.abs(Shap_values)
172
173
if len(np.shape(Shap_values)) > 2:
174
+ # SHAPE: (n_samples, n_features, n_outputs)
175
+ assert len(np.shape(Shap_values)) == 3, "Shap values should be 3D"
176
+ # in case of multi-output, we take the max of the outputs as the shap value
177
Shap_values = np.max(Shap_values, axis=-1)
178
+ # new shape = (n_samples, n_features)
179
180
# TODO: consider to convert to even float16?
181
Shap_values = np.mean(Shap_values, axis=0).astype("float32")
182
+ # new shape = (n_features,)
183
184
shaps += [Shap_values]
185
0 commit comments