Skip to content

Commit 51e19f8

Browse files
committed
simplify _build_data a bit + remove accessor methods
1 parent 85b6939 commit 51e19f8

File tree

2 files changed

+68
-76
lines changed

2 files changed

+68
-76
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 65 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,27 @@ def algorithm(self) -> None:
105105
if isinstance(self.model, PyMCModel):
106106
COORDS = {
107107
"coeffs": self.labels,
108-
"obs_ind": np.arange(self.pre_X.shape[0]),
108+
"obs_ind": np.arange(self.data.X.sel(period="pre").shape[0]),
109109
"treated_units": ["unit_0"],
110110
}
111-
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
111+
self.model.fit(
112+
X=self.data.X.sel(period="pre"),
113+
y=self.data.y.sel(period="pre"),
114+
coords=COORDS,
115+
)
112116
elif isinstance(self.model, RegressorMixin):
113117
# For OLS models, use 1D y data
114-
self.model.fit(X=self.pre_X, y=self.pre_y.isel(treated_units=0))
118+
self.model.fit(
119+
X=self.data.X.sel(period="pre"),
120+
y=self.data.y.sel(period="pre").isel(treated_units=0),
121+
)
115122
else:
116123
raise ValueError("Model type not recognized")
117124

118125
# 2. Score the goodness of fit to the pre-intervention data
119-
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
126+
self.score = self.model.score(
127+
X=self.data.X.sel(period="pre"), y=self.data.y.sel(period="pre")
128+
)
120129

121130
# 3. Generate predictions for the full dataset using unified approach
122131
# This creates predictions aligned with our complete time series
@@ -187,53 +196,26 @@ def _build_data(self, data: pd.DataFrame) -> xr.Dataset:
187196
# Create period coordinate based on treatment time
188197
period_coord = xr.where(data.index < self.treatment_time, "pre", "post")
189198

190-
# Return complete time series as a single xarray Dataset
191-
X_array = xr.DataArray(
192-
np.asarray(X_full),
193-
dims=["obs_ind", "coeffs"],
194-
coords={
195-
"obs_ind": data.index,
196-
"coeffs": self.labels,
197-
"period": ("obs_ind", period_coord),
198-
},
199-
)
200-
201-
y_array = xr.DataArray(
202-
np.asarray(y_full),
203-
dims=["obs_ind", "treated_units"],
204-
coords={
205-
"obs_ind": data.index,
206-
"treated_units": ["unit_0"],
207-
"period": ("obs_ind", period_coord),
208-
},
209-
)
210-
211-
# Create dataset and use set_xindex to make period selectable with .sel()
212-
dataset = xr.Dataset({"X": X_array, "y": y_array})
213-
dataset = dataset.set_xindex("period")
214-
215-
return dataset
216-
217-
# Properties for pre/post intervention data access
218-
@property
219-
def pre_X(self) -> xr.DataArray:
220-
"""Pre-intervention features."""
221-
return self.data.X.sel(period="pre")
222-
223-
@property
224-
def pre_y(self) -> xr.DataArray:
225-
"""Pre-intervention outcomes."""
226-
return self.data.y.sel(period="pre")
227-
228-
@property
229-
def post_X(self) -> xr.DataArray:
230-
"""Post-intervention features."""
231-
return self.data.X.sel(period="post")
232-
233-
@property
234-
def post_y(self) -> xr.DataArray:
235-
"""Post-intervention outcomes."""
236-
return self.data.y.sel(period="post")
199+
# Return as a xarray.Dataset
200+
common_coords = {
201+
"obs_ind": data.index,
202+
"period": ("obs_ind", period_coord),
203+
}
204+
205+
return xr.Dataset(
206+
{
207+
"X": xr.DataArray(
208+
np.asarray(X_full),
209+
dims=["obs_ind", "coeffs"],
210+
coords={**common_coords, "coeffs": self.labels},
211+
),
212+
"y": xr.DataArray(
213+
np.asarray(y_full),
214+
dims=["obs_ind", "treated_units"],
215+
coords={**common_coords, "treated_units": ["unit_0"]},
216+
),
217+
}
218+
).set_xindex("period")
237219

238220
def input_validation(self, data, treatment_time):
239221
"""Validate the input data and model formula for correctness"""
@@ -285,7 +267,7 @@ def _bayesian_plot(
285267
# TOP PLOT --------------------------------------------------
286268
# pre-intervention period
287269
h_line, h_patch = plot_xY(
288-
self.pre_X.obs_ind,
270+
self.data.X.sel(period="pre").obs_ind,
289271
pre_pred.mu.isel(treated_units=0),
290272
ax=ax[0],
291273
plot_hdi_kwargs={"color": "C0"},
@@ -294,8 +276,8 @@ def _bayesian_plot(
294276
labels = ["Pre-intervention period"]
295277

296278
(h,) = ax[0].plot(
297-
self.pre_X.obs_ind,
298-
self.pre_y.isel(treated_units=0),
279+
self.data.X.sel(period="pre").obs_ind,
280+
self.data.y.sel(period="pre").isel(treated_units=0),
299281
"k.",
300282
label="Observations",
301283
)
@@ -304,7 +286,7 @@ def _bayesian_plot(
304286

305287
# post intervention period
306288
h_line, h_patch = plot_xY(
307-
self.post_X.obs_ind,
289+
self.data.X.sel(period="post").obs_ind,
308290
post_pred.mu.isel(treated_units=0),
309291
ax=ax[0],
310292
plot_hdi_kwargs={"color": "C1"},
@@ -313,17 +295,17 @@ def _bayesian_plot(
313295
labels.append(counterfactual_label)
314296

315297
ax[0].plot(
316-
self.post_X.obs_ind,
317-
self.post_y.isel(treated_units=0),
298+
self.data.X.sel(period="post").obs_ind,
299+
self.data.y.sel(period="post").isel(treated_units=0),
318300
"k.",
319301
)
320302

321303
# Shaded causal effect - use direct calculation
322304
post_pred_mu = post_pred.mu.mean(dim=["chain", "draw"]).isel(treated_units=0)
323305
h = ax[0].fill_between(
324-
self.post_X.obs_ind,
306+
self.data.X.sel(period="post").obs_ind,
325307
y1=post_pred_mu,
326-
y2=self.post_y.isel(treated_units=0),
308+
y2=self.data.y.sel(period="post").isel(treated_units=0),
327309
color="C0",
328310
alpha=0.25,
329311
)
@@ -339,20 +321,20 @@ def _bayesian_plot(
339321

340322
# MIDDLE PLOT -----------------------------------------------
341323
plot_xY(
342-
self.pre_X.obs_ind,
324+
self.data.X.sel(period="pre").obs_ind,
343325
self.impact.sel(period="pre").isel(treated_units=0),
344326
ax=ax[1],
345327
plot_hdi_kwargs={"color": "C0"},
346328
)
347329
plot_xY(
348-
self.post_X.obs_ind,
330+
self.data.X.sel(period="post").obs_ind,
349331
self.impact.sel(period="post").isel(treated_units=0),
350332
ax=ax[1],
351333
plot_hdi_kwargs={"color": "C1"},
352334
)
353335
ax[1].axhline(y=0, c="k")
354336
ax[1].fill_between(
355-
self.post_X.obs_ind,
337+
self.data.X.sel(period="post").obs_ind,
356338
y1=self.impact.sel(period="post")
357339
.mean(["chain", "draw"])
358340
.isel(treated_units=0),
@@ -365,7 +347,7 @@ def _bayesian_plot(
365347
# BOTTOM PLOT -----------------------------------------------
366348
ax[2].set(title="Cumulative Causal Impact")
367349
plot_xY(
368-
self.post_X.obs_ind,
350+
self.data.X.sel(period="post").obs_ind,
369351
self.post_impact_cumulative.isel(treated_units=0),
370352
ax=ax[2],
371353
plot_hdi_kwargs={"color": "C1"},
@@ -424,12 +406,18 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
424406
pre_pred = self.predictions.sel(period="pre")
425407
post_pred = self.predictions.sel(period="post")
426408

427-
ax[0].plot(self.pre_X.obs_ind, self.pre_y, "k.")
428-
ax[0].plot(self.post_X.obs_ind, self.post_y, "k.")
409+
ax[0].plot(
410+
self.data.X.sel(period="pre").obs_ind, self.data.y.sel(period="pre"), "k."
411+
)
412+
ax[0].plot(
413+
self.data.X.sel(period="post").obs_ind, self.data.y.sel(period="post"), "k."
414+
)
429415

430-
ax[0].plot(self.pre_X.obs_ind, pre_pred, c="k", label="model fit")
431416
ax[0].plot(
432-
self.post_X.obs_ind,
417+
self.data.X.sel(period="pre").obs_ind, pre_pred, c="k", label="model fit"
418+
)
419+
ax[0].plot(
420+
self.data.X.sel(period="post").obs_ind,
433421
post_pred,
434422
label=counterfactual_label,
435423
ls=":",
@@ -439,31 +427,35 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
439427
title=f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"
440428
)
441429

442-
ax[1].plot(self.pre_X.obs_ind, self.impact.sel(period="pre"), "k.")
443430
ax[1].plot(
444-
self.post_X.obs_ind,
431+
self.data.X.sel(period="pre").obs_ind, self.impact.sel(period="pre"), "k."
432+
)
433+
ax[1].plot(
434+
self.data.X.sel(period="post").obs_ind,
445435
self.impact.sel(period="post"),
446436
"k.",
447437
label=counterfactual_label,
448438
)
449439
ax[1].axhline(y=0, c="k")
450440
ax[1].set(title="Causal Impact")
451441

452-
ax[2].plot(self.post_X.obs_ind, self.post_impact_cumulative, c="k")
442+
ax[2].plot(
443+
self.data.X.sel(period="post").obs_ind, self.post_impact_cumulative, c="k"
444+
)
453445
ax[2].axhline(y=0, c="k")
454446
ax[2].set(title="Cumulative Causal Impact")
455447

456448
# Shaded causal effect
457449
ax[0].fill_between(
458-
self.post_X.obs_ind,
450+
self.data.X.sel(period="post").obs_ind,
459451
y1=np.squeeze(post_pred),
460-
y2=np.squeeze(self.post_y),
452+
y2=np.squeeze(self.data.y.sel(period="post")),
461453
color="C0",
462454
alpha=0.25,
463455
label="Causal impact",
464456
)
465457
ax[1].fill_between(
466-
self.post_X.obs_ind,
458+
self.data.X.sel(period="post").obs_ind,
467459
y1=np.squeeze(self.impact.sel(period="post")),
468460
color="C0",
469461
alpha=0.25,

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)