Skip to content

Commit b79743f

Browse files
committed
code simplifications by always having a treated_units dimension
1 parent ee8a92b commit b79743f

File tree

5 files changed

+69
-95
lines changed

5 files changed

+69
-95
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 31 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -240,17 +240,12 @@ def _bayesian_plot(
240240
f"treated_unit '{treated_unit}' not found. Available units: {self.treated_units}"
241241
)
242242

243-
# For multi-unit, select primary unit for main plot
244-
if len(self.treated_units) > 1:
245-
pre_pred_plot = self.pre_pred["posterior_predictive"].mu.sel(
246-
treated_units=treated_unit
247-
)
248-
post_pred_plot = self.post_pred["posterior_predictive"].mu.sel(
249-
treated_units=treated_unit
250-
)
251-
else:
252-
pre_pred_plot = self.pre_pred["posterior_predictive"].mu
253-
post_pred_plot = self.post_pred["posterior_predictive"].mu
243+
pre_pred_plot = self.pre_pred["posterior_predictive"].mu.sel(
244+
treated_units=treated_unit
245+
)
246+
post_pred_plot = self.post_pred["posterior_predictive"].mu.sel(
247+
treated_units=treated_unit
248+
)
254249

255250
h_line, h_patch = plot_xY(
256251
self.datapre.index,
@@ -419,6 +414,7 @@ def _ols_plot(
419414
# For OLS, predictions might be simple arrays
420415
post_pred_values = np.squeeze(self.post_pred)
421416
except (TypeError, AttributeError):
417+
# TODO: WILL THIS PATH EVERY BIT HIT?
422418
# For PyMC predictions (InferenceData)
423419
post_pred_values = (
424420
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
@@ -534,40 +530,19 @@ def get_plot_data_bayesian(
534530
self.post_pred, group="posterior_predictive", var_names="mu"
535531
).mean("sample")
536532

537-
if len(self.treated_units) > 1:
538-
# Multi-unit case: extract primary unit
539-
pre_data["prediction"] = pre_pred_vals.sel(
540-
treated_units=treated_unit
541-
).values
542-
post_data["prediction"] = post_pred_vals.sel(
543-
treated_units=treated_unit
544-
).values
545-
else:
546-
# Single unit case
547-
pre_data["prediction"] = pre_pred_vals.values
548-
post_data["prediction"] = post_pred_vals.values
533+
# Extract predictions for the specified treated unit (always has treated_units dimension)
534+
pre_data["prediction"] = pre_pred_vals.sel(treated_units=treated_unit).values
535+
post_data["prediction"] = post_pred_vals.sel(treated_units=treated_unit).values
549536

550-
# HDI intervals for predictions
551-
if len(self.treated_units) > 1:
552-
pre_hdi = get_hdi_to_df(
553-
self.pre_pred["posterior_predictive"].mu.sel(
554-
treated_units=treated_unit
555-
),
556-
hdi_prob=hdi_prob,
557-
)
558-
post_hdi = get_hdi_to_df(
559-
self.post_pred["posterior_predictive"].mu.sel(
560-
treated_units=treated_unit
561-
),
562-
hdi_prob=hdi_prob,
563-
)
564-
else:
565-
pre_hdi = get_hdi_to_df(
566-
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
567-
)
568-
post_hdi = get_hdi_to_df(
569-
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
570-
)
537+
# HDI intervals for predictions (always use treated_units dimension)
538+
pre_hdi = get_hdi_to_df(
539+
self.pre_pred["posterior_predictive"].mu.sel(treated_units=treated_unit),
540+
hdi_prob=hdi_prob,
541+
)
542+
post_hdi = get_hdi_to_df(
543+
self.post_pred["posterior_predictive"].mu.sel(treated_units=treated_unit),
544+
hdi_prob=hdi_prob,
545+
)
571546

572547
# Extract only the lower and upper columns and ensure proper indexing
573548
pre_lower_upper = pre_hdi.iloc[:, [0, -1]].values # Get first and last columns
@@ -587,17 +562,13 @@ def get_plot_data_bayesian(
587562
.sel(treated_units=treated_unit)
588563
.values
589564
)
590-
# Impact HDI intervals - use primary unit
591-
if len(self.treated_units) > 1:
592-
pre_impact_hdi = get_hdi_to_df(
593-
self.pre_impact.sel(treated_units=treated_unit), hdi_prob=hdi_prob
594-
)
595-
post_impact_hdi = get_hdi_to_df(
596-
self.post_impact.sel(treated_units=treated_unit), hdi_prob=hdi_prob
597-
)
598-
else:
599-
pre_impact_hdi = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob)
600-
post_impact_hdi = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob)
565+
# Impact HDI intervals (always use treated_units dimension)
566+
pre_impact_hdi = get_hdi_to_df(
567+
self.pre_impact.sel(treated_units=treated_unit), hdi_prob=hdi_prob
568+
)
569+
post_impact_hdi = get_hdi_to_df(
570+
self.post_impact.sel(treated_units=treated_unit), hdi_prob=hdi_prob
571+
)
601572

602573
# Extract only the lower and upper columns for impact HDI
603574
pre_impact_lower_upper = pre_impact_hdi.iloc[:, [0, -1]].values
@@ -614,7 +585,7 @@ def _get_score_title(self, round_to=None):
614585
"""Generate appropriate score title based on model type and number of treated units"""
615586
if isinstance(self.model, PyMCModel):
616587
if isinstance(self.score, pd.Series):
617-
# Check if it's multi-unit format (has unit-specific keys)
588+
# Now consistently has unit-specific keys for all cases
618589
if len(self.treated_units) > 1:
619590
mean_r2 = self.score.filter(regex=r".*_r2$").mean()
620591
mean_r2_std = self.score.filter(regex=r".*_r2_std$").mean()
@@ -623,10 +594,11 @@ def _get_score_title(self, round_to=None):
623594
(avg std = {round_num(mean_r2_std, round_to)})
624595
"""
625596
else:
626-
# Single treated unit - Series has 'r2' and 'r2_std' keys
597+
# Single treated unit - use unit-specific keys
598+
unit_name = self.treated_units[0]
627599
return f"""
628-
Pre-intervention Bayesian $R^2$: {round_num(self.score["r2"], round_to)}
629-
(std = {round_num(self.score["r2_std"], round_to)})
600+
Pre-intervention Bayesian $R^2$: {round_num(self.score[f"{unit_name}_r2"], round_to)}
601+
(std = {round_num(self.score[f"{unit_name}_r2_std"], round_to)})
630602
"""
631603
else:
632604
# Fallback for non-Series score (shouldn't happen with WeightedSumFitter)

causalpy/pymc_models.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _data_setter(self, X) -> None:
110110

111111
with self:
112112
if has_treated_units:
113-
# Multiple treated units - get the number from the model coordinates
113+
# Get the number of treated units from the model coordinates
114114
treated_units_coord = getattr(self, "coords", {}).get(
115115
"treated_units", []
116116
)
@@ -122,7 +122,7 @@ def _data_setter(self, X) -> None:
122122
coords={"obs_ind": np.arange(new_no_of_observations)},
123123
)
124124
else:
125-
# Single treated unit case
125+
# Legacy case - this shouldn't happen with new WeightedSumFitter
126126
pm.set_data(
127127
{"X": X, "y": np.zeros(new_no_of_observations)},
128128
coords={"obs_ind": np.arange(new_no_of_observations)},
@@ -378,28 +378,29 @@ def build_model(self, X, y, coords):
378378
n_predictors = X.shape[1]
379379
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
380380

381-
# Check if we have multiple treated units
382-
if y.ndim > 1 and y.shape[1] > 1:
383-
# Multiple treated units case
384-
y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
385-
beta = pm.Dirichlet(
386-
"beta", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
387-
)
388-
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
389-
mu = pm.Deterministic(
390-
"mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
391-
)
392-
pm.Normal(
393-
"y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"]
394-
)
381+
# Always use treated_units dimension for consistency
382+
# Convert to numpy array if it's an xarray DataArray
383+
if hasattr(y, "values"):
384+
y_data = y.values
395385
else:
396-
# Single treated unit case (backward compatibility)
397-
y_data = y[:, 0] if y.ndim > 1 else y
398-
y = pm.Data("y", y_data, dims="obs_ind")
399-
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
400-
sigma = pm.HalfNormal("sigma", 1)
401-
mu = pm.Deterministic("mu", pt.dot(X, beta), dims="obs_ind")
402-
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")
386+
y_data = np.asarray(y)
387+
388+
# Ensure y_data has treated_units dimension
389+
if y_data.ndim == 1:
390+
y_data = y_data.reshape(-1, 1) # Add treated_units dimension
391+
elif y_data.ndim > 1 and y_data.shape[1] == 1:
392+
pass # Already has correct shape
393+
# If y_data.ndim > 1 and y_data.shape[1] > 1, it's multi-unit and already correct
394+
395+
y = pm.Data("y", y_data, dims=["obs_ind", "treated_units"])
396+
beta = pm.Dirichlet(
397+
"beta", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
398+
)
399+
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
400+
mu = pm.Deterministic(
401+
"mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
402+
)
403+
pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
403404

404405

405406
class InstrumentalVariableRegression(PyMCModel):

causalpy/tests/test_multi_unit_sc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def test_single_unit_backward_compatibility(self, single_unit_sc_data):
210210

211211
# Score should still work
212212
assert isinstance(sc.score, pd.Series)
213-
assert "r2" in sc.score.index
214-
assert "r2_std" in sc.score.index
213+
assert "treated_0_r2" in sc.score.index
214+
assert "treated_0_r2_std" in sc.score.index
215215

216216
def test_multi_unit_plotting(self, multi_unit_sc_data):
217217
"""Test that plotting works with multiple treated units."""

causalpy/tests/test_multi_unit_wsf.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,9 @@ def test_backward_compatibility_single_unit(self, single_treated_data):
233233
# Test prediction
234234
pred = wsf.predict(X)
235235

236-
# For single unit, should not have treated_units dimension in some places
237-
# but should still work correctly
236+
# Now always has treated_units dimension, even for single unit
238237
mu_shape = pred["posterior_predictive"]["mu"].shape
239-
expected_shape = (sample_kwargs["chains"], sample_kwargs["draws"], len(X))
238+
expected_shape = (sample_kwargs["chains"], sample_kwargs["draws"], len(X), 1)
240239
assert mu_shape == expected_shape
241240

242241
def test_print_coefficients_multi_unit(self, synthetic_control_data, capsys):
@@ -297,14 +296,16 @@ def test_scoring_single_unit(self, single_treated_data):
297296
# Test scoring
298297
score = wsf.score(X, y)
299298

300-
# For single unit, should have the same format as before
299+
# Now consistently uses treated unit name prefix even for single unit
301300
assert isinstance(score, pd.Series)
302-
assert "r2" in score.index
303-
assert "r2_std" in score.index
301+
assert "treated_0_r2" in score.index
302+
assert "treated_0_r2_std" in score.index
304303

305304
# R2 should be reasonable
306-
assert score["r2"] >= -1 # R2 can be negative for very bad fits
307-
assert score["r2_std"] >= 0 # Standard deviation should be non-negative
305+
assert score["treated_0_r2"] >= -1 # R2 can be negative for very bad fits
306+
assert (
307+
score["treated_0_r2_std"] >= 0
308+
) # Standard deviation should be non-negative
308309

309310
def test_r2_scores_differ_across_units(self, rng):
310311
"""Test that R² scores are different for different treated units.

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)