Skip to content

Commit 020f679

Browse files
committed
fixing merging issues
1 parent 1da80fd commit 020f679

File tree

6 files changed

+278
-468
lines changed

6 files changed

+278
-468
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 43 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -40,54 +40,38 @@ class HandlerUTT:
4040
with unknown treatment intervention times.
4141
"""
4242

43-
def data_preprocessing(self, data, treatment_time, formula, model):
43+
def data_preprocessing(self, data, treatment_time, model):
4444
"""
45-
Preprocess the data using patsy for fittng into the model and update the model with required infos
45+
Preprocess the input data and update the model's treatment time constraints.
4646
"""
47-
y, X = dmatrices(formula, data)
4847
# Restrict model's treatment time inference to given range
4948
model.set_time_range(treatment_time, data)
50-
# Needed to track time evolution across model predictions
51-
model.set_timeline(X.design_info.column_names.index("t"))
52-
return y, X
49+
return data
5350

5451
def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
5552
"""
56-
Postprocess the data accordingly to the inferred treatment time for calculation and plot purpose
53+
Postprocess data based on the inferred treatment time for further analysis and plotting.
5754
"""
58-
# Retrieve posterior mean of inferred treatment time
55+
# --- Inferred treatment time ---
5956
treatment_time_mean = idata.posterior["treatment_time"].mean().item()
60-
inferred_time = int(treatment_time_mean)
57+
inferred_treatment_time = int(treatment_time_mean)
58+
idx_treatment_time = data.index[data["t"] == inferred_treatment_time][0]
6159

62-
# Safety check: ensure the inferred time is present in the dataset
63-
if inferred_time not in data["t"].values:
64-
raise ValueError(
65-
f"Inferred treatment time {inferred_time} not found in data['t']."
66-
)
67-
68-
# Convert the inferred time to its corresponding DataFrame index
69-
inferred_index = data[data["t"] == inferred_time].index[0]
70-
71-
# Retrieve HDI bounds of treatment time (uncertainty interval)
60+
# --- HDI bounds (credible interval) ---
7261
hdi_bounds = az.hdi(idata, var_names=["treatment_time"])[
7362
"treatment_time"
7463
].values
7564
hdi_start_time = int(hdi_bounds[0])
65+
indice = data.index.get_loc(data.index[data["t"] == hdi_start_time][0])
7666

77-
# Convert HDI lower bound to DataFrame index for slicing
78-
if hdi_start_time not in data["t"].values:
79-
raise ValueError(f"HDI start time {hdi_start_time} not found in data['t'].")
80-
81-
hdi_start_idx_df = data[data["t"] == hdi_start_time].index[0]
82-
hdi_start_idx_np = data.index.get_loc(hdi_start_idx_df)
67+
# --- Slicing ---
68+
datapre = data[data["t"] < hdi_start_time]
69+
datapost = data[data["t"] >= hdi_start_time]
8370

84-
# Slice both pandas and numpy objects accordingly
85-
df_pre = data[data.index < hdi_start_idx_df]
86-
df_post = data[data.index >= hdi_start_idx_df]
87-
truncated_y = pre_y[:hdi_start_idx_np]
88-
truncated_X = pre_X[:hdi_start_idx_np]
71+
truncated_y = pre_y.isel(obs_ind=slice(0, indice))
72+
truncated_X = pre_X.isel(obs_ind=slice(0, indice))
8973

90-
return df_pre, df_post, truncated_y, truncated_X, inferred_index
74+
return datapre, datapost, truncated_y, truncated_X, idx_treatment_time
9175

9276
def plot_intervention_line(self, ax, idata, datapost, treatment_time):
9377
"""
@@ -144,16 +128,16 @@ class HandlerKTT:
144128
where the treatment time is known in advance.
145129
"""
146130

147-
def data_preprocessing(self, data, treatment_time, formula, model):
131+
def data_preprocessing(self, data, treatment_time, model):
148132
"""
149-
Preprocess the data using patsy for fitting into the model
133+
Preprocess the data by selecting only the pre-treatment period for model fitting.
150134
"""
151135
# Use only data before treatment for training the model
152-
return dmatrices(formula, data[data.index < treatment_time])
136+
return data[data.index < treatment_time]
153137

154138
def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
155139
"""
156-
Postprocess data by splitting it into pre- and post-intervention periods, using the known treatment time.
140+
Split data into pre- and post-treatment periods using the known treatment time.
157141
"""
158142
return (
159143
data[data.index < treatment_time],
@@ -165,7 +149,7 @@ def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
165149

166150
def plot_intervention_line(self, ax, idata, datapost, treatment_time):
167151
"""
168-
Plot a vertical line at the known treatment time.
152+
Plot a vertical line at the known treatment time on provided axes.
169153
"""
170154
# --- Plot a vertical line at the known treatment time
171155
for i in [0, 1, 2]:
@@ -177,7 +161,7 @@ def plot_treated_counterfactual(
177161
self, sax, handles, labels, datapost, post_pred, post_y
178162
):
179163
"""
180-
Placeholder method to maintain interface compatibility.
164+
Placeholder method to maintain interface compatibility with HandlerUTT.
181165
"""
182166
pass
183167

@@ -236,7 +220,6 @@ def __init__(
236220
# rename the index to "obs_ind"
237221
data.index.name = "obs_ind"
238222
self.input_validation(data, treatment_time, model)
239-
self.treatment_time = treatment_time
240223
# set experiment type - usually done in subclasses
241224
self.expt_type = "Pre-Post Fit"
242225

@@ -249,27 +232,41 @@ def __init__(
249232
else:
250233
self.handler = HandlerKTT()
251234

252-
# set experiment type - usually done in subclasses
253-
self.expt_type = "Pre-Post Fit"
254-
255235
# Preprocessing based on handler type
256-
y, X = self.handler.data_preprocessing(
257-
data, self.treatment_time, formula, self.model
236+
self.datapre = self.handler.data_preprocessing(
237+
data, self.treatment_time, self.model
258238
)
259239

240+
y, X = dmatrices(formula, self.datapre)
260241
# set things up with pre-intervention data
261242
self.outcome_variable_name = y.design_info.column_names[0]
262243
self._y_design_info = y.design_info
263244
self._x_design_info = X.design_info
264245
self.labels = X.design_info.column_names
265246
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
266247

248+
# turn into xarray.DataArray's
249+
self.pre_X = xr.DataArray(
250+
self.pre_X,
251+
dims=["obs_ind", "coeffs"],
252+
coords={
253+
"obs_ind": self.datapre.index,
254+
"coeffs": self.labels,
255+
},
256+
)
257+
self.pre_y = xr.DataArray(
258+
self.pre_y[:, 0],
259+
dims=["obs_ind"],
260+
coords={"obs_ind": self.datapre.index},
261+
)
262+
267263
# fit the model to the observed (pre-intervention) data
268264
if isinstance(self.model, PyMCModel):
269-
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])}
270-
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
265+
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(X.shape[0])}
266+
idata = self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
271267
elif isinstance(self.model, RegressorMixin):
272268
self.model.fit(X=self.pre_X, y=self.pre_y)
269+
idata = None
273270
else:
274271
raise ValueError("Model type not recognized")
275272

@@ -279,7 +276,7 @@ def __init__(
279276
# Postprocessing with handler
280277
self.datapre, self.datapost, self.pre_y, self.pre_X, self.treatment_time = (
281278
self.handler.data_postprocessing(
282-
data, self.idata, treatment_time, self.pre_y, self.pre_X
279+
data, idata, treatment_time, self.pre_y, self.pre_X
283280
)
284281
)
285282

@@ -292,20 +289,6 @@ def __init__(
292289
)
293290
self.post_X = np.asarray(new_x)
294291
self.post_y = np.asarray(new_y)
295-
# turn into xarray.DataArray's
296-
self.pre_X = xr.DataArray(
297-
self.pre_X,
298-
dims=["obs_ind", "coeffs"],
299-
coords={
300-
"obs_ind": self.datapre.index,
301-
"coeffs": self.labels,
302-
},
303-
)
304-
self.pre_y = xr.DataArray(
305-
self.pre_y[:, 0],
306-
dims=["obs_ind"],
307-
coords={"obs_ind": self.datapre.index},
308-
)
309292
self.post_X = xr.DataArray(
310293
self.post_X,
311294
dims=["obs_ind", "coeffs"],

0 commit comments

Comments
 (0)