@@ -293,7 +293,7 @@ def _bayesian_plot(
293293 handles .append (h )
294294 labels .append ("Causal impact" )
295295
296- ax [0 ].set (title = f"{ self ._get_score_title (round_to )} " )
296+ ax [0 ].set (title = f"{ self ._get_score_title (treated_unit , round_to )} " )
297297
298298 # MIDDLE PLOT -----------------------------------------------
299299 plot_xY (
@@ -408,7 +408,7 @@ def _ols_plot(
408408 ls = ":" ,
409409 c = "k" ,
410410 )
411- ax [0 ].set (title = f"{ self ._get_score_title (round_to )} " )
411+ ax [0 ].set (title = f"{ self ._get_score_title (treated_unit , round_to )} " )
412412 # Shaded causal effect - handle different prediction formats
413413 try :
414414 # For OLS, predictions might be simple arrays
@@ -581,28 +581,13 @@ def get_plot_data_bayesian(
581581
582582 return self .plot_data
583583
584- def _get_score_title (self , round_to = None ):
585- """Generate appropriate score title based on model type and number of treated units """
584+ def _get_score_title (self , treated_unit : str , round_to = None ):
585+ """Generate appropriate score title for the specified treated unit """
586586 if isinstance (self .model , PyMCModel ):
587- if isinstance (self .score , pd .Series ):
588- # Now consistently has unit-specific keys for all cases
589- if len (self .treated_units ) > 1 :
590- mean_r2 = self .score .filter (regex = r".*_r2$" ).mean ()
591- mean_r2_std = self .score .filter (regex = r".*_r2_std$" ).mean ()
592- return f"""
593- Pre-intervention Bayesian $R^2$ (avg): { round_num (mean_r2 , round_to )}
594- (avg std = { round_num (mean_r2_std , round_to )} )
595- """
596- else :
597- # Single treated unit - use unit-specific keys
598- unit_name = self .treated_units [0 ]
599- return f"""
600- Pre-intervention Bayesian $R^2$: { round_num (self .score [f"{ unit_name } _r2" ], round_to )}
601- (std = { round_num (self .score [f"{ unit_name } _r2_std" ], round_to )} )
602- """
603- else :
604- # Fallback for non-Series score (shouldn't happen with WeightedSumFitter)
605- return f"Pre-intervention score: { round_num (self .score , round_to )} "
587+ # Bayesian model - get unit-specific R² scores
588+ r2_val = round_num (self .score [f"{ treated_unit } _r2" ], round_to )
589+ r2_std_val = round_num (self .score [f"{ treated_unit } _r2_std" ], round_to )
590+ return f"Pre-intervention Bayesian $R^2$: { r2_val } (std = { r2_std_val } )"
606591 else :
607- # OLS model - score is typically a simple float
592+ # OLS model - simple float score
608593 return f"$R^2$ on pre-intervention data = { round_num (self .score , round_to )} "
0 commit comments