Skip to content

Commit 7b473af

Browse files
committed
simplify plot code
1 parent 4fa1650 commit 7b473af

File tree

2 files changed

+69
-81
lines changed

2 files changed

+69
-81
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 63 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -213,32 +213,40 @@ def summary(self, round_to=None) -> None:
213213
self.print_coefficients(round_to)
214214

215215
def _bayesian_plot(
216-
self, round_to=None, treated_unit=None, **kwargs
216+
self, round_to=None, treated_unit: str | None = None, **kwargs
217217
) -> tuple[plt.Figure, List[plt.Axes]]:
218218
"""
219219
Plot the results for a specific treated unit
220220
221221
:param round_to:
222222
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
223223
:param treated_unit:
224-
Which treated unit to plot. Can be an integer index or string name.
224+
Which treated unit to plot. Must be a string name of the treated unit.
225225
If None, plots the first treated unit.
226226
"""
227227
counterfactual_label = "Counterfactual"
228228

229229
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
230230
# TOP PLOT --------------------------------------------------
231231
# pre-intervention period
232-
primary_unit_idx = self._get_primary_treated_unit_index(treated_unit)
233-
primary_unit_name = self.treated_units[primary_unit_idx]
232+
233+
# Get treated unit name - default to first unit if None
234+
primary_unit_name = (
235+
treated_unit if treated_unit is not None else self.treated_units[0]
236+
)
237+
238+
if primary_unit_name not in self.treated_units:
239+
raise ValueError(
240+
f"treated_unit '{primary_unit_name}' not found. Available units: {self.treated_units}"
241+
)
234242

235243
# For multi-unit, select primary unit for main plot
236244
if len(self.treated_units) > 1:
237-
pre_pred_plot = self.pre_pred["posterior_predictive"].mu.isel(
238-
treated_units=primary_unit_idx
245+
pre_pred_plot = self.pre_pred["posterior_predictive"].mu.sel(
246+
treated_units=primary_unit_name
239247
)
240-
post_pred_plot = self.post_pred["posterior_predictive"].mu.isel(
241-
treated_units=primary_unit_idx
248+
post_pred_plot = self.post_pred["posterior_predictive"].mu.sel(
249+
treated_units=primary_unit_name
242250
)
243251
else:
244252
pre_pred_plot = self.pre_pred["posterior_predictive"].mu
@@ -256,12 +264,12 @@ def _bayesian_plot(
256264
# Plot observations for primary treated unit
257265
(h,) = ax[0].plot(
258266
self.datapre.index,
259-
self.datapre_treated.isel(treated_units=primary_unit_idx),
267+
self.datapre_treated.sel(treated_units=primary_unit_name),
260268
"k.",
261-
label=f"Observations ({self.treated_units[primary_unit_idx]})",
269+
label=f"Observations ({primary_unit_name})",
262270
)
263271
handles.append(h)
264-
labels.append(f"Observations ({self.treated_units[primary_unit_idx]})")
272+
labels.append(f"Observations ({primary_unit_name})")
265273

266274
# post intervention period
267275
h_line, h_patch = plot_xY(
@@ -275,14 +283,14 @@ def _bayesian_plot(
275283

276284
ax[0].plot(
277285
self.datapost.index,
278-
self.datapost_treated.isel(treated_units=primary_unit_idx),
286+
self.datapost_treated.sel(treated_units=primary_unit_name),
279287
"k.",
280288
)
281289
# Shaded causal effect for primary treated unit
282290
h = ax[0].fill_between(
283291
self.datapost.index,
284292
y1=post_pred_plot.mean(dim=["chain", "draw"]).values,
285-
y2=self.datapost_treated.isel(treated_units=primary_unit_idx).values,
293+
y2=self.datapost_treated.sel(treated_units=primary_unit_name).values,
286294
color="C2",
287295
alpha=0.25,
288296
label="Causal impact",
@@ -295,21 +303,21 @@ def _bayesian_plot(
295303
# MIDDLE PLOT -----------------------------------------------
296304
plot_xY(
297305
self.datapre.index,
298-
self.pre_impact.sel(treated_units=self.treated_units[primary_unit_idx]),
306+
self.pre_impact.sel(treated_units=primary_unit_name),
299307
ax=ax[1],
300308
plot_hdi_kwargs={"color": "C0"},
301309
)
302310
plot_xY(
303311
self.datapost.index,
304-
self.post_impact.sel(treated_units=self.treated_units[primary_unit_idx]),
312+
self.post_impact.sel(treated_units=primary_unit_name),
305313
ax=ax[1],
306314
plot_hdi_kwargs={"color": "C1"},
307315
)
308316
ax[1].axhline(y=0, c="k")
309317
ax[1].fill_between(
310318
self.datapost.index,
311319
y1=self.post_impact.mean(["chain", "draw"]).sel(
312-
treated_units=self.treated_units[primary_unit_idx]
320+
treated_units=primary_unit_name
313321
),
314322
color="C0",
315323
alpha=0.25,
@@ -321,9 +329,7 @@ def _bayesian_plot(
321329
ax[2].set(title=f"Cumulative Causal Impact ({primary_unit_name})")
322330
plot_xY(
323331
self.datapost.index,
324-
self.post_impact_cumulative.sel(
325-
treated_units=self.treated_units[primary_unit_idx]
326-
),
332+
self.post_impact_cumulative.sel(treated_units=primary_unit_name),
327333
ax=ax[2],
328334
plot_hdi_kwargs={"color": "C1"},
329335
)
@@ -365,31 +371,39 @@ def _bayesian_plot(
365371
return fig, ax
366372

367373
def _ols_plot(
368-
self, round_to=None, treated_unit=None, **kwargs
374+
self, round_to=None, treated_unit: str | None = None, **kwargs
369375
) -> tuple[plt.Figure, List[plt.Axes]]:
370376
"""
371377
Plot the results for OLS model for a specific treated unit
372378
373379
:param round_to:
374380
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
375381
:param treated_unit:
376-
Which treated unit to plot. Can be an integer index or string name.
382+
Which treated unit to plot. Must be a string name of the treated unit.
377383
If None, plots the first treated unit.
378384
"""
379385
counterfactual_label = "Counterfactual"
380-
primary_unit_idx = self._get_primary_treated_unit_index(treated_unit)
381-
primary_unit_name = self.treated_units[primary_unit_idx]
386+
387+
# Get treated unit name - default to first unit if None
388+
primary_unit_name = (
389+
treated_unit if treated_unit is not None else self.treated_units[0]
390+
)
391+
392+
if primary_unit_name not in self.treated_units:
393+
raise ValueError(
394+
f"treated_unit '{primary_unit_name}' not found. Available units: {self.treated_units}"
395+
)
382396

383397
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
384398

385399
ax[0].plot(
386400
self.datapre_treated["obs_ind"],
387-
self.datapre_treated.isel(treated_units=primary_unit_idx),
401+
self.datapre_treated.sel(treated_units=primary_unit_name),
388402
"k.",
389403
)
390404
ax[0].plot(
391405
self.datapost_treated["obs_ind"],
392-
self.datapost_treated.isel(treated_units=primary_unit_idx),
406+
self.datapost_treated.sel(treated_units=primary_unit_name),
393407
"k.",
394408
)
395409

@@ -422,7 +436,7 @@ def _ols_plot(
422436
self.datapost.index,
423437
y1=post_pred_values,
424438
y2=np.squeeze(
425-
self.datapost_treated.isel(treated_units=primary_unit_idx).data
439+
self.datapost_treated.sel(treated_units=primary_unit_name).data
426440
),
427441
color="C0",
428442
alpha=0.25,
@@ -482,15 +496,15 @@ def get_plot_data_ols(self) -> pd.DataFrame:
482496
return self.plot_data
483497

484498
def get_plot_data_bayesian(
485-
self, hdi_prob: float = 0.94, treated_unit=None
499+
self, hdi_prob: float = 0.94, treated_unit: str | None = None
486500
) -> pd.DataFrame:
487501
"""
488502
Recover the data of the PrePostFit experiment along with the prediction and causal impact information.
489503
490504
:param hdi_prob:
491505
Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function.
492506
:param treated_unit:
493-
Which treated unit to extract data for. Can be an integer index or string name.
507+
Which treated unit to extract data for. Must be a string name of the treated unit.
494508
If None, uses the first treated unit.
495509
"""
496510
if not isinstance(self.model, PyMCModel):
@@ -506,8 +520,15 @@ def get_plot_data_bayesian(
506520
pre_data = self.datapre.copy()
507521
post_data = self.datapost.copy()
508522

509-
# Get primary treated unit index for data extraction
510-
primary_unit_idx = self._get_primary_treated_unit_index(treated_unit)
523+
# Get treated unit name - default to first unit if None
524+
primary_unit_name = (
525+
treated_unit if treated_unit is not None else self.treated_units[0]
526+
)
527+
528+
if primary_unit_name not in self.treated_units:
529+
raise ValueError(
530+
f"treated_unit '{primary_unit_name}' not found. Available units: {self.treated_units}"
531+
)
511532

512533
# Extract predictions - handle multi-unit case
513534
pre_pred_vals = az.extract(
@@ -519,11 +540,11 @@ def get_plot_data_bayesian(
519540

520541
if len(self.treated_units) > 1:
521542
# Multi-unit case: extract primary unit
522-
pre_data["prediction"] = pre_pred_vals.isel(
523-
treated_units=primary_unit_idx
543+
pre_data["prediction"] = pre_pred_vals.sel(
544+
treated_units=primary_unit_name
524545
).values
525-
post_data["prediction"] = post_pred_vals.isel(
526-
treated_units=primary_unit_idx
546+
post_data["prediction"] = post_pred_vals.sel(
547+
treated_units=primary_unit_name
527548
).values
528549
else:
529550
# Single unit case
@@ -533,14 +554,14 @@ def get_plot_data_bayesian(
533554
# HDI intervals for predictions
534555
if len(self.treated_units) > 1:
535556
pre_hdi = get_hdi_to_df(
536-
self.pre_pred["posterior_predictive"].mu.isel(
537-
treated_units=primary_unit_idx
557+
self.pre_pred["posterior_predictive"].mu.sel(
558+
treated_units=primary_unit_name
538559
),
539560
hdi_prob=hdi_prob,
540561
)
541562
post_hdi = get_hdi_to_df(
542-
self.post_pred["posterior_predictive"].mu.isel(
543-
treated_units=primary_unit_idx
563+
self.post_pred["posterior_predictive"].mu.sel(
564+
treated_units=primary_unit_name
544565
),
545566
hdi_prob=hdi_prob,
546567
)
@@ -562,21 +583,21 @@ def get_plot_data_bayesian(
562583
# Impact data - always use primary unit for main dataframe
563584
pre_data["impact"] = (
564585
self.pre_impact.mean(dim=["chain", "draw"])
565-
.isel(treated_units=primary_unit_idx)
586+
.sel(treated_units=primary_unit_name)
566587
.values
567588
)
568589
post_data["impact"] = (
569590
self.post_impact.mean(dim=["chain", "draw"])
570-
.isel(treated_units=primary_unit_idx)
591+
.sel(treated_units=primary_unit_name)
571592
.values
572593
)
573594
# Impact HDI intervals - use primary unit
574595
if len(self.treated_units) > 1:
575596
pre_impact_hdi = get_hdi_to_df(
576-
self.pre_impact.isel(treated_units=primary_unit_idx), hdi_prob=hdi_prob
597+
self.pre_impact.sel(treated_units=primary_unit_name), hdi_prob=hdi_prob
577598
)
578599
post_impact_hdi = get_hdi_to_df(
579-
self.post_impact.isel(treated_units=primary_unit_idx), hdi_prob=hdi_prob
600+
self.post_impact.sel(treated_units=primary_unit_name), hdi_prob=hdi_prob
580601
)
581602
else:
582603
pre_impact_hdi = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob)
@@ -617,30 +638,3 @@ def _get_score_title(self, round_to=None):
617638
else:
618639
# OLS model - score is typically a simple float
619640
return f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"
620-
621-
def _get_primary_treated_unit_index(self, treated_unit=None):
622-
"""Get the index for the treated unit to plot.
623-
624-
:param treated_unit: Optional. Either an integer index or string name of the treated unit.
625-
If None, defaults to the first treated unit (index 0).
626-
"""
627-
if treated_unit is None:
628-
return 0
629-
elif isinstance(treated_unit, int):
630-
if 0 <= treated_unit < len(self.treated_units):
631-
return treated_unit
632-
else:
633-
raise ValueError(
634-
f"treated_unit index {treated_unit} out of range. Valid range: 0-{len(self.treated_units) - 1}"
635-
)
636-
elif isinstance(treated_unit, str):
637-
if treated_unit in self.treated_units:
638-
return self.treated_units.index(treated_unit)
639-
else:
640-
raise ValueError(
641-
f"treated_unit '{treated_unit}' not found. Available units: {self.treated_units}"
642-
)
643-
else:
644-
raise ValueError(
645-
"treated_unit must be an integer index, string name, or None"
646-
)

causalpy/tests/test_multi_unit_sc.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,10 @@ def test_multi_unit_plotting(self, multi_unit_sc_data):
231231
# Test default (first unit)
232232
fig, ax = sc.plot()
233233

234-
# Test specific unit by index
235-
fig2, ax2 = sc.plot(treated_unit=1)
236-
237234
# Test specific unit by name
235+
fig2, ax2 = sc.plot(treated_unit="treated_1")
236+
237+
# Test another specific unit by name
238238
fig3, ax3 = sc.plot(treated_unit="treated_2")
239239

240240
# Check that we got the expected plot structure
@@ -287,16 +287,10 @@ def test_multi_unit_plotting_invalid_unit(self, multi_unit_sc_data):
287287
model=model,
288288
)
289289

290-
# Test invalid index
291-
with pytest.raises(ValueError, match="treated_unit index.*out of range"):
292-
sc.plot(treated_unit=10)
293-
294290
# Test invalid name
295291
with pytest.raises(ValueError, match="treated_unit.*not found"):
296292
sc.plot(treated_unit="invalid_unit")
297293

298-
# Test invalid type
299-
with pytest.raises(
300-
ValueError, match="treated_unit must be.*integer.*string.*None"
301-
):
302-
sc.plot(treated_unit=3.14)
294+
# Test another invalid name
295+
with pytest.raises(ValueError, match="treated_unit.*not found"):
296+
sc.plot(treated_unit="treated_10")

0 commit comments

Comments
 (0)