Skip to content

Commit 76ef685

Browse files
committed
Restore intterupte_time_series.py to its original version
1 parent e5f13ab commit 76ef685

File tree

1 file changed

+87
-33
lines changed

1 file changed

+87
-33
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def __init__(
120120
},
121121
)
122122
self.pre_y = xr.DataArray(
123-
self.pre_y[:, 0],
124-
dims=["obs_ind"],
125-
coords={"obs_ind": self.datapre.index},
123+
self.pre_y, # Keep 2D shape
124+
dims=["obs_ind", "treated_units"],
125+
coords={"obs_ind": self.datapre.index, "treated_units": ["unit_0"]},
126126
)
127127
self.post_X = xr.DataArray(
128128
self.post_X,
@@ -133,17 +133,22 @@ def __init__(
133133
},
134134
)
135135
self.post_y = xr.DataArray(
136-
self.post_y[:, 0],
137-
dims=["obs_ind"],
138-
coords={"obs_ind": self.datapost.index},
136+
self.post_y, # Keep 2D shape
137+
dims=["obs_ind", "treated_units"],
138+
coords={"obs_ind": self.datapost.index, "treated_units": ["unit_0"]},
139139
)
140140

141141
# fit the model to the observed (pre-intervention) data
142142
if isinstance(self.model, PyMCModel):
143-
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])}
143+
COORDS = {
144+
"coeffs": self.labels,
145+
"obs_ind": np.arange(self.pre_X.shape[0]),
146+
"treated_units": ["unit_0"],
147+
}
144148
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
145149
elif isinstance(self.model, RegressorMixin):
146-
self.model.fit(X=self.pre_X, y=self.pre_y)
150+
# For OLS models, use 1D y data
151+
self.model.fit(X=self.pre_X, y=self.pre_y.isel(treated_units=0))
147152
else:
148153
raise ValueError("Model type not recognized")
149154

@@ -155,8 +160,21 @@ def __init__(
155160

156161
# calculate the counterfactual
157162
self.post_pred = self.model.predict(X=self.post_X)
158-
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
159-
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)
163+
164+
# calculate impact - use appropriate y data format for each model type
165+
if isinstance(self.model, PyMCModel):
166+
# PyMC models work with 2D data
167+
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
168+
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)
169+
elif isinstance(self.model, RegressorMixin):
170+
# SKL models work with 1D data
171+
self.pre_impact = self.model.calculate_impact(
172+
self.pre_y.isel(treated_units=0), self.pre_pred
173+
)
174+
self.post_impact = self.model.calculate_impact(
175+
self.post_y.isel(treated_units=0), self.post_pred
176+
)
177+
160178
self.post_impact_cumulative = self.model.calculate_cumulative_impact(
161179
self.post_impact
162180
)
@@ -202,35 +220,53 @@ def _bayesian_plot(
202220
# pre-intervention period
203221
h_line, h_patch = plot_xY(
204222
self.datapre.index,
205-
self.pre_pred["posterior_predictive"].mu,
223+
self.pre_pred["posterior_predictive"].mu.isel(treated_units=0),
206224
ax=ax[0],
207225
plot_hdi_kwargs={"color": "C0"},
208226
)
209227
handles = [(h_line, h_patch)]
210228
labels = ["Pre-intervention period"]
211229

212-
(h,) = ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
230+
(h,) = ax[0].plot(
231+
self.datapre.index,
232+
self.pre_y.isel(treated_units=0)
233+
if hasattr(self.pre_y, "isel")
234+
else self.pre_y[:, 0],
235+
"k.",
236+
label="Observations",
237+
)
213238
handles.append(h)
214239
labels.append("Observations")
215240

216241
# post intervention period
217242
h_line, h_patch = plot_xY(
218243
self.datapost.index,
219-
self.post_pred["posterior_predictive"].mu,
244+
self.post_pred["posterior_predictive"].mu.isel(treated_units=0),
220245
ax=ax[0],
221246
plot_hdi_kwargs={"color": "C1"},
222247
)
223248
handles.append((h_line, h_patch))
224249
labels.append(counterfactual_label)
225250

226-
ax[0].plot(self.datapost.index, self.post_y, "k.")
251+
ax[0].plot(
252+
self.datapost.index,
253+
self.post_y.isel(treated_units=0)
254+
if hasattr(self.post_y, "isel")
255+
else self.post_y[:, 0],
256+
"k.",
257+
)
227258
# Shaded causal effect
259+
post_pred_mu = (
260+
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
261+
.isel(treated_units=0)
262+
.mean("sample")
263+
) # Add .mean("sample") to get 1D array
228264
h = ax[0].fill_between(
229265
self.datapost.index,
230-
y1=az.extract(
231-
self.post_pred, group="posterior_predictive", var_names="mu"
232-
).mean("sample"),
233-
y2=np.squeeze(self.post_y),
266+
y1=post_pred_mu,
267+
y2=self.post_y.isel(treated_units=0)
268+
if hasattr(self.post_y, "isel")
269+
else self.post_y[:, 0],
234270
color="C0",
235271
alpha=0.25,
236272
)
@@ -239,28 +275,28 @@ def _bayesian_plot(
239275

240276
ax[0].set(
241277
title=f"""
242-
Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)}
243-
(std = {round_num(self.score.r2_std, round_to)})
278+
Pre-intervention Bayesian $R^2$: {round_num(self.score["unit_0_r2"], round_to)}
279+
(std = {round_num(self.score["unit_0_r2_std"], round_to)})
244280
"""
245281
)
246282

247283
# MIDDLE PLOT -----------------------------------------------
248284
plot_xY(
249285
self.datapre.index,
250-
self.pre_impact,
286+
self.pre_impact.isel(treated_units=0),
251287
ax=ax[1],
252288
plot_hdi_kwargs={"color": "C0"},
253289
)
254290
plot_xY(
255291
self.datapost.index,
256-
self.post_impact,
292+
self.post_impact.isel(treated_units=0),
257293
ax=ax[1],
258294
plot_hdi_kwargs={"color": "C1"},
259295
)
260296
ax[1].axhline(y=0, c="k")
261297
ax[1].fill_between(
262298
self.datapost.index,
263-
y1=self.post_impact.mean(["chain", "draw"]),
299+
y1=self.post_impact.mean(["chain", "draw"]).isel(treated_units=0),
264300
color="C0",
265301
alpha=0.25,
266302
label="Causal impact",
@@ -271,7 +307,7 @@ def _bayesian_plot(
271307
ax[2].set(title="Cumulative Causal Impact")
272308
plot_xY(
273309
self.datapost.index,
274-
self.post_impact_cumulative,
310+
self.post_impact_cumulative.isel(treated_units=0),
275311
ax=ax[2],
276312
plot_hdi_kwargs={"color": "C1"},
277313
)
@@ -387,27 +423,45 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
387423
pre_data["prediction"] = (
388424
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
389425
.mean("sample")
426+
.isel(treated_units=0)
390427
.values
391428
)
392429
post_data["prediction"] = (
393430
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
394431
.mean("sample")
432+
.isel(treated_units=0)
395433
.values
396434
)
397-
pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
435+
hdi_pre_pred = get_hdi_to_df(
398436
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
399-
).set_index(pre_data.index)
400-
post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
437+
)
438+
hdi_post_pred = get_hdi_to_df(
401439
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
440+
)
441+
# Select the single unit from the MultiIndex results
442+
pre_data[[pred_lower_col, pred_upper_col]] = hdi_pre_pred.xs(
443+
"unit_0", level="treated_units"
444+
).set_index(pre_data.index)
445+
post_data[[pred_lower_col, pred_upper_col]] = hdi_post_pred.xs(
446+
"unit_0", level="treated_units"
402447
).set_index(post_data.index)
403448

404-
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
405-
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
406-
pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
407-
self.pre_impact, hdi_prob=hdi_prob
449+
pre_data["impact"] = (
450+
self.pre_impact.mean(dim=["chain", "draw"]).isel(treated_units=0).values
451+
)
452+
post_data["impact"] = (
453+
self.post_impact.mean(dim=["chain", "draw"])
454+
.isel(treated_units=0)
455+
.values
456+
)
457+
hdi_pre_impact = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob)
458+
hdi_post_impact = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob)
459+
# Select the single unit from the MultiIndex results
460+
pre_data[[impact_lower_col, impact_upper_col]] = hdi_pre_impact.xs(
461+
"unit_0", level="treated_units"
408462
).set_index(pre_data.index)
409-
post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
410-
self.post_impact, hdi_prob=hdi_prob
463+
post_data[[impact_lower_col, impact_upper_col]] = hdi_post_impact.xs(
464+
"unit_0", level="treated_units"
411465
).set_index(post_data.index)
412466

413467
self.plot_data = pd.concat([pre_data, post_data])

0 commit comments

Comments
 (0)