@@ -94,30 +94,16 @@ def get_hdi_to_df(
9494 The size of the HDI, default is 0.94
9595 """
9696 hdi_result = az .hdi (x , hdi_prob = hdi_prob )
97- hdi_df = hdi_result .to_dataframe ().unstack (level = "hdi" )
9897
99- # Handle MultiIndex columns from unstack operation
100- # After unstack, we may have MultiIndex like: [('mu', 'lower'), ('mu', 'higher'), ('coord', 'lower'), ('coord', 'higher')]
101- # We need to extract only the data variable columns (first level), not coordinate columns
102- if isinstance (hdi_df .columns , pd .MultiIndex ):
103- # Get the name of the data variable (should be at level 0)
104- # For xarray DataArrays, the variable name is typically at index 0
105- data_var_names = hdi_df .columns .get_level_values (0 ).unique ()
98+ # Get the data variable name (typically 'mu' or 'x')
99+ # We select only the data variable column to exclude coordinates like 'treated_units'
100+ data_var = list (hdi_result .data_vars )[0 ]
106101
107- # Filter to include only actual data variables (excluding coordinate names that became columns)
108- # The data variable is typically the one that was originally in the DataArray/Dataset
109- # For simple cases, it's often just the first unique value
110- if len (data_var_names ) > 1 :
111- # Find the numeric data variable (not string coordinates)
112- for var_name in data_var_names :
113- if (
114- hdi_df [(var_name , hdi_df .columns .get_level_values (1 )[0 ])].dtype
115- != "object"
116- ):
117- hdi_df = hdi_df [var_name ]
118- break
119- else :
120- # Only one variable, select it
121- hdi_df = hdi_df [data_var_names [0 ]]
102+ # Convert to DataFrame, select only the data variable column, then unstack
103+ # This prevents coordinate values (like 'treated_agg') from appearing as columns
104+ hdi_df = hdi_result [data_var ].to_dataframe ()[[data_var ]].unstack (level = "hdi" )
105+
106+ # Remove the top level of column MultiIndex to get just 'lower' and 'higher'
107+ hdi_df .columns = hdi_df .columns .droplevel (0 )
122108
123109 return hdi_df
0 commit comments