Skip to content

Commit aa9920a

Browse files
committed
code simplification relating to _get_score_title
1 parent b79743f commit aa9920a

File tree

2 files changed

+30
-43
lines changed

2 files changed

+30
-43
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def _bayesian_plot(
293293
handles.append(h)
294294
labels.append("Causal impact")
295295

296-
ax[0].set(title=f"{self._get_score_title(round_to)}")
296+
ax[0].set(title=f"{self._get_score_title(treated_unit, round_to)}")
297297

298298
# MIDDLE PLOT -----------------------------------------------
299299
plot_xY(
@@ -408,7 +408,7 @@ def _ols_plot(
408408
ls=":",
409409
c="k",
410410
)
411-
ax[0].set(title=f"{self._get_score_title(round_to)}")
411+
ax[0].set(title=f"{self._get_score_title(treated_unit, round_to)}")
412412
# Shaded causal effect - handle different prediction formats
413413
try:
414414
# For OLS, predictions might be simple arrays
@@ -581,28 +581,13 @@ def get_plot_data_bayesian(
581581

582582
return self.plot_data
583583

584-
def _get_score_title(self, round_to=None):
585-
"""Generate appropriate score title based on model type and number of treated units"""
584+
def _get_score_title(self, treated_unit: str, round_to=None):
585+
"""Generate appropriate score title for the specified treated unit"""
586586
if isinstance(self.model, PyMCModel):
587-
if isinstance(self.score, pd.Series):
588-
# Now consistently has unit-specific keys for all cases
589-
if len(self.treated_units) > 1:
590-
mean_r2 = self.score.filter(regex=r".*_r2$").mean()
591-
mean_r2_std = self.score.filter(regex=r".*_r2_std$").mean()
592-
return f"""
593-
Pre-intervention Bayesian $R^2$ (avg): {round_num(mean_r2, round_to)}
594-
(avg std = {round_num(mean_r2_std, round_to)})
595-
"""
596-
else:
597-
# Single treated unit - use unit-specific keys
598-
unit_name = self.treated_units[0]
599-
return f"""
600-
Pre-intervention Bayesian $R^2$: {round_num(self.score[f"{unit_name}_r2"], round_to)}
601-
(std = {round_num(self.score[f"{unit_name}_r2_std"], round_to)})
602-
"""
603-
else:
604-
# Fallback for non-Series score (shouldn't happen with WeightedSumFitter)
605-
return f"Pre-intervention score: {round_num(self.score, round_to)}"
587+
# Bayesian model - get unit-specific R² scores
588+
r2_val = round_num(self.score[f"{treated_unit}_r2"], round_to)
589+
r2_std_val = round_num(self.score[f"{treated_unit}_r2_std"], round_to)
590+
return f"Pre-intervention Bayesian $R^2$: {r2_val} (std = {r2_std_val})"
606591
else:
607-
# OLS model - score is typically a simple float
592+
# OLS model - simple float score
608593
return f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"

docs/source/notebooks/multi_cell_geolift.ipynb

Lines changed: 21 additions & 19 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)