Skip to content

Commit 5c849bb

Browse files
committed
use set_xindex for cleaner code :)
1 parent 37bee48 commit 5c849bb

File tree

1 file changed

+104
-74
lines changed

1 file changed

+104
-74
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 104 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -123,59 +123,55 @@ def algorithm(self) -> None:
123123
if isinstance(self.model, PyMCModel):
124124
# PyMC models expect xarray DataArrays
125125
self.predictions = self.model.predict(X=self.data.X)
126-
# Add period coordinate to predictions - key insight for unified operations!
126+
# Add period coordinate to predictions - InferenceData handles multiple data arrays
127127
self.predictions = self.predictions.assign_coords(
128128
period=("obs_ind", self.data.period.data)
129129
)
130130
else:
131131
# Sklearn models expect numpy arrays
132132
pred_array = self.model.predict(X=self.data.X.values)
133-
# Create xarray DataArray with period coordinate for unified operations
133+
# Create xarray DataArray with period coordinate
134134
self.predictions = xr.DataArray(
135135
pred_array,
136136
dims=["obs_ind"],
137137
coords={
138138
"obs_ind": self.data.obs_ind,
139139
"period": ("obs_ind", self.data.period.data),
140140
},
141-
)
141+
).set_xindex("period")
142142

143-
# 4. Use native xarray operations on unified predictions with period coordinate
144-
# No more manual indexing - leverage xarray's .where() operations!
143+
# 4. Calculate unified impact with period coordinate - no more splitting!
145144
if isinstance(self.model, PyMCModel):
146-
# For PyMC models, use .where() on the posterior_predictive dataset
147-
pp = self.predictions.posterior_predictive
148-
pre_pp = pp.where(pp.period == "pre", drop=True)
149-
post_pp = pp.where(pp.period == "post", drop=True)
150-
151-
# Create new InferenceData objects for pre/post with the filtered data
152-
import arviz as az
153-
154-
self.pre_pred = az.InferenceData(posterior_predictive=pre_pp)
155-
self.post_pred = az.InferenceData(posterior_predictive=post_pp)
156-
157-
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
158-
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)
145+
# Calculate impact for the entire time series at once
146+
self.impact = self.model.calculate_impact(self.data.y, self.predictions)
147+
# Assign period coordinate to unified impact and set index
148+
self.impact = self.impact.assign_coords(
149+
period=("obs_ind", self.data.period.data)
150+
).set_xindex("period")
159151
else:
160-
# For sklearn models, same clean .where() approach
161-
self.pre_pred = self.predictions.where(
162-
self.predictions.period == "pre", drop=True
163-
)
164-
self.post_pred = self.predictions.where(
165-
self.predictions.period == "post", drop=True
166-
)
152+
# For sklearn: calculate unified impact as DataArray
153+
observed_values = self.data.y.isel(treated_units=0).values
154+
predicted_values = self.predictions.values
155+
impact_values = observed_values - predicted_values
167156

168-
self.pre_impact = self.model.calculate_impact(
169-
self.pre_y.isel(treated_units=0), self.pre_pred
170-
)
171-
self.post_impact = self.model.calculate_impact(
172-
self.post_y.isel(treated_units=0), self.post_pred
173-
)
157+
self.impact = xr.DataArray(
158+
impact_values,
159+
dims=["obs_ind"],
160+
coords={
161+
"obs_ind": self.data.obs_ind,
162+
"period": ("obs_ind", self.data.period.data),
163+
},
164+
).set_xindex("period")
174165

175-
# 4b. Calculate cumulative impact
176-
self.post_impact_cumulative = self.model.calculate_cumulative_impact(
177-
self.post_impact
178-
)
166+
# 5. Calculate cumulative impact (only on post-intervention period)
167+
post_impact = self.impact.sel(period="post")
168+
if isinstance(self.model, PyMCModel):
169+
self.post_impact_cumulative = self.model.calculate_cumulative_impact(
170+
post_impact
171+
)
172+
else:
173+
# For sklearn: simple cumulative sum
174+
self.post_impact_cumulative = post_impact.cumsum()
179175

180176
def _build_data(self, data: pd.DataFrame) -> xr.Dataset:
181177
"""Build the experiment dataset as unified time series with period coordinate."""
@@ -198,6 +194,7 @@ def _build_data(self, data: pd.DataFrame) -> xr.Dataset:
198194
coords={
199195
"obs_ind": data.index,
200196
"coeffs": self.labels,
197+
"period": ("obs_ind", period_coord),
201198
},
202199
)
203200

@@ -207,35 +204,47 @@ def _build_data(self, data: pd.DataFrame) -> xr.Dataset:
207204
coords={
208205
"obs_ind": data.index,
209206
"treated_units": ["unit_0"],
207+
"period": ("obs_ind", period_coord),
210208
},
211209
)
212210

213-
# Create dataset and add period as a coordinate
211+
# Create dataset and use set_xindex to make period selectable with .sel()
214212
dataset = xr.Dataset({"X": X_array, "y": y_array})
215-
dataset = dataset.assign_coords(period=("obs_ind", period_coord))
213+
dataset = dataset.set_xindex("period")
216214

217215
return dataset
218216

219217
# Properties for pre/post intervention data access
220218
@property
221219
def pre_X(self) -> xr.DataArray:
222220
"""Pre-intervention features."""
223-
return self.data.X.where(self.data.period == "pre", drop=True)
221+
return self.data.X.sel(period="pre")
224222

225223
@property
226224
def pre_y(self) -> xr.DataArray:
227225
"""Pre-intervention outcomes."""
228-
return self.data.y.where(self.data.period == "pre", drop=True)
226+
return self.data.y.sel(period="pre")
229227

230228
@property
231229
def post_X(self) -> xr.DataArray:
232230
"""Post-intervention features."""
233-
return self.data.X.where(self.data.period == "post", drop=True)
231+
return self.data.X.sel(period="post")
234232

235233
@property
236234
def post_y(self) -> xr.DataArray:
237235
"""Post-intervention outcomes."""
238-
return self.data.y.where(self.data.period == "post", drop=True)
236+
return self.data.y.sel(period="post")
237+
238+
# Simple backward-compatible properties for impact only (still used in plotting)
239+
@property
240+
def pre_impact(self):
241+
"""Pre-intervention impact (backward compatibility)."""
242+
return self.impact.sel(period="pre")
243+
244+
@property
245+
def post_impact(self):
246+
"""Post-intervention impact (backward compatibility)."""
247+
return self.impact.sel(period="post")
239248

240249
def input_validation(self, data, treatment_time):
241250
"""Validate the input data and model formula for correctness"""
@@ -266,19 +275,29 @@ def _bayesian_plot(
266275
self, round_to=None, **kwargs
267276
) -> tuple[plt.Figure, List[plt.Axes]]:
268277
"""
269-
Plot the results
278+
Plot the results using unified predictions with period coordinates.
270279
271280
:param round_to:
272281
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
273282
"""
274283
counterfactual_label = "Counterfactual"
275284

276285
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
286+
287+
# Extract pre/post predictions - InferenceData doesn't support .sel() with period
288+
# but .where() works fine with coordinates
289+
pre_pred = self.predictions["posterior_predictive"].where(
290+
self.predictions["posterior_predictive"].period == "pre", drop=True
291+
)
292+
post_pred = self.predictions["posterior_predictive"].where(
293+
self.predictions["posterior_predictive"].period == "post", drop=True
294+
)
295+
277296
# TOP PLOT --------------------------------------------------
278297
# pre-intervention period
279298
h_line, h_patch = plot_xY(
280299
self.pre_X.obs_ind,
281-
self.pre_pred["posterior_predictive"].mu.isel(treated_units=0),
300+
pre_pred.mu.isel(treated_units=0),
282301
ax=ax[0],
283302
plot_hdi_kwargs={"color": "C0"},
284303
)
@@ -287,9 +306,7 @@ def _bayesian_plot(
287306

288307
(h,) = ax[0].plot(
289308
self.pre_X.obs_ind,
290-
self.pre_y.isel(treated_units=0)
291-
if hasattr(self.pre_y, "isel")
292-
else self.pre_y[:, 0],
309+
self.pre_y.isel(treated_units=0),
293310
"k.",
294311
label="Observations",
295312
)
@@ -299,7 +316,7 @@ def _bayesian_plot(
299316
# post intervention period
300317
h_line, h_patch = plot_xY(
301318
self.post_X.obs_ind,
302-
self.post_pred["posterior_predictive"].mu.isel(treated_units=0),
319+
post_pred.mu.isel(treated_units=0),
303320
ax=ax[0],
304321
plot_hdi_kwargs={"color": "C1"},
305322
)
@@ -308,23 +325,16 @@ def _bayesian_plot(
308325

309326
ax[0].plot(
310327
self.post_X.obs_ind,
311-
self.post_y.isel(treated_units=0)
312-
if hasattr(self.post_y, "isel")
313-
else self.post_y[:, 0],
328+
self.post_y.isel(treated_units=0),
314329
"k.",
315330
)
316-
# Shaded causal effect
317-
post_pred_mu = (
318-
self.post_pred["posterior_predictive"]
319-
.mu.mean(dim=["chain", "draw"])
320-
.isel(treated_units=0)
321-
)
331+
332+
# Shaded causal effect - use direct calculation
333+
post_pred_mu = post_pred.mu.mean(dim=["chain", "draw"]).isel(treated_units=0)
322334
h = ax[0].fill_between(
323335
self.post_X.obs_ind,
324336
y1=post_pred_mu,
325-
y2=self.post_y.isel(treated_units=0)
326-
if hasattr(self.post_y, "isel")
327-
else self.post_y[:, 0],
337+
y2=self.post_y.isel(treated_units=0),
328338
color="C0",
329339
alpha=0.25,
330340
)
@@ -390,7 +400,7 @@ def _bayesian_plot(
390400

391401
def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
392402
"""
393-
Plot the results
403+
Plot the results using unified predictions with period coordinates.
394404
395405
:param round_to:
396406
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
@@ -399,13 +409,37 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
399409

400410
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
401411

412+
# Extract pre/post predictions - handle PyMC vs sklearn differently
413+
if isinstance(self.model, PyMCModel):
414+
# For PyMC models, predictions is InferenceData - use .where() with coordinates
415+
pre_pred = (
416+
self.predictions["posterior_predictive"]
417+
.where(
418+
self.predictions["posterior_predictive"].period == "pre", drop=True
419+
)
420+
.mu.mean(dim=["chain", "draw"])
421+
.isel(treated_units=0)
422+
)
423+
post_pred = (
424+
self.predictions["posterior_predictive"]
425+
.where(
426+
self.predictions["posterior_predictive"].period == "post", drop=True
427+
)
428+
.mu.mean(dim=["chain", "draw"])
429+
.isel(treated_units=0)
430+
)
431+
else:
432+
# For sklearn models, predictions is DataArray - use .sel() with indexed coordinates
433+
pre_pred = self.predictions.sel(period="pre")
434+
post_pred = self.predictions.sel(period="post")
435+
402436
ax[0].plot(self.pre_X.obs_ind, self.pre_y, "k.")
403437
ax[0].plot(self.post_X.obs_ind, self.post_y, "k.")
404438

405-
ax[0].plot(self.pre_X.obs_ind, self.pre_pred, c="k", label="model fit")
439+
ax[0].plot(self.pre_X.obs_ind, pre_pred, c="k", label="model fit")
406440
ax[0].plot(
407441
self.post_X.obs_ind,
408-
self.post_pred,
442+
post_pred,
409443
label=counterfactual_label,
410444
ls=":",
411445
c="k",
@@ -431,7 +465,7 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
431465
# Shaded causal effect
432466
ax[0].fill_between(
433467
self.post_X.obs_ind,
434-
y1=np.squeeze(self.post_pred),
468+
y1=np.squeeze(post_pred),
435469
y2=np.squeeze(self.post_y),
436470
color="C0",
437471
alpha=0.25,
@@ -482,20 +516,18 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
482516
pred_mu = self.predictions["posterior_predictive"].mu.isel(treated_units=0)
483517
plot_data["prediction"] = pred_mu.mean(dim=["chain", "draw"]).values
484518

485-
# Calculate impact directly from unified data
486-
observed = self.data.y.isel(treated_units=0)
487-
predicted = pred_mu.mean(dim=["chain", "draw"])
488-
plot_data["impact"] = (observed - predicted).values
519+
# Extract impact directly from unified impact - no more calculation needed!
520+
plot_data["impact"] = (
521+
self.impact.mean(dim=["chain", "draw"]).isel(treated_units=0).values
522+
)
489523

490524
# Calculate HDI bounds directly using arviz
491525
import arviz as az
492526

493527
pred_hdi = az.hdi(pred_mu, hdi_prob=hdi_prob)
494-
impact_data = observed - pred_mu
495-
impact_hdi = az.hdi(impact_data, hdi_prob=hdi_prob)
528+
impact_hdi = az.hdi(self.impact.isel(treated_units=0), hdi_prob=hdi_prob)
496529

497530
# Extract HDI bounds from xarray Dataset results
498-
# Use the actual variable name that arviz creates (usually the first data variable)
499531
pred_var_name = list(pred_hdi.data_vars.keys())[0]
500532
impact_var_name = list(impact_hdi.data_vars.keys())[0]
501533

@@ -520,11 +552,9 @@ def get_plot_data_ols(self) -> pd.DataFrame:
520552
index=self.data.y.obs_ind.values,
521553
)
522554

523-
# With unified predictions, extract values directly (no more reconstruction needed!)
555+
# Extract directly from unified data structures - ultimate simplification!
524556
plot_data["prediction"] = self.predictions.values
525-
plot_data["impact"] = (
526-
self.data.y.isel(treated_units=0) - self.predictions
527-
).values
557+
plot_data["impact"] = self.impact.values
528558

529559
self.plot_data = plot_data
530560
return self.plot_data

0 commit comments

Comments
 (0)