Skip to content

Commit 2996331

Browse files
committed
codespell
1 parent 64c97b7 commit 2996331

File tree

4 files changed

+585
-226
lines changed

4 files changed

+585
-226
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 188 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,154 @@
3333
LEGEND_FONT_SIZE = 12
3434

3535

36+
class HandlerUTT:
37+
"""
38+
Handle data preprocessing, postprocessing, and plotting steps for models
39+
with unknown treatment intervention times.
40+
"""
41+
42+
def data_preprocessing(self, data, treatment_time, formula, model):
43+
"""
44+
Preprocess the data using patsy for fittng into the model and update the model with required infos
45+
"""
46+
y, X = dmatrices(formula, data)
47+
# Restrict model's treatment time inference to given range
48+
model.set_time_range(treatment_time, data)
49+
# Needed to track time evolution across model predictions
50+
model.set_timeline(X.design_info.column_names.index("t"))
51+
return y, X
52+
53+
def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
54+
"""
55+
Postprocess the data accordingly to the inferred treatment time for calculation and plot purpose
56+
"""
57+
# Retrieve posterior mean of inferred treatment time
58+
treatment_time_mean = idata.posterior["treatment_time"].mean().item()
59+
inferred_time = int(treatment_time_mean)
60+
61+
# Safety check: ensure the inferred time is present in the dataset
62+
if inferred_time not in data["t"].values:
63+
raise ValueError(
64+
f"Inferred treatment time {inferred_time} not found in data['t']."
65+
)
66+
67+
# Convert the inferred time to its corresponding DataFrame index
68+
inferred_index = data[data["t"] == inferred_time].index[0]
69+
70+
# Retrieve HDI bounds of treatment time (uncertainty interval)
71+
hdi_bounds = az.hdi(idata, var_names=["treatment_time"])[
72+
"treatment_time"
73+
].values
74+
hdi_start_time = int(hdi_bounds[0])
75+
76+
# Convert HDI lower bound to DataFrame index for slicing
77+
if hdi_start_time not in data["t"].values:
78+
raise ValueError(f"HDI start time {hdi_start_time} not found in data['t'].")
79+
80+
hdi_start_idx_df = data[data["t"] == hdi_start_time].index[0]
81+
hdi_start_idx_np = data.index.get_loc(hdi_start_idx_df)
82+
83+
# Slice both pandas and numpy objects accordingly
84+
df_pre = data[data.index < hdi_start_idx_df]
85+
df_post = data[data.index >= hdi_start_idx_df]
86+
truncated_y = pre_y[:hdi_start_idx_np]
87+
truncated_X = pre_X[:hdi_start_idx_np]
88+
89+
return df_pre, df_post, truncated_y, truncated_X, inferred_index
90+
91+
def plot_intervention_line(self, ax, idata, datapost, treatment_time):
92+
"""
93+
Plot a vertical line at the inferred treatment time, along with a shaded area
94+
representing the Highest Density Interval (HDI) of the inferred time.
95+
"""
96+
# Extract the HDI (uncertainty interval) of the treatment time
97+
hdi = az.hdi(idata, var_names=["treatment_time"])["treatment_time"].values
98+
x1 = datapost.index[datapost["t"] == int(hdi[0])][0]
99+
x2 = datapost.index[datapost["t"] == int(hdi[1])][0]
100+
101+
for i in [0, 1, 2]:
102+
ymin, ymax = ax[i].get_ylim()
103+
104+
# Vertical line for inferred treatment time
105+
ax[i].plot(
106+
[treatment_time, treatment_time],
107+
[ymin, ymax],
108+
ls="-",
109+
lw=3,
110+
color="r",
111+
solid_capstyle="butt",
112+
)
113+
114+
# Shaded region for HDI of treatment time
115+
ax[i].fill_betweenx(
116+
y=[ymin, ymax],
117+
x1=x1,
118+
x2=x2,
119+
alpha=0.1,
120+
color="r",
121+
)
122+
123+
def plot_treated_counterfactual(
124+
self, ax, handles, labels, datapost, post_pred, post_y
125+
):
126+
"""
127+
Plot the inferred post-intervention trajectory (with treatment effect).
128+
"""
129+
# --- Plot predicted trajectory under treatment (with HDI)
130+
h_line, h_patch = plot_xY(
131+
datapost.index,
132+
post_pred["posterior_predictive"].mu_ts,
133+
ax=ax[0],
134+
plot_hdi_kwargs={"color": "yellowgreen"},
135+
)
136+
handles.append((h_line, h_patch))
137+
labels.append("treated counterfactual")
138+
139+
140+
class HandlerKTT:
141+
"""
142+
Handles data preprocessing, postprocessing, and plotting logic for models
143+
where the treatment time is known in advance.
144+
"""
145+
146+
def data_preprocessing(self, data, treatment_time, formula, model):
147+
"""
148+
Preprocess the data using patsy for fitting into the model
149+
"""
150+
# Use only data before treatment for training the model
151+
return dmatrices(formula, data[data.index < treatment_time])
152+
153+
def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
154+
"""
155+
Postprocess data by splitting it into pre- and post-intervention periods, using the known treatment time.
156+
"""
157+
return (
158+
data[data.index < treatment_time],
159+
data[data.index >= treatment_time],
160+
pre_y,
161+
pre_X,
162+
treatment_time,
163+
)
164+
165+
def plot_intervention_line(self, ax, idata, datapost, treatment_time):
166+
"""
167+
Plot a vertical line at the known treatment time.
168+
"""
169+
# --- Plot a vertical line at the known treatment time
170+
for i in [0, 1, 2]:
171+
ax[i].axvline(
172+
x=treatment_time, ls="-", lw=3, color="r", solid_capstyle="butt"
173+
)
174+
175+
def plot_treated_counterfactual(
176+
self, sax, handles, labels, datapost, post_pred, post_y
177+
):
178+
"""
179+
Placeholder method to maintain interface compatibility.
180+
"""
181+
pass
182+
183+
36184
class InterruptedTimeSeries(BaseExperiment):
37185
"""
38186
The class for interrupted time series analysis.
@@ -86,38 +234,33 @@ def __init__(
86234
self.input_validation(data, treatment_time, model)
87235

88236
self.treatment_time = treatment_time
89-
# set experiment type - usually done in subclasses
90-
self.expt_type = "Pre-Post Fit"
91-
# set if the model is supposed to infer the treatment_time
92-
self.infer_treatment_time = isinstance(self.treatment_time, (type(None), tuple))
237+
self.formula = formula
93238

94-
# Set the data according to if the model is fitted on the whole bunch or not
95-
if self.infer_treatment_time:
96-
self.datapre = data
239+
# Getting the right handler
240+
if treatment_time is None or isinstance(treatment_time, tuple):
241+
self.handler = HandlerUTT()
97242
else:
98-
# split data in to pre and post intervention
99-
self.datapre = data[data.index < self.treatment_time]
243+
self.handler = HandlerKTT()
100244

101-
self.formula = formula
245+
# set experiment type - usually done in subclasses
246+
self.expt_type = "Pre-Post Fit"
247+
248+
# Preprocessing based on handler type
249+
y, X = self.handler.data_preprocessing(
250+
data, self.treatment_time, formula, self.model
251+
)
102252

103253
# set things up with pre-intervention data
104-
y, X = dmatrices(formula, self.datapre)
105254
self.outcome_variable_name = y.design_info.column_names[0]
106255
self._y_design_info = y.design_info
107256
self._x_design_info = X.design_info
108257
self.labels = X.design_info.column_names
109258
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
110259

111-
# Setting the time range in which the model infers treatment_time
112-
# Setting the timeline index so that the model can keep of time track between predicts
113-
if self.infer_treatment_time:
114-
self.model.set_time_range(self.treatment_time, self.datapre)
115-
self.model.set_timeline(self.labels.index("t"))
116-
117260
# fit the model to the observed (pre-intervention) data
118261
if isinstance(self.model, PyMCModel):
119262
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])}
120-
idata = self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
263+
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
121264
elif isinstance(self.model, RegressorMixin):
122265
self.model.fit(X=self.pre_X, y=self.pre_y)
123266
else:
@@ -126,29 +269,17 @@ def __init__(
126269
# score the goodness of fit to the pre-intervention data
127270
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
128271

129-
if self.infer_treatment_time:
130-
# We're getting the inferred switchpoint as one of the values of the timeline, from the last column
131-
switchpoint = int(
132-
az.extract(idata, group="posterior", var_names="switchpoint")
133-
.mean("sample")
134-
.values
135-
)
136-
# we're getting the associated index of that switchpoint
137-
self.treatment_time = data[data["t"] == switchpoint].index[0]
138-
139-
# We're getting datapre as intended for prediction
140-
self.datapre = data[data.index < self.treatment_time]
141-
(new_y, new_x) = build_design_matrices(
142-
[self._y_design_info, self._x_design_info], self.datapre
272+
# Postprocessing with handler
273+
self.datapre, self.datapost, self.pre_y, self.pre_X, self.treatment_time = (
274+
self.handler.data_postprocessing(
275+
data, self.idata, treatment_time, self.pre_y, self.pre_X
143276
)
144-
self.pre_X = np.asarray(new_x)
145-
self.pre_y = np.asarray(new_y)
277+
)
146278

147279
# get the model predictions of the observed (pre-intervention) data
148280
self.pre_pred = self.model.predict(X=self.pre_X)
149-
# process post-intervention data
150-
self.datapost = data[data.index >= self.treatment_time]
151281

282+
# process post-intervention data
152283
(new_y, new_x) = build_design_matrices(
153284
[self._y_design_info, self._x_design_info], self.datapost
154285
)
@@ -211,6 +342,7 @@ def _bayesian_plot(
211342

212343
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
213344
# TOP PLOT --------------------------------------------------
345+
214346
# pre-intervention period
215347
h_line, h_patch = plot_xY(
216348
self.datapre.index,
@@ -225,6 +357,11 @@ def _bayesian_plot(
225357
handles.append(h)
226358
labels.append("Observations")
227359

360+
# Green line for treated counterfactual (if unknown treatment time)
361+
self.handler.plot_treated_counterfactual(
362+
ax, handles, labels, self.datapost, self.post_pred, self.post_y
363+
)
364+
228365
# post intervention period
229366
h_line, h_patch = plot_xY(
230367
self.datapost.index,
@@ -289,14 +426,10 @@ def _bayesian_plot(
289426
)
290427
ax[2].axhline(y=0, c="k")
291428

292-
# Intervention line
293-
for i in [0, 1, 2]:
294-
ax[i].axvline(
295-
x=self.treatment_time,
296-
ls="-",
297-
lw=3,
298-
color="r",
299-
)
429+
# Plot vertical line marking treatment time (with HDI if it's inferred)
430+
self.handler.plot_intervention_line(
431+
ax, self.idata, self.datapost, self.treatment_time
432+
)
300433

301434
ax[0].legend(
302435
handles=(h_tuple for h_tuple in handles),
@@ -441,3 +574,14 @@ def get_plot_data_ols(self) -> pd.DataFrame:
441574
self.plot_data = pd.concat([pre_data, post_data])
442575

443576
return self.plot_data
577+
578+
def plot_treatment_time(self):
579+
"""
580+
display the posterior estimates of the treatment time
581+
"""
582+
if "treatment_time" not in self.idata.posterior.data_vars:
583+
raise ValueError(
584+
"Variable 'treatment_time' not found in inference data (idata)."
585+
)
586+
587+
az.plot_trace(self.idata, var_names="treatment_time")

0 commit comments

Comments
 (0)