@@ -120,9 +120,9 @@ def __init__(
120120 },
121121 )
122122 self .pre_y = xr .DataArray (
123- self .pre_y [:, 0 ],
124- dims = ["obs_ind" ],
125- coords = {"obs_ind" : self .datapre .index },
123+ self .pre_y , # Keep 2D shape
124+ dims = ["obs_ind" , "treated_units" ],
125+ coords = {"obs_ind" : self .datapre .index , "treated_units" : [ "unit_0" ] },
126126 )
127127 self .post_X = xr .DataArray (
128128 self .post_X ,
@@ -133,17 +133,22 @@ def __init__(
133133 },
134134 )
135135 self .post_y = xr .DataArray (
136- self .post_y [:, 0 ],
137- dims = ["obs_ind" ],
138- coords = {"obs_ind" : self .datapost .index },
136+ self .post_y , # Keep 2D shape
137+ dims = ["obs_ind" , "treated_units" ],
138+ coords = {"obs_ind" : self .datapost .index , "treated_units" : [ "unit_0" ] },
139139 )
140140
141141 # fit the model to the observed (pre-intervention) data
142142 if isinstance (self .model , PyMCModel ):
143- COORDS = {"coeffs" : self .labels , "obs_ind" : np .arange (self .pre_X .shape [0 ])}
143+ COORDS = {
144+ "coeffs" : self .labels ,
145+ "obs_ind" : np .arange (self .pre_X .shape [0 ]),
146+ "treated_units" : ["unit_0" ],
147+ }
144148 self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
145149 elif isinstance (self .model , RegressorMixin ):
146- self .model .fit (X = self .pre_X , y = self .pre_y )
150+ # For OLS models, use 1D y data
151+ self .model .fit (X = self .pre_X , y = self .pre_y .isel (treated_units = 0 ))
147152 else :
148153 raise ValueError ("Model type not recognized" )
149154
@@ -155,8 +160,21 @@ def __init__(
155160
156161 # calculate the counterfactual
157162 self .post_pred = self .model .predict (X = self .post_X )
158- self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
159- self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
163+
164+ # calculate impact - use appropriate y data format for each model type
165+ if isinstance (self .model , PyMCModel ):
166+ # PyMC models work with 2D data
167+ self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
168+ self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
169+ elif isinstance (self .model , RegressorMixin ):
170+ # SKL models work with 1D data
171+ self .pre_impact = self .model .calculate_impact (
172+ self .pre_y .isel (treated_units = 0 ), self .pre_pred
173+ )
174+ self .post_impact = self .model .calculate_impact (
175+ self .post_y .isel (treated_units = 0 ), self .post_pred
176+ )
177+
160178 self .post_impact_cumulative = self .model .calculate_cumulative_impact (
161179 self .post_impact
162180 )
@@ -202,35 +220,53 @@ def _bayesian_plot(
202220 # pre-intervention period
203221 h_line , h_patch = plot_xY (
204222 self .datapre .index ,
205- self .pre_pred ["posterior_predictive" ].mu ,
223+ self .pre_pred ["posterior_predictive" ].mu . isel ( treated_units = 0 ) ,
206224 ax = ax [0 ],
207225 plot_hdi_kwargs = {"color" : "C0" },
208226 )
209227 handles = [(h_line , h_patch )]
210228 labels = ["Pre-intervention period" ]
211229
212- (h ,) = ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
230+ (h ,) = ax [0 ].plot (
231+ self .datapre .index ,
232+ self .pre_y .isel (treated_units = 0 )
233+ if hasattr (self .pre_y , "isel" )
234+ else self .pre_y [:, 0 ],
235+ "k." ,
236+ label = "Observations" ,
237+ )
213238 handles .append (h )
214239 labels .append ("Observations" )
215240
216241 # post intervention period
217242 h_line , h_patch = plot_xY (
218243 self .datapost .index ,
219- self .post_pred ["posterior_predictive" ].mu ,
244+ self .post_pred ["posterior_predictive" ].mu . isel ( treated_units = 0 ) ,
220245 ax = ax [0 ],
221246 plot_hdi_kwargs = {"color" : "C1" },
222247 )
223248 handles .append ((h_line , h_patch ))
224249 labels .append (counterfactual_label )
225250
226- ax [0 ].plot (self .datapost .index , self .post_y , "k." )
251+ ax [0 ].plot (
252+ self .datapost .index ,
253+ self .post_y .isel (treated_units = 0 )
254+ if hasattr (self .post_y , "isel" )
255+ else self .post_y [:, 0 ],
256+ "k." ,
257+ )
227258 # Shaded causal effect
259+ post_pred_mu = (
260+ az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
261+ .isel (treated_units = 0 )
262+ .mean ("sample" )
263+ ) # Add .mean("sample") to get 1D array
228264 h = ax [0 ].fill_between (
229265 self .datapost .index ,
230- y1 = az . extract (
231- self .post_pred , group = "posterior_predictive" , var_names = "mu"
232- ). mean ( "sample" ),
233- y2 = np . squeeze ( self .post_y ) ,
266+ y1 = post_pred_mu ,
267+ y2 = self .post_y . isel ( treated_units = 0 )
268+ if hasattr ( self . post_y , "isel" )
269+ else self .post_y [:, 0 ] ,
234270 color = "C0" ,
235271 alpha = 0.25 ,
236272 )
@@ -239,28 +275,28 @@ def _bayesian_plot(
239275
240276 ax [0 ].set (
241277 title = f"""
242- Pre-intervention Bayesian $R^2$: { round_num (self .score . r2 , round_to )}
243- (std = { round_num (self .score . r2_std , round_to )} )
278+ Pre-intervention Bayesian $R^2$: { round_num (self .score [ "unit_0_r2" ] , round_to )}
279+ (std = { round_num (self .score [ "unit_0_r2_std" ] , round_to )} )
244280 """
245281 )
246282
247283 # MIDDLE PLOT -----------------------------------------------
248284 plot_xY (
249285 self .datapre .index ,
250- self .pre_impact ,
286+ self .pre_impact . isel ( treated_units = 0 ) ,
251287 ax = ax [1 ],
252288 plot_hdi_kwargs = {"color" : "C0" },
253289 )
254290 plot_xY (
255291 self .datapost .index ,
256- self .post_impact ,
292+ self .post_impact . isel ( treated_units = 0 ) ,
257293 ax = ax [1 ],
258294 plot_hdi_kwargs = {"color" : "C1" },
259295 )
260296 ax [1 ].axhline (y = 0 , c = "k" )
261297 ax [1 ].fill_between (
262298 self .datapost .index ,
263- y1 = self .post_impact .mean (["chain" , "draw" ]),
299+ y1 = self .post_impact .mean (["chain" , "draw" ]). isel ( treated_units = 0 ) ,
264300 color = "C0" ,
265301 alpha = 0.25 ,
266302 label = "Causal impact" ,
@@ -271,7 +307,7 @@ def _bayesian_plot(
271307 ax [2 ].set (title = "Cumulative Causal Impact" )
272308 plot_xY (
273309 self .datapost .index ,
274- self .post_impact_cumulative ,
310+ self .post_impact_cumulative . isel ( treated_units = 0 ) ,
275311 ax = ax [2 ],
276312 plot_hdi_kwargs = {"color" : "C1" },
277313 )
@@ -387,27 +423,45 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
387423 pre_data ["prediction" ] = (
388424 az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
389425 .mean ("sample" )
426+ .isel (treated_units = 0 )
390427 .values
391428 )
392429 post_data ["prediction" ] = (
393430 az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
394431 .mean ("sample" )
432+ .isel (treated_units = 0 )
395433 .values
396434 )
397- pre_data [[ pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
435+ hdi_pre_pred = get_hdi_to_df (
398436 self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
399- ). set_index ( pre_data . index )
400- post_data [[ pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
437+ )
438+ hdi_post_pred = get_hdi_to_df (
401439 self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
440+ )
441+ # Select the single unit from the MultiIndex results
442+ pre_data [[pred_lower_col , pred_upper_col ]] = hdi_pre_pred .xs (
443+ "unit_0" , level = "treated_units"
444+ ).set_index (pre_data .index )
445+ post_data [[pred_lower_col , pred_upper_col ]] = hdi_post_pred .xs (
446+ "unit_0" , level = "treated_units"
402447 ).set_index (post_data .index )
403448
404- pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
405- post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
406- pre_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
407- self .pre_impact , hdi_prob = hdi_prob
449+ pre_data ["impact" ] = (
450+ self .pre_impact .mean (dim = ["chain" , "draw" ]).isel (treated_units = 0 ).values
451+ )
452+ post_data ["impact" ] = (
453+ self .post_impact .mean (dim = ["chain" , "draw" ])
454+ .isel (treated_units = 0 )
455+ .values
456+ )
457+ hdi_pre_impact = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob )
458+ hdi_post_impact = get_hdi_to_df (self .post_impact , hdi_prob = hdi_prob )
459+ # Select the single unit from the MultiIndex results
460+ pre_data [[impact_lower_col , impact_upper_col ]] = hdi_pre_impact .xs (
461+ "unit_0" , level = "treated_units"
408462 ).set_index (pre_data .index )
409- post_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
410- self . post_impact , hdi_prob = hdi_prob
463+ post_data [[impact_lower_col , impact_upper_col ]] = hdi_post_impact . xs (
464+ "unit_0" , level = "treated_units"
411465 ).set_index (post_data .index )
412466
413467 self .plot_data = pd .concat ([pre_data , post_data ])
0 commit comments