@@ -240,17 +240,12 @@ def _bayesian_plot(
240240 f"treated_unit '{ treated_unit } ' not found. Available units: { self .treated_units } "
241241 )
242242
243- # For multi-unit, select primary unit for main plot
244- if len (self .treated_units ) > 1 :
245- pre_pred_plot = self .pre_pred ["posterior_predictive" ].mu .sel (
246- treated_units = treated_unit
247- )
248- post_pred_plot = self .post_pred ["posterior_predictive" ].mu .sel (
249- treated_units = treated_unit
250- )
251- else :
252- pre_pred_plot = self .pre_pred ["posterior_predictive" ].mu
253- post_pred_plot = self .post_pred ["posterior_predictive" ].mu
243+ pre_pred_plot = self .pre_pred ["posterior_predictive" ].mu .sel (
244+ treated_units = treated_unit
245+ )
246+ post_pred_plot = self .post_pred ["posterior_predictive" ].mu .sel (
247+ treated_units = treated_unit
248+ )
254249
255250 h_line , h_patch = plot_xY (
256251 self .datapre .index ,
@@ -419,6 +414,7 @@ def _ols_plot(
419414 # For OLS, predictions might be simple arrays
420415 post_pred_values = np .squeeze (self .post_pred )
421416 except (TypeError , AttributeError ):
417+ # TODO: WILL THIS PATH EVERY BIT HIT?
422418 # For PyMC predictions (InferenceData)
423419 post_pred_values = (
424420 az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
@@ -534,40 +530,19 @@ def get_plot_data_bayesian(
534530 self .post_pred , group = "posterior_predictive" , var_names = "mu"
535531 ).mean ("sample" )
536532
537- if len (self .treated_units ) > 1 :
538- # Multi-unit case: extract primary unit
539- pre_data ["prediction" ] = pre_pred_vals .sel (
540- treated_units = treated_unit
541- ).values
542- post_data ["prediction" ] = post_pred_vals .sel (
543- treated_units = treated_unit
544- ).values
545- else :
546- # Single unit case
547- pre_data ["prediction" ] = pre_pred_vals .values
548- post_data ["prediction" ] = post_pred_vals .values
533+ # Extract predictions for the specified treated unit (always has treated_units dimension)
534+ pre_data ["prediction" ] = pre_pred_vals .sel (treated_units = treated_unit ).values
535+ post_data ["prediction" ] = post_pred_vals .sel (treated_units = treated_unit ).values
549536
550- # HDI intervals for predictions
551- if len (self .treated_units ) > 1 :
552- pre_hdi = get_hdi_to_df (
553- self .pre_pred ["posterior_predictive" ].mu .sel (
554- treated_units = treated_unit
555- ),
556- hdi_prob = hdi_prob ,
557- )
558- post_hdi = get_hdi_to_df (
559- self .post_pred ["posterior_predictive" ].mu .sel (
560- treated_units = treated_unit
561- ),
562- hdi_prob = hdi_prob ,
563- )
564- else :
565- pre_hdi = get_hdi_to_df (
566- self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
567- )
568- post_hdi = get_hdi_to_df (
569- self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
570- )
537+ # HDI intervals for predictions (always use treated_units dimension)
538+ pre_hdi = get_hdi_to_df (
539+ self .pre_pred ["posterior_predictive" ].mu .sel (treated_units = treated_unit ),
540+ hdi_prob = hdi_prob ,
541+ )
542+ post_hdi = get_hdi_to_df (
543+ self .post_pred ["posterior_predictive" ].mu .sel (treated_units = treated_unit ),
544+ hdi_prob = hdi_prob ,
545+ )
571546
572547 # Extract only the lower and upper columns and ensure proper indexing
573548 pre_lower_upper = pre_hdi .iloc [:, [0 , - 1 ]].values # Get first and last columns
@@ -587,17 +562,13 @@ def get_plot_data_bayesian(
587562 .sel (treated_units = treated_unit )
588563 .values
589564 )
590- # Impact HDI intervals - use primary unit
591- if len (self .treated_units ) > 1 :
592- pre_impact_hdi = get_hdi_to_df (
593- self .pre_impact .sel (treated_units = treated_unit ), hdi_prob = hdi_prob
594- )
595- post_impact_hdi = get_hdi_to_df (
596- self .post_impact .sel (treated_units = treated_unit ), hdi_prob = hdi_prob
597- )
598- else :
599- pre_impact_hdi = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob )
600- post_impact_hdi = get_hdi_to_df (self .post_impact , hdi_prob = hdi_prob )
565+ # Impact HDI intervals (always use treated_units dimension)
566+ pre_impact_hdi = get_hdi_to_df (
567+ self .pre_impact .sel (treated_units = treated_unit ), hdi_prob = hdi_prob
568+ )
569+ post_impact_hdi = get_hdi_to_df (
570+ self .post_impact .sel (treated_units = treated_unit ), hdi_prob = hdi_prob
571+ )
601572
602573 # Extract only the lower and upper columns for impact HDI
603574 pre_impact_lower_upper = pre_impact_hdi .iloc [:, [0 , - 1 ]].values
@@ -614,7 +585,7 @@ def _get_score_title(self, round_to=None):
614585 """Generate appropriate score title based on model type and number of treated units"""
615586 if isinstance (self .model , PyMCModel ):
616587 if isinstance (self .score , pd .Series ):
617- # Check if it's multi-unit format ( has unit-specific keys)
588+ # Now consistently has unit-specific keys for all cases
618589 if len (self .treated_units ) > 1 :
619590 mean_r2 = self .score .filter (regex = r".*_r2$" ).mean ()
620591 mean_r2_std = self .score .filter (regex = r".*_r2_std$" ).mean ()
@@ -623,10 +594,11 @@ def _get_score_title(self, round_to=None):
623594 (avg std = { round_num (mean_r2_std , round_to )} )
624595 """
625596 else :
626- # Single treated unit - Series has 'r2' and 'r2_std' keys
597+ # Single treated unit - use unit-specific keys
598+ unit_name = self .treated_units [0 ]
627599 return f"""
628- Pre-intervention Bayesian $R^2$: { round_num (self .score ["r2 " ], round_to )}
629- (std = { round_num (self .score ["r2_std " ], round_to )} )
600+ Pre-intervention Bayesian $R^2$: { round_num (self .score [f" { unit_name } _r2 " ], round_to )}
601+ (std = { round_num (self .score [f" { unit_name } _r2_std " ], round_to )} )
630602 """
631603 else :
632604 # Fallback for non-Series score (shouldn't happen with WeightedSumFitter)
0 commit comments