Skip to content

Commit 3ee430e

Browse files
committed
more unification with score (r2) in terms of unified naming: unit_{n}_r2
1 parent d0c520f commit 3ee430e

File tree

7 files changed

+37
-43
lines changed

7 files changed

+37
-43
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def _bayesian_plot(
239239

240240
ax[0].set(
241241
title=f"""
242-
Pre-intervention Bayesian $R^2$: {round_num(self.score["unit_r2"], round_to)}
243-
(std = {round_num(self.score["unit_r2_std"], round_to)})
242+
Pre-intervention Bayesian $R^2$: {round_num(self.score["unit_0_r2"], round_to)}
243+
(std = {round_num(self.score["unit_0_r2_std"], round_to)})
244244
"""
245245
)
246246

causalpy/experiments/regression_discontinuity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]
256256
labels = ["Posterior mean"]
257257

258258
# create strings to compose title
259-
title_info = f"{round_num(self.score['unit_r2'], round_to)} (std = {round_num(self.score['unit_r2_std'], round_to)})"
259+
title_info = f"{round_num(self.score['unit_0_r2'], round_to)} (std = {round_num(self.score['unit_0_r2_std'], round_to)})"
260260
r2 = f"Bayesian $R^2$ on all data = {title_info}"
261261
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
262262
ci = (

causalpy/experiments/regression_kink.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]
227227
labels = ["Posterior mean"]
228228

229229
# create strings to compose title
230-
title_info = f"{round_num(self.score['unit_r2'], round_to)} (std = {round_num(self.score['unit_r2_std'], round_to)})"
230+
title_info = f"{round_num(self.score['unit_0_r2'], round_to)} (std = {round_num(self.score['unit_0_r2_std'], round_to)})"
231231
r2 = f"Bayesian $R^2$ on all data = {title_info}"
232232
percentiles = self.gradient_change.quantile([0.03, 1 - 0.03]).values
233233
ci = (

causalpy/experiments/synthetic_control.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,10 @@ def get_plot_data_bayesian(
564564
def _get_score_title(self, treated_unit: str, round_to=None):
565565
"""Generate appropriate score title for the specified treated unit"""
566566
if isinstance(self.model, PyMCModel):
567-
# Bayesian model - get unit-specific R² scores
568-
r2_val = round_num(self.score[f"{treated_unit}_r2"], round_to)
569-
r2_std_val = round_num(self.score[f"{treated_unit}_r2_std"], round_to)
567+
# Bayesian model - get unit-specific R² scores using unified format
568+
unit_index = self.treated_units.index(treated_unit)
569+
r2_val = round_num(self.score[f"unit_{unit_index}_r2"], round_to)
570+
r2_std_val = round_num(self.score[f"unit_{unit_index}_r2_std"], round_to)
570571
return f"Pre-intervention Bayesian $R^2$: {r2_val} (std = {r2_std_val})"
571572
else:
572573
# OLS model - simple float score

causalpy/pymc_models.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class PyMCModel(pm.Model):
6868
>>> model.fit(X, y)
6969
Inference data...
7070
>>> model.score(X, y) # doctest: +ELLIPSIS
71-
unit_r2 ...
72-
unit_r2_std ...
71+
unit_0_r2 ...
72+
unit_0_r2_std ...
7373
dtype: float64
7474
>>> X_new = rng.normal(loc=0, scale=1, size=(20, 2))
7575
>>> model.predict(X_new)
@@ -203,30 +203,25 @@ def score(self, X: xr.DataArray, y: xr.DataArray) -> pd.Series:
203203
mu = self.predict(X)
204204
mu_data = az.extract(mu, group="posterior_predictive", var_names="mu")
205205

206-
# Always use the multiple treated unit convention for consistency
206+
# Always use unified labeling system: unit_0_r2, unit_1_r2, etc.
207207
scores = {}
208208

209209
if "treated_units" in mu_data.dims:
210210
# Multiple treated units - score each unit separately
211211
treated_units = mu_data.coords["treated_units"].values
212-
for unit in treated_units:
212+
for i, unit in enumerate(treated_units):
213213
unit_mu = mu_data.sel(treated_units=unit).T # (sample, obs_ind)
214214
unit_y = y.sel(treated_units=unit).data
215215
unit_score = r2_score(unit_y, unit_mu.data)
216-
scores[f"{unit}_r2"] = unit_score["r2"]
217-
scores[f"{unit}_r2_std"] = unit_score["r2_std"]
216+
scores[f"unit_{i}_r2"] = unit_score["r2"]
217+
scores[f"unit_{i}_r2_std"] = unit_score["r2_std"]
218218
else:
219-
# Single treated unit - determine unit name and use same format
220-
if hasattr(y, "coords") and "treated_units" in y.coords:
221-
unit_name = y.coords["treated_units"].values[0]
222-
else:
223-
unit_name = "unit" # Fallback for backwards compatibility
224-
219+
# Single treated unit - use unit_0 for consistency
225220
mu_data = mu_data.T
226221
y_data = y.data.squeeze() if y.data.ndim > 1 else y.data
227222
unit_score = r2_score(y_data, mu_data.data)
228-
scores[f"{unit_name}_r2"] = unit_score["r2"]
229-
scores[f"{unit_name}_r2_std"] = unit_score["r2_std"]
223+
scores["unit_0_r2"] = unit_score["r2"]
224+
scores["unit_0_r2_std"] = unit_score["r2_std"]
230225

231226
return pd.Series(scores)
232227

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -863,10 +863,10 @@ def test_multi_unit_scoring(self, multi_unit_sc_data):
863863
# Score should be a pandas Series with separate entries for each unit
864864
assert isinstance(sc.score, pd.Series)
865865

866-
# Check that we have r2 and r2_std for each treated unit
867-
for unit in treated_units:
868-
assert f"{unit}_r2" in sc.score.index
869-
assert f"{unit}_r2_std" in sc.score.index
866+
# Check that we have r2 and r2_std for each treated unit using unified format
867+
for i, unit in enumerate(treated_units):
868+
assert f"unit_{i}_r2" in sc.score.index
869+
assert f"unit_{i}_r2_std" in sc.score.index
870870

871871
@pytest.mark.integration
872872
def test_multi_unit_summary(self, multi_unit_sc_data, capsys):

causalpy/tests/test_pymc_models.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def test_fit_predict(self, coords, rng) -> None:
120120
).shape == (20, 2 * 2)
121121
assert isinstance(score, pd.Series)
122122
assert score.shape == (2,)
123-
# Test that the score follows the new standardized format
124-
assert "unit_r2" in score.index
125-
assert "unit_r2_std" in score.index
123+
# Test that the score follows the new unified format
124+
assert "unit_0_r2" in score.index
125+
assert "unit_0_r2_std" in score.index
126126
assert isinstance(predictions, az.InferenceData)
127127

128128

@@ -423,15 +423,15 @@ def test_scoring_multi_unit(self, synthetic_control_data):
423423
# Score should be a pandas Series with separate r2 and r2_std for each treated unit
424424
assert isinstance(score, pd.Series)
425425

426-
# Check that we have r2 and r2_std for each treated unit
427-
for unit in treated_units:
428-
assert f"{unit}_r2" in score.index
429-
assert f"{unit}_r2_std" in score.index
426+
# Check that we have r2 and r2_std for each treated unit using unified format
427+
for i, unit in enumerate(treated_units):
428+
assert f"unit_{i}_r2" in score.index
429+
assert f"unit_{i}_r2_std" in score.index
430430

431431
# R2 should be reasonable (between 0 and 1 typically, though can be negative)
432-
assert score[f"{unit}_r2"] >= -1 # R2 can be negative for very bad fits
432+
assert score[f"unit_{i}_r2"] >= -1 # R2 can be negative for very bad fits
433433
assert (
434-
score[f"{unit}_r2_std"] >= 0
434+
score[f"unit_{i}_r2_std"] >= 0
435435
) # Standard deviation should be non-negative
436436

437437
def test_scoring_single_unit(self, single_treated_data):
@@ -444,16 +444,14 @@ def test_scoring_single_unit(self, single_treated_data):
444444
# Test scoring
445445
score = wsf.score(X, y)
446446

447-
# Now consistently uses treated unit name prefix even for single unit
447+
# Now consistently uses unified unit indexing even for single unit
448448
assert isinstance(score, pd.Series)
449-
assert "treated_0_r2" in score.index
450-
assert "treated_0_r2_std" in score.index
449+
assert "unit_0_r2" in score.index
450+
assert "unit_0_r2_std" in score.index
451451

452452
# R2 should be reasonable
453-
assert score["treated_0_r2"] >= -1 # R2 can be negative for very bad fits
454-
assert (
455-
score["treated_0_r2_std"] >= 0
456-
) # Standard deviation should be non-negative
453+
assert score["unit_0_r2"] >= -1 # R2 can be negative for very bad fits
454+
assert score["unit_0_r2_std"] >= 0 # Standard deviation should be non-negative
457455

458456
def test_r2_scores_differ_across_units(self, rng):
459457
"""Test that R² scores are different for different treated units.
@@ -523,8 +521,8 @@ def test_r2_scores_differ_across_units(self, rng):
523521
wsf.fit(X, y, coords=coords)
524522
scores = wsf.score(X, y)
525523

526-
# Extract R² values for each treated unit
527-
r2_values = [scores[f"{unit}_r2"] for unit in treated_units]
524+
# Extract R² values for each treated unit using unified format
525+
r2_values = [scores[f"unit_{i}_r2"] for i in range(len(treated_units))]
528526

529527
# Test that not all R² values are the same
530528
# Use a tolerance to avoid issues with floating point precision

0 commit comments

Comments
 (0)