Skip to content

Commit 2d93972

Browse files
jarverhajarverha
authored andcommitted
inhomogeneous fix
1 parent 9e07c2e commit 2d93972

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

powershap/utils.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,30 +61,34 @@ def powerSHAP_statistical_analysis(
6161
required_iterations.append(0)
6262
effect_size.append(0)
6363
power_list.append(0)
64-
try:
65-
processed_shaps_df = pd.DataFrame(
66-
data=np.hstack(
67-
[
68-
np.reshape(shaps_df.mean().values, (-1, 1)),
69-
np.reshape(np.array(p_values), (len(p_values), 1)),
70-
np.reshape(np.array(effect_size), (len(effect_size), 1)),
71-
np.reshape(np.array(power_list), (len(power_list), 1)),
72-
np.reshape(np.array(required_iterations), (len(required_iterations), 1)),
73-
]
74-
),
75-
columns=[
76-
"impact",
77-
"p_value",
78-
"effect_size",
79-
"power_" + str(power_alpha) + "_alpha",
80-
str(power_req_iterations) + "_power_its_req",
81-
],
82-
index=shaps_df.mean().index,
83-
)
84-
except ValueError as e:
85-
# If a ValueError occurs, print the error and append a placeholder.
86-
print(f"failed with error: {e}")
87-
print(f"required iterations is = {required_iterations}")
64+
65+
#This code is required because if statsmodels does not converge, due to numpy issues (see closed issue for inhomogeneous shape) the code does not work otherwise
66+
flattened_required_iterations = []
67+
for item in required_iterations:
68+
if isinstance(item, np.ndarray):
69+
flattened_required_iterations.extend(item.tolist())
70+
else:
71+
flattened_required_iterations.append(item)
72+
73+
processed_shaps_df = pd.DataFrame(
74+
data=np.hstack(
75+
[
76+
np.reshape(shaps_df.mean().values, (-1, 1)),
77+
np.reshape(np.array(p_values), (len(p_values), 1)),
78+
np.reshape(np.array(effect_size), (len(effect_size), 1)),
79+
np.reshape(np.array(power_list), (len(power_list), 1)),
80+
np.reshape(np.array(flattened_required_iterations), (len(flattened_required_iterations), 1)),
81+
]
82+
),
83+
columns=[
84+
"impact",
85+
"p_value",
86+
"effect_size",
87+
"power_" + str(power_alpha) + "_alpha",
88+
str(power_req_iterations) + "_power_its_req",
89+
],
90+
index=shaps_df.mean().index,
91+
)
8892

8993
processed_shaps_df = processed_shaps_df.reindex(
9094
processed_shaps_df.impact.abs().sort_values(ascending=False).index

0 commit comments

Comments
 (0)