Skip to content

Commit fa5f5af

Browse files
committed
massive simplification of the fix
1 parent 89636f8 commit fa5f5af

File tree

1 file changed

+9
-23
lines changed

1 file changed

+9
-23
lines changed

causalpy/plot_utils.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -94,30 +94,16 @@ def get_hdi_to_df(
9494
The size of the HDI, default is 0.94
9595
"""
9696
hdi_result = az.hdi(x, hdi_prob=hdi_prob)
97-
hdi_df = hdi_result.to_dataframe().unstack(level="hdi")
9897

99-
# Handle MultiIndex columns from unstack operation
100-
# After unstack, we may have MultiIndex like: [('mu', 'lower'), ('mu', 'higher'), ('coord', 'lower'), ('coord', 'higher')]
101-
# We need to extract only the data variable columns (first level), not coordinate columns
102-
if isinstance(hdi_df.columns, pd.MultiIndex):
103-
# Get the name of the data variable (should be at level 0)
104-
# For xarray DataArrays, the variable name is typically at index 0
105-
data_var_names = hdi_df.columns.get_level_values(0).unique()
98+
# Get the data variable name (typically 'mu' or 'x')
99+
# We select only the data variable column to exclude coordinates like 'treated_units'
100+
data_var = list(hdi_result.data_vars)[0]
106101

107-
# Filter to include only actual data variables (excluding coordinate names that became columns)
108-
# The data variable is typically the one that was originally in the DataArray/Dataset
109-
# For simple cases, it's often just the first unique value
110-
if len(data_var_names) > 1:
111-
# Find the numeric data variable (not string coordinates)
112-
for var_name in data_var_names:
113-
if (
114-
hdi_df[(var_name, hdi_df.columns.get_level_values(1)[0])].dtype
115-
!= "object"
116-
):
117-
hdi_df = hdi_df[var_name]
118-
break
119-
else:
120-
# Only one variable, select it
121-
hdi_df = hdi_df[data_var_names[0]]
102+
# Convert to DataFrame, select only the data variable column, then unstack
103+
# This prevents coordinate values (like 'treated_agg') from appearing as columns
104+
hdi_df = hdi_result[data_var].to_dataframe()[[data_var]].unstack(level="hdi")
105+
106+
# Remove the top level of column MultiIndex to get just 'lower' and 'higher'
107+
hdi_df.columns = hdi_df.columns.droplevel(0)
122108

123109
return hdi_df

0 commit comments

Comments
 (0)