@@ -231,22 +231,22 @@ def _bayesian_plot(
231231 # pre-intervention period
232232
233233 # Get treated unit name - default to first unit if None
234- primary_unit_name = (
234+ treated_unit = (
235235 treated_unit if treated_unit is not None else self .treated_units [0 ]
236236 )
237237
238- if primary_unit_name not in self .treated_units :
238+ if treated_unit not in self .treated_units :
239239 raise ValueError (
240- f"treated_unit '{ primary_unit_name } ' not found. Available units: { self .treated_units } "
240+ f"treated_unit '{ treated_unit } ' not found. Available units: { self .treated_units } "
241241 )
242242
243243 # For multi-unit, select primary unit for main plot
244244 if len (self .treated_units ) > 1 :
245245 pre_pred_plot = self .pre_pred ["posterior_predictive" ].mu .sel (
246- treated_units = primary_unit_name
246+ treated_units = treated_unit
247247 )
248248 post_pred_plot = self .post_pred ["posterior_predictive" ].mu .sel (
249- treated_units = primary_unit_name
249+ treated_units = treated_unit
250250 )
251251 else :
252252 pre_pred_plot = self .pre_pred ["posterior_predictive" ].mu
@@ -264,7 +264,7 @@ def _bayesian_plot(
264264 # Plot observations for primary treated unit
265265 (h ,) = ax [0 ].plot (
266266 self .datapre .index ,
267- self .datapre_treated .sel (treated_units = primary_unit_name ),
267+ self .datapre_treated .sel (treated_units = treated_unit ),
268268 "k." ,
269269 label = "Observations" ,
270270 )
@@ -283,42 +283,40 @@ def _bayesian_plot(
283283
284284 ax [0 ].plot (
285285 self .datapost .index ,
286- self .datapost_treated .sel (treated_units = primary_unit_name ),
286+ self .datapost_treated .sel (treated_units = treated_unit ),
287287 "k." ,
288288 )
289289 # Shaded causal effect for primary treated unit
290290 h = ax [0 ].fill_between (
291291 self .datapost .index ,
292292 y1 = post_pred_plot .mean (dim = ["chain" , "draw" ]).values ,
293- y2 = self .datapost_treated .sel (treated_units = primary_unit_name ).values ,
293+ y2 = self .datapost_treated .sel (treated_units = treated_unit ).values ,
294294 color = "C2" ,
295295 alpha = 0.25 ,
296296 label = "Causal impact" ,
297297 )
298298 handles .append (h )
299299 labels .append ("Causal impact" )
300300
301- ax [0 ].set (title = f"{ self ._get_score_title (round_to )} \n Unit " )
301+ ax [0 ].set (title = f"{ self ._get_score_title (round_to )} " )
302302
303303 # MIDDLE PLOT -----------------------------------------------
304304 plot_xY (
305305 self .datapre .index ,
306- self .pre_impact .sel (treated_units = primary_unit_name ),
306+ self .pre_impact .sel (treated_units = treated_unit ),
307307 ax = ax [1 ],
308308 plot_hdi_kwargs = {"color" : "C0" },
309309 )
310310 plot_xY (
311311 self .datapost .index ,
312- self .post_impact .sel (treated_units = primary_unit_name ),
312+ self .post_impact .sel (treated_units = treated_unit ),
313313 ax = ax [1 ],
314314 plot_hdi_kwargs = {"color" : "C1" },
315315 )
316316 ax [1 ].axhline (y = 0 , c = "k" )
317317 ax [1 ].fill_between (
318318 self .datapost .index ,
319- y1 = self .post_impact .mean (["chain" , "draw" ]).sel (
320- treated_units = primary_unit_name
321- ),
319+ y1 = self .post_impact .mean (["chain" , "draw" ]).sel (treated_units = treated_unit ),
322320 color = "C0" ,
323321 alpha = 0.25 ,
324322 label = "Causal impact" ,
@@ -329,7 +327,7 @@ def _bayesian_plot(
329327 ax [2 ].set (title = "Cumulative Causal Impact" )
330328 plot_xY (
331329 self .datapost .index ,
332- self .post_impact_cumulative .sel (treated_units = primary_unit_name ),
330+ self .post_impact_cumulative .sel (treated_units = treated_unit ),
333331 ax = ax [2 ],
334332 plot_hdi_kwargs = {"color" : "C1" },
335333 )
@@ -385,25 +383,25 @@ def _ols_plot(
385383 counterfactual_label = "Counterfactual"
386384
387385 # Get treated unit name - default to first unit if None
388- primary_unit_name = (
386+ treated_unit = (
389387 treated_unit if treated_unit is not None else self .treated_units [0 ]
390388 )
391389
392- if primary_unit_name not in self .treated_units :
390+ if treated_unit not in self .treated_units :
393391 raise ValueError (
394- f"treated_unit '{ primary_unit_name } ' not found. Available units: { self .treated_units } "
392+ f"treated_unit '{ treated_unit } ' not found. Available units: { self .treated_units } "
395393 )
396394
397395 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
398396
399397 ax [0 ].plot (
400398 self .datapre_treated ["obs_ind" ],
401- self .datapre_treated .sel (treated_units = primary_unit_name ),
399+ self .datapre_treated .sel (treated_units = treated_unit ),
402400 "k." ,
403401 )
404402 ax [0 ].plot (
405403 self .datapost_treated ["obs_ind" ],
406- self .datapost_treated .sel (treated_units = primary_unit_name ),
404+ self .datapost_treated .sel (treated_units = treated_unit ),
407405 "k." ,
408406 )
409407
@@ -415,7 +413,7 @@ def _ols_plot(
415413 ls = ":" ,
416414 c = "k" ,
417415 )
418- ax [0 ].set (title = f"{ self ._get_score_title (round_to )} \n Unit " )
416+ ax [0 ].set (title = f"{ self ._get_score_title (round_to )} " )
419417 # Shaded causal effect - handle different prediction formats
420418 try :
421419 # For OLS, predictions might be simple arrays
@@ -435,9 +433,7 @@ def _ols_plot(
435433 ax [0 ].fill_between (
436434 self .datapost .index ,
437435 y1 = post_pred_values ,
438- y2 = np .squeeze (
439- self .datapost_treated .sel (treated_units = primary_unit_name ).data
440- ),
436+ y2 = np .squeeze (self .datapost_treated .sel (treated_units = treated_unit ).data ),
441437 color = "C0" ,
442438 alpha = 0.25 ,
443439 label = "Causal impact" ,
@@ -521,13 +517,13 @@ def get_plot_data_bayesian(
521517 post_data = self .datapost .copy ()
522518
523519 # Get treated unit name - default to first unit if None
524- primary_unit_name = (
520+ treated_unit = (
525521 treated_unit if treated_unit is not None else self .treated_units [0 ]
526522 )
527523
528- if primary_unit_name not in self .treated_units :
524+ if treated_unit not in self .treated_units :
529525 raise ValueError (
530- f"treated_unit '{ primary_unit_name } ' not found. Available units: { self .treated_units } "
526+ f"treated_unit '{ treated_unit } ' not found. Available units: { self .treated_units } "
531527 )
532528
533529 # Extract predictions - handle multi-unit case
@@ -541,10 +537,10 @@ def get_plot_data_bayesian(
541537 if len (self .treated_units ) > 1 :
542538 # Multi-unit case: extract primary unit
543539 pre_data ["prediction" ] = pre_pred_vals .sel (
544- treated_units = primary_unit_name
540+ treated_units = treated_unit
545541 ).values
546542 post_data ["prediction" ] = post_pred_vals .sel (
547- treated_units = primary_unit_name
543+ treated_units = treated_unit
548544 ).values
549545 else :
550546 # Single unit case
@@ -555,13 +551,13 @@ def get_plot_data_bayesian(
555551 if len (self .treated_units ) > 1 :
556552 pre_hdi = get_hdi_to_df (
557553 self .pre_pred ["posterior_predictive" ].mu .sel (
558- treated_units = primary_unit_name
554+ treated_units = treated_unit
559555 ),
560556 hdi_prob = hdi_prob ,
561557 )
562558 post_hdi = get_hdi_to_df (
563559 self .post_pred ["posterior_predictive" ].mu .sel (
564- treated_units = primary_unit_name
560+ treated_units = treated_unit
565561 ),
566562 hdi_prob = hdi_prob ,
567563 )
@@ -583,21 +579,21 @@ def get_plot_data_bayesian(
583579 # Impact data - always use primary unit for main dataframe
584580 pre_data ["impact" ] = (
585581 self .pre_impact .mean (dim = ["chain" , "draw" ])
586- .sel (treated_units = primary_unit_name )
582+ .sel (treated_units = treated_unit )
587583 .values
588584 )
589585 post_data ["impact" ] = (
590586 self .post_impact .mean (dim = ["chain" , "draw" ])
591- .sel (treated_units = primary_unit_name )
587+ .sel (treated_units = treated_unit )
592588 .values
593589 )
594590 # Impact HDI intervals - use primary unit
595591 if len (self .treated_units ) > 1 :
596592 pre_impact_hdi = get_hdi_to_df (
597- self .pre_impact .sel (treated_units = primary_unit_name ), hdi_prob = hdi_prob
593+ self .pre_impact .sel (treated_units = treated_unit ), hdi_prob = hdi_prob
598594 )
599595 post_impact_hdi = get_hdi_to_df (
600- self .post_impact .sel (treated_units = primary_unit_name ), hdi_prob = hdi_prob
596+ self .post_impact .sel (treated_units = treated_unit ), hdi_prob = hdi_prob
601597 )
602598 else :
603599 pre_impact_hdi = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob )
0 commit comments