File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change 1919 rm -rf .coverage
2020 rm -rf .ruff_cache
2121 rm -rf catboost_info
22+
23+ .PHONY : test
24+ test :
25+ poetry run pytest --cov=powershap tests/
Original file line number Diff line number Diff line change @@ -171,11 +171,15 @@ def explain(
171171 Shap_values = np .abs (Shap_values )
172172
173173 if len (np .shape (Shap_values )) > 2 :
174- # Shap_values = np.max(Shap_values, axis=0)
175- Shap_values = np .max (Shap_values , axis = 0 ).T
174+ # SHAPE: (n_samples, n_features, n_outputs)
175+ assert len (np .shape (Shap_values )) == 3 , "Shap values should be 3D"
176+ # in case of multi-output, we take the max of the outputs as the shap value
177+ Shap_values = np .max (Shap_values , axis = - 1 )
178+ # new shape = (n_samples, n_features)
176179
177180 # TODO: consider to convert to even float16?
178181 Shap_values = np .mean (Shap_values , axis = 0 ).astype ("float32" )
182+ # new shape = (n_features,)
179183
180184 shaps += [Shap_values ]
181185
You can’t perform that action at this time.
0 commit comments