@@ -120,9 +120,9 @@ def __init__(
120120            },
121121        )
122122        self .pre_y  =  xr .DataArray (
123-             self .pre_y [:,  0 ], 
124-             dims = ["obs_ind" ],
125-             coords = {"obs_ind" : self .datapre .index },
123+             self .pre_y ,   # Keep 2D shape 
124+             dims = ["obs_ind" ,  "treated_units" ],
125+             coords = {"obs_ind" : self .datapre .index ,  "treated_units" : [ "unit_0" ] },
126126        )
127127        self .post_X  =  xr .DataArray (
128128            self .post_X ,
@@ -133,17 +133,22 @@ def __init__(
133133            },
134134        )
135135        self .post_y  =  xr .DataArray (
136-             self .post_y [:,  0 ], 
137-             dims = ["obs_ind" ],
138-             coords = {"obs_ind" : self .datapost .index },
136+             self .post_y ,   # Keep 2D shape 
137+             dims = ["obs_ind" ,  "treated_units" ],
138+             coords = {"obs_ind" : self .datapost .index ,  "treated_units" : [ "unit_0" ] },
139139        )
140140
141141        # fit the model to the observed (pre-intervention) data 
142142        if  isinstance (self .model , PyMCModel ):
143-             COORDS  =  {"coeffs" : self .labels , "obs_ind" : np .arange (self .pre_X .shape [0 ])}
143+             COORDS  =  {
144+                 "coeffs" : self .labels ,
145+                 "obs_ind" : np .arange (self .pre_X .shape [0 ]),
146+                 "treated_units" : ["unit_0" ],
147+             }
144148            self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
145149        elif  isinstance (self .model , RegressorMixin ):
146-             self .model .fit (X = self .pre_X , y = self .pre_y )
150+             # For OLS models, use 1D y data 
151+             self .model .fit (X = self .pre_X , y = self .pre_y .isel (treated_units = 0 ))
147152        else :
148153            raise  ValueError ("Model type not recognized" )
149154
@@ -155,8 +160,21 @@ def __init__(
155160
156161        # calculate the counterfactual 
157162        self .post_pred  =  self .model .predict (X = self .post_X )
158-         self .pre_impact  =  self .model .calculate_impact (self .pre_y , self .pre_pred )
159-         self .post_impact  =  self .model .calculate_impact (self .post_y , self .post_pred )
163+ 
164+         # calculate impact - use appropriate y data format for each model type 
165+         if  isinstance (self .model , PyMCModel ):
166+             # PyMC models work with 2D data 
167+             self .pre_impact  =  self .model .calculate_impact (self .pre_y , self .pre_pred )
168+             self .post_impact  =  self .model .calculate_impact (self .post_y , self .post_pred )
169+         elif  isinstance (self .model , RegressorMixin ):
170+             # SKL models work with 1D data 
171+             self .pre_impact  =  self .model .calculate_impact (
172+                 self .pre_y .isel (treated_units = 0 ), self .pre_pred 
173+             )
174+             self .post_impact  =  self .model .calculate_impact (
175+                 self .post_y .isel (treated_units = 0 ), self .post_pred 
176+             )
177+ 
160178        self .post_impact_cumulative  =  self .model .calculate_cumulative_impact (
161179            self .post_impact 
162180        )
@@ -202,35 +220,53 @@ def _bayesian_plot(
202220        # pre-intervention period 
203221        h_line , h_patch  =  plot_xY (
204222            self .datapre .index ,
205-             self .pre_pred ["posterior_predictive" ].mu ,
223+             self .pre_pred ["posterior_predictive" ].mu . isel ( treated_units = 0 ) ,
206224            ax = ax [0 ],
207225            plot_hdi_kwargs = {"color" : "C0" },
208226        )
209227        handles  =  [(h_line , h_patch )]
210228        labels  =  ["Pre-intervention period" ]
211229
212-         (h ,) =  ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
230+         (h ,) =  ax [0 ].plot (
231+             self .datapre .index ,
232+             self .pre_y .isel (treated_units = 0 )
233+             if  hasattr (self .pre_y , "isel" )
234+             else  self .pre_y [:, 0 ],
235+             "k." ,
236+             label = "Observations" ,
237+         )
213238        handles .append (h )
214239        labels .append ("Observations" )
215240
216241        # post intervention period 
217242        h_line , h_patch  =  plot_xY (
218243            self .datapost .index ,
219-             self .post_pred ["posterior_predictive" ].mu ,
244+             self .post_pred ["posterior_predictive" ].mu . isel ( treated_units = 0 ) ,
220245            ax = ax [0 ],
221246            plot_hdi_kwargs = {"color" : "C1" },
222247        )
223248        handles .append ((h_line , h_patch ))
224249        labels .append (counterfactual_label )
225250
226-         ax [0 ].plot (self .datapost .index , self .post_y , "k." )
251+         ax [0 ].plot (
252+             self .datapost .index ,
253+             self .post_y .isel (treated_units = 0 )
254+             if  hasattr (self .post_y , "isel" )
255+             else  self .post_y [:, 0 ],
256+             "k." ,
257+         )
227258        # Shaded causal effect 
259+         post_pred_mu  =  (
260+             az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
261+             .isel (treated_units = 0 )
262+             .mean ("sample" )
263+         )  # Add .mean("sample") to get 1D array 
228264        h  =  ax [0 ].fill_between (
229265            self .datapost .index ,
230-             y1 = az . extract ( 
231-                  self .post_pred ,  group = "posterior_predictive" ,  var_names = "mu" 
232-             ). mean ( "sample" ), 
233-             y2 = np . squeeze ( self .post_y ) ,
266+             y1 = post_pred_mu , 
267+             y2 = self .post_y . isel ( treated_units = 0 ) 
268+             if   hasattr ( self . post_y ,  "isel" ) 
269+             else   self .post_y [:,  0 ] ,
234270            color = "C0" ,
235271            alpha = 0.25 ,
236272        )
@@ -239,28 +275,28 @@ def _bayesian_plot(
239275
240276        ax [0 ].set (
241277            title = f""" 
242-             Pre-intervention Bayesian $R^2$: { round_num (self .score . r2 , round_to )}  
243-             (std = { round_num (self .score . r2_std , round_to )}  ) 
278+             Pre-intervention Bayesian $R^2$: { round_num (self .score [ "unit_0_r2" ] , round_to )}  
279+             (std = { round_num (self .score [ "unit_0_r2_std" ] , round_to )}  ) 
244280            """ 
245281        )
246282
247283        # MIDDLE PLOT ----------------------------------------------- 
248284        plot_xY (
249285            self .datapre .index ,
250-             self .pre_impact ,
286+             self .pre_impact . isel ( treated_units = 0 ) ,
251287            ax = ax [1 ],
252288            plot_hdi_kwargs = {"color" : "C0" },
253289        )
254290        plot_xY (
255291            self .datapost .index ,
256-             self .post_impact ,
292+             self .post_impact . isel ( treated_units = 0 ) ,
257293            ax = ax [1 ],
258294            plot_hdi_kwargs = {"color" : "C1" },
259295        )
260296        ax [1 ].axhline (y = 0 , c = "k" )
261297        ax [1 ].fill_between (
262298            self .datapost .index ,
263-             y1 = self .post_impact .mean (["chain" , "draw" ]),
299+             y1 = self .post_impact .mean (["chain" , "draw" ]). isel ( treated_units = 0 ) ,
264300            color = "C0" ,
265301            alpha = 0.25 ,
266302            label = "Causal impact" ,
@@ -271,7 +307,7 @@ def _bayesian_plot(
271307        ax [2 ].set (title = "Cumulative Causal Impact" )
272308        plot_xY (
273309            self .datapost .index ,
274-             self .post_impact_cumulative ,
310+             self .post_impact_cumulative . isel ( treated_units = 0 ) ,
275311            ax = ax [2 ],
276312            plot_hdi_kwargs = {"color" : "C1" },
277313        )
@@ -387,27 +423,45 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
387423            pre_data ["prediction" ] =  (
388424                az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
389425                .mean ("sample" )
426+                 .isel (treated_units = 0 )
390427                .values 
391428            )
392429            post_data ["prediction" ] =  (
393430                az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
394431                .mean ("sample" )
432+                 .isel (treated_units = 0 )
395433                .values 
396434            )
397-             pre_data [[ pred_lower_col ,  pred_upper_col ]]  =  get_hdi_to_df (
435+             hdi_pre_pred  =  get_hdi_to_df (
398436                self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob 
399-             ). set_index ( pre_data . index ) 
400-             post_data [[ pred_lower_col ,  pred_upper_col ]]  =  get_hdi_to_df (
437+             )
438+             hdi_post_pred  =  get_hdi_to_df (
401439                self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob 
440+             )
441+             # Select the single unit from the MultiIndex results 
442+             pre_data [[pred_lower_col , pred_upper_col ]] =  hdi_pre_pred .xs (
443+                 "unit_0" , level = "treated_units" 
444+             ).set_index (pre_data .index )
445+             post_data [[pred_lower_col , pred_upper_col ]] =  hdi_post_pred .xs (
446+                 "unit_0" , level = "treated_units" 
402447            ).set_index (post_data .index )
403448
404-             pre_data ["impact" ] =  self .pre_impact .mean (dim = ["chain" , "draw" ]).values 
405-             post_data ["impact" ] =  self .post_impact .mean (dim = ["chain" , "draw" ]).values 
406-             pre_data [[impact_lower_col , impact_upper_col ]] =  get_hdi_to_df (
407-                 self .pre_impact , hdi_prob = hdi_prob 
449+             pre_data ["impact" ] =  (
450+                 self .pre_impact .mean (dim = ["chain" , "draw" ]).isel (treated_units = 0 ).values 
451+             )
452+             post_data ["impact" ] =  (
453+                 self .post_impact .mean (dim = ["chain" , "draw" ])
454+                 .isel (treated_units = 0 )
455+                 .values 
456+             )
457+             hdi_pre_impact  =  get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob )
458+             hdi_post_impact  =  get_hdi_to_df (self .post_impact , hdi_prob = hdi_prob )
459+             # Select the single unit from the MultiIndex results 
460+             pre_data [[impact_lower_col , impact_upper_col ]] =  hdi_pre_impact .xs (
461+                 "unit_0" , level = "treated_units" 
408462            ).set_index (pre_data .index )
409-             post_data [[impact_lower_col , impact_upper_col ]] =  get_hdi_to_df (
410-                 self . post_impact ,  hdi_prob = hdi_prob 
463+             post_data [[impact_lower_col , impact_upper_col ]] =  hdi_post_impact . xs (
464+                 "unit_0" ,  level = "treated_units" 
411465            ).set_index (post_data .index )
412466
413467            self .plot_data  =  pd .concat ([pre_data , post_data ])
0 commit comments