@@ -123,59 +123,55 @@ def algorithm(self) -> None:
123123 if isinstance (self .model , PyMCModel ):
124124 # PyMC models expect xarray DataArrays
125125 self .predictions = self .model .predict (X = self .data .X )
126- # Add period coordinate to predictions - key insight for unified operations!
126+ # Add period coordinate to predictions - InferenceData handles multiple data arrays
127127 self .predictions = self .predictions .assign_coords (
128128 period = ("obs_ind" , self .data .period .data )
129129 )
130130 else :
131131 # Sklearn models expect numpy arrays
132132 pred_array = self .model .predict (X = self .data .X .values )
133- # Create xarray DataArray with period coordinate for unified operations
133+ # Create xarray DataArray with period coordinate
134134 self .predictions = xr .DataArray (
135135 pred_array ,
136136 dims = ["obs_ind" ],
137137 coords = {
138138 "obs_ind" : self .data .obs_ind ,
139139 "period" : ("obs_ind" , self .data .period .data ),
140140 },
141- )
141+ ). set_xindex ( "period" )
142142
143- # 4. Use native xarray operations on unified predictions with period coordinate
144- # No more manual indexing - leverage xarray's .where() operations!
143+ # 4. Calculate unified impact with period coordinate - no more splitting!
145144 if isinstance (self .model , PyMCModel ):
146- # For PyMC models, use .where() on the posterior_predictive dataset
147- pp = self .predictions .posterior_predictive
148- pre_pp = pp .where (pp .period == "pre" , drop = True )
149- post_pp = pp .where (pp .period == "post" , drop = True )
150-
151- # Create new InferenceData objects for pre/post with the filtered data
152- import arviz as az
153-
154- self .pre_pred = az .InferenceData (posterior_predictive = pre_pp )
155- self .post_pred = az .InferenceData (posterior_predictive = post_pp )
156-
157- self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
158- self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
145+ # Calculate impact for the entire time series at once
146+ self .impact = self .model .calculate_impact (self .data .y , self .predictions )
147+ # Assign period coordinate to unified impact and set index
148+ self .impact = self .impact .assign_coords (
149+ period = ("obs_ind" , self .data .period .data )
150+ ).set_xindex ("period" )
159151 else :
160- # For sklearn models, same clean .where() approach
161- self .pre_pred = self .predictions .where (
162- self .predictions .period == "pre" , drop = True
163- )
164- self .post_pred = self .predictions .where (
165- self .predictions .period == "post" , drop = True
166- )
152+ # For sklearn: calculate unified impact as DataArray
153+ observed_values = self .data .y .isel (treated_units = 0 ).values
154+ predicted_values = self .predictions .values
155+ impact_values = observed_values - predicted_values
167156
168- self .pre_impact = self .model .calculate_impact (
169- self .pre_y .isel (treated_units = 0 ), self .pre_pred
170- )
171- self .post_impact = self .model .calculate_impact (
172- self .post_y .isel (treated_units = 0 ), self .post_pred
173- )
157+ self .impact = xr .DataArray (
158+ impact_values ,
159+ dims = ["obs_ind" ],
160+ coords = {
161+ "obs_ind" : self .data .obs_ind ,
162+ "period" : ("obs_ind" , self .data .period .data ),
163+ },
164+ ).set_xindex ("period" )
174165
175- # 4b. Calculate cumulative impact
176- self .post_impact_cumulative = self .model .calculate_cumulative_impact (
177- self .post_impact
178- )
166+ # 5. Calculate cumulative impact (only on post-intervention period)
167+ post_impact = self .impact .sel (period = "post" )
168+ if isinstance (self .model , PyMCModel ):
169+ self .post_impact_cumulative = self .model .calculate_cumulative_impact (
170+ post_impact
171+ )
172+ else :
173+ # For sklearn: simple cumulative sum
174+ self .post_impact_cumulative = post_impact .cumsum ()
179175
180176 def _build_data (self , data : pd .DataFrame ) -> xr .Dataset :
181177 """Build the experiment dataset as unified time series with period coordinate."""
@@ -198,6 +194,7 @@ def _build_data(self, data: pd.DataFrame) -> xr.Dataset:
198194 coords = {
199195 "obs_ind" : data .index ,
200196 "coeffs" : self .labels ,
197+ "period" : ("obs_ind" , period_coord ),
201198 },
202199 )
203200
@@ -207,35 +204,47 @@ def _build_data(self, data: pd.DataFrame) -> xr.Dataset:
207204 coords = {
208205 "obs_ind" : data .index ,
209206 "treated_units" : ["unit_0" ],
207+ "period" : ("obs_ind" , period_coord ),
210208 },
211209 )
212210
213- # Create dataset and add period as a coordinate
211+ # Create dataset and use set_xindex to make period selectable with .sel()
214212 dataset = xr .Dataset ({"X" : X_array , "y" : y_array })
215- dataset = dataset .assign_coords ( period = ( "obs_ind" , period_coord ) )
213+ dataset = dataset .set_xindex ( " period" )
216214
217215 return dataset
218216
219217 # Properties for pre/post intervention data access
220218 @property
221219 def pre_X (self ) -> xr .DataArray :
222220 """Pre-intervention features."""
223- return self .data .X .where ( self . data . period == "pre" , drop = True )
221+ return self .data .X .sel ( period = "pre" )
224222
225223 @property
226224 def pre_y (self ) -> xr .DataArray :
227225 """Pre-intervention outcomes."""
228- return self .data .y .where ( self . data . period == "pre" , drop = True )
226+ return self .data .y .sel ( period = "pre" )
229227
230228 @property
231229 def post_X (self ) -> xr .DataArray :
232230 """Post-intervention features."""
233- return self .data .X .where ( self . data . period == "post" , drop = True )
231+ return self .data .X .sel ( period = "post" )
234232
235233 @property
236234 def post_y (self ) -> xr .DataArray :
237235 """Post-intervention outcomes."""
238- return self .data .y .where (self .data .period == "post" , drop = True )
236+ return self .data .y .sel (period = "post" )
237+
238+ # Simple backward-compatible properties for impact only (still used in plotting)
239+ @property
240+ def pre_impact (self ):
241+ """Pre-intervention impact (backward compatibility)."""
242+ return self .impact .sel (period = "pre" )
243+
244+ @property
245+ def post_impact (self ):
246+ """Post-intervention impact (backward compatibility)."""
247+ return self .impact .sel (period = "post" )
239248
240249 def input_validation (self , data , treatment_time ):
241250 """Validate the input data and model formula for correctness"""
@@ -266,19 +275,29 @@ def _bayesian_plot(
266275 self , round_to = None , ** kwargs
267276 ) -> tuple [plt .Figure , List [plt .Axes ]]:
268277 """
269- Plot the results
278+ Plot the results using unified predictions with period coordinates.
270279
271280 :param round_to:
272281 Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
273282 """
274283 counterfactual_label = "Counterfactual"
275284
276285 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
286+
287+ # Extract pre/post predictions - InferenceData doesn't support .sel() with period
288+ # but .where() works fine with coordinates
289+ pre_pred = self .predictions ["posterior_predictive" ].where (
290+ self .predictions ["posterior_predictive" ].period == "pre" , drop = True
291+ )
292+ post_pred = self .predictions ["posterior_predictive" ].where (
293+ self .predictions ["posterior_predictive" ].period == "post" , drop = True
294+ )
295+
277296 # TOP PLOT --------------------------------------------------
278297 # pre-intervention period
279298 h_line , h_patch = plot_xY (
280299 self .pre_X .obs_ind ,
281- self . pre_pred [ "posterior_predictive" ] .mu .isel (treated_units = 0 ),
300+ pre_pred .mu .isel (treated_units = 0 ),
282301 ax = ax [0 ],
283302 plot_hdi_kwargs = {"color" : "C0" },
284303 )
@@ -287,9 +306,7 @@ def _bayesian_plot(
287306
288307 (h ,) = ax [0 ].plot (
289308 self .pre_X .obs_ind ,
290- self .pre_y .isel (treated_units = 0 )
291- if hasattr (self .pre_y , "isel" )
292- else self .pre_y [:, 0 ],
309+ self .pre_y .isel (treated_units = 0 ),
293310 "k." ,
294311 label = "Observations" ,
295312 )
@@ -299,7 +316,7 @@ def _bayesian_plot(
299316 # post intervention period
300317 h_line , h_patch = plot_xY (
301318 self .post_X .obs_ind ,
302- self . post_pred [ "posterior_predictive" ] .mu .isel (treated_units = 0 ),
319+ post_pred .mu .isel (treated_units = 0 ),
303320 ax = ax [0 ],
304321 plot_hdi_kwargs = {"color" : "C1" },
305322 )
@@ -308,23 +325,16 @@ def _bayesian_plot(
308325
309326 ax [0 ].plot (
310327 self .post_X .obs_ind ,
311- self .post_y .isel (treated_units = 0 )
312- if hasattr (self .post_y , "isel" )
313- else self .post_y [:, 0 ],
328+ self .post_y .isel (treated_units = 0 ),
314329 "k." ,
315330 )
316- # Shaded causal effect
317- post_pred_mu = (
318- self .post_pred ["posterior_predictive" ]
319- .mu .mean (dim = ["chain" , "draw" ])
320- .isel (treated_units = 0 )
321- )
331+
332+ # Shaded causal effect - use direct calculation
333+ post_pred_mu = post_pred .mu .mean (dim = ["chain" , "draw" ]).isel (treated_units = 0 )
322334 h = ax [0 ].fill_between (
323335 self .post_X .obs_ind ,
324336 y1 = post_pred_mu ,
325- y2 = self .post_y .isel (treated_units = 0 )
326- if hasattr (self .post_y , "isel" )
327- else self .post_y [:, 0 ],
337+ y2 = self .post_y .isel (treated_units = 0 ),
328338 color = "C0" ,
329339 alpha = 0.25 ,
330340 )
@@ -390,7 +400,7 @@ def _bayesian_plot(
390400
391401 def _ols_plot (self , round_to = None , ** kwargs ) -> tuple [plt .Figure , List [plt .Axes ]]:
392402 """
393- Plot the results
403+ Plot the results using unified predictions with period coordinates.
394404
395405 :param round_to:
396406 Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
@@ -399,13 +409,37 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
399409
400410 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
401411
412+ # Extract pre/post predictions - handle PyMC vs sklearn differently
413+ if isinstance (self .model , PyMCModel ):
414+ # For PyMC models, predictions is InferenceData - use .where() with coordinates
415+ pre_pred = (
416+ self .predictions ["posterior_predictive" ]
417+ .where (
418+ self .predictions ["posterior_predictive" ].period == "pre" , drop = True
419+ )
420+ .mu .mean (dim = ["chain" , "draw" ])
421+ .isel (treated_units = 0 )
422+ )
423+ post_pred = (
424+ self .predictions ["posterior_predictive" ]
425+ .where (
426+ self .predictions ["posterior_predictive" ].period == "post" , drop = True
427+ )
428+ .mu .mean (dim = ["chain" , "draw" ])
429+ .isel (treated_units = 0 )
430+ )
431+ else :
432+ # For sklearn models, predictions is DataArray - use .sel() with indexed coordinates
433+ pre_pred = self .predictions .sel (period = "pre" )
434+ post_pred = self .predictions .sel (period = "post" )
435+
402436 ax [0 ].plot (self .pre_X .obs_ind , self .pre_y , "k." )
403437 ax [0 ].plot (self .post_X .obs_ind , self .post_y , "k." )
404438
405- ax [0 ].plot (self .pre_X .obs_ind , self . pre_pred , c = "k" , label = "model fit" )
439+ ax [0 ].plot (self .pre_X .obs_ind , pre_pred , c = "k" , label = "model fit" )
406440 ax [0 ].plot (
407441 self .post_X .obs_ind ,
408- self . post_pred ,
442+ post_pred ,
409443 label = counterfactual_label ,
410444 ls = ":" ,
411445 c = "k" ,
@@ -431,7 +465,7 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
431465 # Shaded causal effect
432466 ax [0 ].fill_between (
433467 self .post_X .obs_ind ,
434- y1 = np .squeeze (self . post_pred ),
468+ y1 = np .squeeze (post_pred ),
435469 y2 = np .squeeze (self .post_y ),
436470 color = "C0" ,
437471 alpha = 0.25 ,
@@ -482,20 +516,18 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
482516 pred_mu = self .predictions ["posterior_predictive" ].mu .isel (treated_units = 0 )
483517 plot_data ["prediction" ] = pred_mu .mean (dim = ["chain" , "draw" ]).values
484518
485- # Calculate impact directly from unified data
486- observed = self . data . y . isel ( treated_units = 0 )
487- predicted = pred_mu . mean (dim = ["chain" , "draw" ])
488- plot_data [ "impact" ] = ( observed - predicted ). values
519+ # Extract impact directly from unified impact - no more calculation needed!
520+ plot_data [ "impact" ] = (
521+ self . impact . mean (dim = ["chain" , "draw" ]). isel ( treated_units = 0 ). values
522+ )
489523
490524 # Calculate HDI bounds directly using arviz
491525 import arviz as az
492526
493527 pred_hdi = az .hdi (pred_mu , hdi_prob = hdi_prob )
494- impact_data = observed - pred_mu
495- impact_hdi = az .hdi (impact_data , hdi_prob = hdi_prob )
528+ impact_hdi = az .hdi (self .impact .isel (treated_units = 0 ), hdi_prob = hdi_prob )
496529
497530 # Extract HDI bounds from xarray Dataset results
498- # Use the actual variable name that arviz creates (usually the first data variable)
499531 pred_var_name = list (pred_hdi .data_vars .keys ())[0 ]
500532 impact_var_name = list (impact_hdi .data_vars .keys ())[0 ]
501533
@@ -520,11 +552,9 @@ def get_plot_data_ols(self) -> pd.DataFrame:
520552 index = self .data .y .obs_ind .values ,
521553 )
522554
523- # With unified predictions, extract values directly (no more reconstruction needed!)
555+ # Extract directly from unified data structures - ultimate simplification!
524556 plot_data ["prediction" ] = self .predictions .values
525- plot_data ["impact" ] = (
526- self .data .y .isel (treated_units = 0 ) - self .predictions
527- ).values
557+ plot_data ["impact" ] = self .impact .values
528558
529559 self .plot_data = plot_data
530560 return self .plot_data
0 commit comments