@@ -105,18 +105,27 @@ def algorithm(self) -> None:
105105 if isinstance (self .model , PyMCModel ):
106106 COORDS = {
107107 "coeffs" : self .labels ,
108- "obs_ind" : np .arange (self .pre_X .shape [0 ]),
108+ "obs_ind" : np .arange (self .data . X . sel ( period = "pre" ) .shape [0 ]),
109109 "treated_units" : ["unit_0" ],
110110 }
111- self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
111+ self .model .fit (
112+ X = self .data .X .sel (period = "pre" ),
113+ y = self .data .y .sel (period = "pre" ),
114+ coords = COORDS ,
115+ )
112116 elif isinstance (self .model , RegressorMixin ):
113117 # For OLS models, use 1D y data
114- self .model .fit (X = self .pre_X , y = self .pre_y .isel (treated_units = 0 ))
118+ self .model .fit (
119+ X = self .data .X .sel (period = "pre" ),
120+ y = self .data .y .sel (period = "pre" ).isel (treated_units = 0 ),
121+ )
115122 else :
116123 raise ValueError ("Model type not recognized" )
117124
118125 # 2. Score the goodness of fit to the pre-intervention data
119- self .score = self .model .score (X = self .pre_X , y = self .pre_y )
126+ self .score = self .model .score (
127+ X = self .data .X .sel (period = "pre" ), y = self .data .y .sel (period = "pre" )
128+ )
120129
121130 # 3. Generate predictions for the full dataset using unified approach
122131 # This creates predictions aligned with our complete time series
@@ -187,53 +196,26 @@ def _build_data(self, data: pd.DataFrame) -> xr.Dataset:
187196 # Create period coordinate based on treatment time
188197 period_coord = xr .where (data .index < self .treatment_time , "pre" , "post" )
189198
190- # Return complete time series as a single xarray Dataset
191- X_array = xr .DataArray (
192- np .asarray (X_full ),
193- dims = ["obs_ind" , "coeffs" ],
194- coords = {
195- "obs_ind" : data .index ,
196- "coeffs" : self .labels ,
197- "period" : ("obs_ind" , period_coord ),
198- },
199- )
200-
201- y_array = xr .DataArray (
202- np .asarray (y_full ),
203- dims = ["obs_ind" , "treated_units" ],
204- coords = {
205- "obs_ind" : data .index ,
206- "treated_units" : ["unit_0" ],
207- "period" : ("obs_ind" , period_coord ),
208- },
209- )
210-
211- # Create dataset and use set_xindex to make period selectable with .sel()
212- dataset = xr .Dataset ({"X" : X_array , "y" : y_array })
213- dataset = dataset .set_xindex ("period" )
214-
215- return dataset
216-
217- # Properties for pre/post intervention data access
218- @property
219- def pre_X (self ) -> xr .DataArray :
220- """Pre-intervention features."""
221- return self .data .X .sel (period = "pre" )
222-
223- @property
224- def pre_y (self ) -> xr .DataArray :
225- """Pre-intervention outcomes."""
226- return self .data .y .sel (period = "pre" )
227-
228- @property
229- def post_X (self ) -> xr .DataArray :
230- """Post-intervention features."""
231- return self .data .X .sel (period = "post" )
232-
233- @property
234- def post_y (self ) -> xr .DataArray :
235- """Post-intervention outcomes."""
236- return self .data .y .sel (period = "post" )
199+ # Return as a xarray.Dataset
200+ common_coords = {
201+ "obs_ind" : data .index ,
202+ "period" : ("obs_ind" , period_coord ),
203+ }
204+
205+ return xr .Dataset (
206+ {
207+ "X" : xr .DataArray (
208+ np .asarray (X_full ),
209+ dims = ["obs_ind" , "coeffs" ],
210+ coords = {** common_coords , "coeffs" : self .labels },
211+ ),
212+ "y" : xr .DataArray (
213+ np .asarray (y_full ),
214+ dims = ["obs_ind" , "treated_units" ],
215+ coords = {** common_coords , "treated_units" : ["unit_0" ]},
216+ ),
217+ }
218+ ).set_xindex ("period" )
237219
238220 def input_validation (self , data , treatment_time ):
239221 """Validate the input data and model formula for correctness"""
@@ -285,7 +267,7 @@ def _bayesian_plot(
285267 # TOP PLOT --------------------------------------------------
286268 # pre-intervention period
287269 h_line , h_patch = plot_xY (
288- self .pre_X .obs_ind ,
270+ self .data . X . sel ( period = "pre" ) .obs_ind ,
289271 pre_pred .mu .isel (treated_units = 0 ),
290272 ax = ax [0 ],
291273 plot_hdi_kwargs = {"color" : "C0" },
@@ -294,8 +276,8 @@ def _bayesian_plot(
294276 labels = ["Pre-intervention period" ]
295277
296278 (h ,) = ax [0 ].plot (
297- self .pre_X .obs_ind ,
298- self .pre_y .isel (treated_units = 0 ),
279+ self .data . X . sel ( period = "pre" ) .obs_ind ,
280+ self .data . y . sel ( period = "pre" ) .isel (treated_units = 0 ),
299281 "k." ,
300282 label = "Observations" ,
301283 )
@@ -304,7 +286,7 @@ def _bayesian_plot(
304286
305287 # post intervention period
306288 h_line , h_patch = plot_xY (
307- self .post_X .obs_ind ,
289+ self .data . X . sel ( period = "post" ) .obs_ind ,
308290 post_pred .mu .isel (treated_units = 0 ),
309291 ax = ax [0 ],
310292 plot_hdi_kwargs = {"color" : "C1" },
@@ -313,17 +295,17 @@ def _bayesian_plot(
313295 labels .append (counterfactual_label )
314296
315297 ax [0 ].plot (
316- self .post_X .obs_ind ,
317- self .post_y .isel (treated_units = 0 ),
298+ self .data . X . sel ( period = "post" ) .obs_ind ,
299+ self .data . y . sel ( period = "post" ) .isel (treated_units = 0 ),
318300 "k." ,
319301 )
320302
321303 # Shaded causal effect - use direct calculation
322304 post_pred_mu = post_pred .mu .mean (dim = ["chain" , "draw" ]).isel (treated_units = 0 )
323305 h = ax [0 ].fill_between (
324- self .post_X .obs_ind ,
306+ self .data . X . sel ( period = "post" ) .obs_ind ,
325307 y1 = post_pred_mu ,
326- y2 = self .post_y .isel (treated_units = 0 ),
308+ y2 = self .data . y . sel ( period = "post" ) .isel (treated_units = 0 ),
327309 color = "C0" ,
328310 alpha = 0.25 ,
329311 )
@@ -339,20 +321,20 @@ def _bayesian_plot(
339321
340322 # MIDDLE PLOT -----------------------------------------------
341323 plot_xY (
342- self .pre_X .obs_ind ,
324+ self .data . X . sel ( period = "pre" ) .obs_ind ,
343325 self .impact .sel (period = "pre" ).isel (treated_units = 0 ),
344326 ax = ax [1 ],
345327 plot_hdi_kwargs = {"color" : "C0" },
346328 )
347329 plot_xY (
348- self .post_X .obs_ind ,
330+ self .data . X . sel ( period = "post" ) .obs_ind ,
349331 self .impact .sel (period = "post" ).isel (treated_units = 0 ),
350332 ax = ax [1 ],
351333 plot_hdi_kwargs = {"color" : "C1" },
352334 )
353335 ax [1 ].axhline (y = 0 , c = "k" )
354336 ax [1 ].fill_between (
355- self .post_X .obs_ind ,
337+ self .data . X . sel ( period = "post" ) .obs_ind ,
356338 y1 = self .impact .sel (period = "post" )
357339 .mean (["chain" , "draw" ])
358340 .isel (treated_units = 0 ),
@@ -365,7 +347,7 @@ def _bayesian_plot(
365347 # BOTTOM PLOT -----------------------------------------------
366348 ax [2 ].set (title = "Cumulative Causal Impact" )
367349 plot_xY (
368- self .post_X .obs_ind ,
350+ self .data . X . sel ( period = "post" ) .obs_ind ,
369351 self .post_impact_cumulative .isel (treated_units = 0 ),
370352 ax = ax [2 ],
371353 plot_hdi_kwargs = {"color" : "C1" },
@@ -424,12 +406,18 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
424406 pre_pred = self .predictions .sel (period = "pre" )
425407 post_pred = self .predictions .sel (period = "post" )
426408
427- ax [0 ].plot (self .pre_X .obs_ind , self .pre_y , "k." )
428- ax [0 ].plot (self .post_X .obs_ind , self .post_y , "k." )
409+ ax [0 ].plot (
410+ self .data .X .sel (period = "pre" ).obs_ind , self .data .y .sel (period = "pre" ), "k."
411+ )
412+ ax [0 ].plot (
413+ self .data .X .sel (period = "post" ).obs_ind , self .data .y .sel (period = "post" ), "k."
414+ )
429415
430- ax [0 ].plot (self .pre_X .obs_ind , pre_pred , c = "k" , label = "model fit" )
431416 ax [0 ].plot (
432- self .post_X .obs_ind ,
417+ self .data .X .sel (period = "pre" ).obs_ind , pre_pred , c = "k" , label = "model fit"
418+ )
419+ ax [0 ].plot (
420+ self .data .X .sel (period = "post" ).obs_ind ,
433421 post_pred ,
434422 label = counterfactual_label ,
435423 ls = ":" ,
@@ -439,31 +427,35 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
439427 title = f"$R^2$ on pre-intervention data = { round_num (self .score , round_to )} "
440428 )
441429
442- ax [1 ].plot (self .pre_X .obs_ind , self .impact .sel (period = "pre" ), "k." )
443430 ax [1 ].plot (
444- self .post_X .obs_ind ,
431+ self .data .X .sel (period = "pre" ).obs_ind , self .impact .sel (period = "pre" ), "k."
432+ )
433+ ax [1 ].plot (
434+ self .data .X .sel (period = "post" ).obs_ind ,
445435 self .impact .sel (period = "post" ),
446436 "k." ,
447437 label = counterfactual_label ,
448438 )
449439 ax [1 ].axhline (y = 0 , c = "k" )
450440 ax [1 ].set (title = "Causal Impact" )
451441
452- ax [2 ].plot (self .post_X .obs_ind , self .post_impact_cumulative , c = "k" )
442+ ax [2 ].plot (
443+ self .data .X .sel (period = "post" ).obs_ind , self .post_impact_cumulative , c = "k"
444+ )
453445 ax [2 ].axhline (y = 0 , c = "k" )
454446 ax [2 ].set (title = "Cumulative Causal Impact" )
455447
456448 # Shaded causal effect
457449 ax [0 ].fill_between (
458- self .post_X .obs_ind ,
450+ self .data . X . sel ( period = "post" ) .obs_ind ,
459451 y1 = np .squeeze (post_pred ),
460- y2 = np .squeeze (self .post_y ),
452+ y2 = np .squeeze (self .data . y . sel ( period = "post" ) ),
461453 color = "C0" ,
462454 alpha = 0.25 ,
463455 label = "Causal impact" ,
464456 )
465457 ax [1 ].fill_between (
466- self .post_X .obs_ind ,
458+ self .data . X . sel ( period = "post" ) .obs_ind ,
467459 y1 = np .squeeze (self .impact .sel (period = "post" )),
468460 color = "C0" ,
469461 alpha = 0.25 ,
0 commit comments