Skip to content

Commit 47e0a2d

Browse files
committed
rename primary_unit_name -> treated_unit + revert to no "Unit" title
1 parent be01357 commit 47e0a2d

File tree

1 file changed

+31
-35
lines changed

1 file changed

+31
-35
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -231,22 +231,22 @@ def _bayesian_plot(
231231
# pre-intervention period
232232

233233
# Get treated unit name - default to first unit if None
234-
primary_unit_name = (
234+
treated_unit = (
235235
treated_unit if treated_unit is not None else self.treated_units[0]
236236
)
237237

238-
if primary_unit_name not in self.treated_units:
238+
if treated_unit not in self.treated_units:
239239
raise ValueError(
240-
f"treated_unit '{primary_unit_name}' not found. Available units: {self.treated_units}"
240+
f"treated_unit '{treated_unit}' not found. Available units: {self.treated_units}"
241241
)
242242

243243
# For multi-unit, select primary unit for main plot
244244
if len(self.treated_units) > 1:
245245
pre_pred_plot = self.pre_pred["posterior_predictive"].mu.sel(
246-
treated_units=primary_unit_name
246+
treated_units=treated_unit
247247
)
248248
post_pred_plot = self.post_pred["posterior_predictive"].mu.sel(
249-
treated_units=primary_unit_name
249+
treated_units=treated_unit
250250
)
251251
else:
252252
pre_pred_plot = self.pre_pred["posterior_predictive"].mu
@@ -264,7 +264,7 @@ def _bayesian_plot(
264264
# Plot observations for primary treated unit
265265
(h,) = ax[0].plot(
266266
self.datapre.index,
267-
self.datapre_treated.sel(treated_units=primary_unit_name),
267+
self.datapre_treated.sel(treated_units=treated_unit),
268268
"k.",
269269
label="Observations",
270270
)
@@ -283,42 +283,40 @@ def _bayesian_plot(
283283

284284
ax[0].plot(
285285
self.datapost.index,
286-
self.datapost_treated.sel(treated_units=primary_unit_name),
286+
self.datapost_treated.sel(treated_units=treated_unit),
287287
"k.",
288288
)
289289
# Shaded causal effect for primary treated unit
290290
h = ax[0].fill_between(
291291
self.datapost.index,
292292
y1=post_pred_plot.mean(dim=["chain", "draw"]).values,
293-
y2=self.datapost_treated.sel(treated_units=primary_unit_name).values,
293+
y2=self.datapost_treated.sel(treated_units=treated_unit).values,
294294
color="C2",
295295
alpha=0.25,
296296
label="Causal impact",
297297
)
298298
handles.append(h)
299299
labels.append("Causal impact")
300300

301-
ax[0].set(title=f"{self._get_score_title(round_to)}\nUnit")
301+
ax[0].set(title=f"{self._get_score_title(round_to)}")
302302

303303
# MIDDLE PLOT -----------------------------------------------
304304
plot_xY(
305305
self.datapre.index,
306-
self.pre_impact.sel(treated_units=primary_unit_name),
306+
self.pre_impact.sel(treated_units=treated_unit),
307307
ax=ax[1],
308308
plot_hdi_kwargs={"color": "C0"},
309309
)
310310
plot_xY(
311311
self.datapost.index,
312-
self.post_impact.sel(treated_units=primary_unit_name),
312+
self.post_impact.sel(treated_units=treated_unit),
313313
ax=ax[1],
314314
plot_hdi_kwargs={"color": "C1"},
315315
)
316316
ax[1].axhline(y=0, c="k")
317317
ax[1].fill_between(
318318
self.datapost.index,
319-
y1=self.post_impact.mean(["chain", "draw"]).sel(
320-
treated_units=primary_unit_name
321-
),
319+
y1=self.post_impact.mean(["chain", "draw"]).sel(treated_units=treated_unit),
322320
color="C0",
323321
alpha=0.25,
324322
label="Causal impact",
@@ -329,7 +327,7 @@ def _bayesian_plot(
329327
ax[2].set(title="Cumulative Causal Impact")
330328
plot_xY(
331329
self.datapost.index,
332-
self.post_impact_cumulative.sel(treated_units=primary_unit_name),
330+
self.post_impact_cumulative.sel(treated_units=treated_unit),
333331
ax=ax[2],
334332
plot_hdi_kwargs={"color": "C1"},
335333
)
@@ -385,25 +383,25 @@ def _ols_plot(
385383
counterfactual_label = "Counterfactual"
386384

387385
# Get treated unit name - default to first unit if None
388-
primary_unit_name = (
386+
treated_unit = (
389387
treated_unit if treated_unit is not None else self.treated_units[0]
390388
)
391389

392-
if primary_unit_name not in self.treated_units:
390+
if treated_unit not in self.treated_units:
393391
raise ValueError(
394-
f"treated_unit '{primary_unit_name}' not found. Available units: {self.treated_units}"
392+
f"treated_unit '{treated_unit}' not found. Available units: {self.treated_units}"
395393
)
396394

397395
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
398396

399397
ax[0].plot(
400398
self.datapre_treated["obs_ind"],
401-
self.datapre_treated.sel(treated_units=primary_unit_name),
399+
self.datapre_treated.sel(treated_units=treated_unit),
402400
"k.",
403401
)
404402
ax[0].plot(
405403
self.datapost_treated["obs_ind"],
406-
self.datapost_treated.sel(treated_units=primary_unit_name),
404+
self.datapost_treated.sel(treated_units=treated_unit),
407405
"k.",
408406
)
409407

@@ -415,7 +413,7 @@ def _ols_plot(
415413
ls=":",
416414
c="k",
417415
)
418-
ax[0].set(title=f"{self._get_score_title(round_to)}\nUnit")
416+
ax[0].set(title=f"{self._get_score_title(round_to)}")
419417
# Shaded causal effect - handle different prediction formats
420418
try:
421419
# For OLS, predictions might be simple arrays
@@ -435,9 +433,7 @@ def _ols_plot(
435433
ax[0].fill_between(
436434
self.datapost.index,
437435
y1=post_pred_values,
438-
y2=np.squeeze(
439-
self.datapost_treated.sel(treated_units=primary_unit_name).data
440-
),
436+
y2=np.squeeze(self.datapost_treated.sel(treated_units=treated_unit).data),
441437
color="C0",
442438
alpha=0.25,
443439
label="Causal impact",
@@ -521,13 +517,13 @@ def get_plot_data_bayesian(
521517
post_data = self.datapost.copy()
522518

523519
# Get treated unit name - default to first unit if None
524-
primary_unit_name = (
520+
treated_unit = (
525521
treated_unit if treated_unit is not None else self.treated_units[0]
526522
)
527523

528-
if primary_unit_name not in self.treated_units:
524+
if treated_unit not in self.treated_units:
529525
raise ValueError(
530-
f"treated_unit '{primary_unit_name}' not found. Available units: {self.treated_units}"
526+
f"treated_unit '{treated_unit}' not found. Available units: {self.treated_units}"
531527
)
532528

533529
# Extract predictions - handle multi-unit case
@@ -541,10 +537,10 @@ def get_plot_data_bayesian(
541537
if len(self.treated_units) > 1:
542538
# Multi-unit case: extract primary unit
543539
pre_data["prediction"] = pre_pred_vals.sel(
544-
treated_units=primary_unit_name
540+
treated_units=treated_unit
545541
).values
546542
post_data["prediction"] = post_pred_vals.sel(
547-
treated_units=primary_unit_name
543+
treated_units=treated_unit
548544
).values
549545
else:
550546
# Single unit case
@@ -555,13 +551,13 @@ def get_plot_data_bayesian(
555551
if len(self.treated_units) > 1:
556552
pre_hdi = get_hdi_to_df(
557553
self.pre_pred["posterior_predictive"].mu.sel(
558-
treated_units=primary_unit_name
554+
treated_units=treated_unit
559555
),
560556
hdi_prob=hdi_prob,
561557
)
562558
post_hdi = get_hdi_to_df(
563559
self.post_pred["posterior_predictive"].mu.sel(
564-
treated_units=primary_unit_name
560+
treated_units=treated_unit
565561
),
566562
hdi_prob=hdi_prob,
567563
)
@@ -583,21 +579,21 @@ def get_plot_data_bayesian(
583579
# Impact data - always use primary unit for main dataframe
584580
pre_data["impact"] = (
585581
self.pre_impact.mean(dim=["chain", "draw"])
586-
.sel(treated_units=primary_unit_name)
582+
.sel(treated_units=treated_unit)
587583
.values
588584
)
589585
post_data["impact"] = (
590586
self.post_impact.mean(dim=["chain", "draw"])
591-
.sel(treated_units=primary_unit_name)
587+
.sel(treated_units=treated_unit)
592588
.values
593589
)
594590
# Impact HDI intervals - use primary unit
595591
if len(self.treated_units) > 1:
596592
pre_impact_hdi = get_hdi_to_df(
597-
self.pre_impact.sel(treated_units=primary_unit_name), hdi_prob=hdi_prob
593+
self.pre_impact.sel(treated_units=treated_unit), hdi_prob=hdi_prob
598594
)
599595
post_impact_hdi = get_hdi_to_df(
600-
self.post_impact.sel(treated_units=primary_unit_name), hdi_prob=hdi_prob
596+
self.post_impact.sel(treated_units=treated_unit), hdi_prob=hdi_prob
601597
)
602598
else:
603599
pre_impact_hdi = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob)

0 commit comments

Comments
 (0)