Skip to content

Commit 8d607b8

Browse files
committed
updating notebook with examples and adding time_variable_name parameter
1 parent 2d4d158 commit 8d607b8

File tree

4 files changed

+564
-149
lines changed

4 files changed

+564
-149
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,40 +48,50 @@ def data_preprocessing(self, data, treatment_time, model):
4848
model.set_time_range(treatment_time, data)
4949
return data
5050

51-
def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
51+
def data_postprocessing(self, model, data, idata, treatment_time, pre_y, pre_X):
5252
"""
5353
Postprocess data based on the inferred treatment time for further analysis and plotting.
5454
"""
55+
# --- Getting the time_variable_name ---
56+
time_variable_name = model.get_time_variable_name()
57+
5558
# --- Inferred treatment time ---
5659
treatment_time_mean = idata.posterior["treatment_time"].mean().item()
5760
inferred_treatment_time = int(treatment_time_mean)
58-
idx_treatment_time = data.index[data["t"] == inferred_treatment_time][0]
61+
idx_treatment_time = data.index[
62+
data[time_variable_name] == inferred_treatment_time
63+
][0]
5964

6065
# --- HDI bounds (credible interval) ---
6166
hdi_bounds = az.hdi(idata, var_names=["treatment_time"])[
6267
"treatment_time"
6368
].values
6469
hdi_start_time = int(hdi_bounds[0])
65-
indice = data.index.get_loc(data.index[data["t"] == hdi_start_time][0])
70+
indice = data.index.get_loc(
71+
data.index[data[time_variable_name] == hdi_start_time][0]
72+
)
6673

6774
# --- Slicing ---
68-
datapre = data[data["t"] < hdi_start_time]
69-
datapost = data[data["t"] >= hdi_start_time]
75+
datapre = data[data[time_variable_name] < hdi_start_time]
76+
datapost = data[data[time_variable_name] >= hdi_start_time]
7077

7178
truncated_y = pre_y.isel(obs_ind=slice(0, indice))
7279
truncated_X = pre_X.isel(obs_ind=slice(0, indice))
7380

7481
return datapre, datapost, truncated_y, truncated_X, idx_treatment_time
7582

76-
def plot_intervention_line(self, ax, idata, datapost, treatment_time):
83+
def plot_intervention_line(self, ax, model, idata, datapost, treatment_time):
7784
"""
7885
Plot a vertical line at the inferred treatment time, along with a shaded area
7986
representing the Highest Density Interval (HDI) of the inferred time.
8087
"""
88+
# --- Getting the time_variable_name ---
89+
time_variable_name = model.get_time_variable_name()
90+
8191
# Extract the HDI (uncertainty interval) of the treatment time
8292
hdi = az.hdi(idata, var_names=["treatment_time"])["treatment_time"].values
83-
x1 = datapost.index[datapost["t"] == int(hdi[0])][0]
84-
x2 = datapost.index[datapost["t"] == int(hdi[1])][0]
93+
x1 = datapost.index[datapost[time_variable_name] == int(hdi[0])][0]
94+
x2 = datapost.index[datapost[time_variable_name] == int(hdi[1])][0]
8595

8696
for i in [0, 1, 2]:
8797
ymin, ymax = ax[i].get_ylim()
@@ -119,7 +129,7 @@ def plot_treated_counterfactual(
119129
plot_hdi_kwargs={"color": "yellowgreen"},
120130
)
121131
handles.append((h_line, h_patch))
122-
labels.append("treated counterfactual")
132+
labels.append("Treated counterfactual")
123133

124134

125135
class HandlerKTT:
@@ -135,7 +145,7 @@ def data_preprocessing(self, data, treatment_time, model):
135145
# Use only data before treatment for training the model
136146
return data[data.index < treatment_time]
137147

138-
def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
148+
def data_postprocessing(self, model, data, idata, treatment_time, pre_y, pre_X):
139149
"""
140150
Split data into pre- and post-treatment periods using the known treatment time.
141151
"""
@@ -147,7 +157,7 @@ def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
147157
treatment_time,
148158
)
149159

150-
def plot_intervention_line(self, ax, idata, datapost, treatment_time):
160+
def plot_intervention_line(self, model, ax, idata, datapost, treatment_time):
151161
"""
152162
Plot a vertical line at the known treatment time on provided axes.
153163
"""
@@ -276,7 +286,7 @@ def __init__(
276286
# Postprocessing with handler
277287
self.datapre, self.datapost, self.pre_y, self.pre_X, self.treatment_time = (
278288
self.handler.data_postprocessing(
279-
data, idata, treatment_time, self.pre_y, self.pre_X
289+
self.model, data, idata, treatment_time, self.pre_y, self.pre_X
280290
)
281291
)
282292

@@ -443,7 +453,7 @@ def _bayesian_plot(
443453

444454
# Plot vertical line marking treatment time (with HDI if it's inferred)
445455
self.handler.plot_intervention_line(
446-
ax, self.idata, self.datapost, self.treatment_time
456+
ax, self.model, self.idata, self.datapost, self.treatment_time
447457
)
448458

449459
ax[0].legend(

causalpy/pymc_models.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -566,10 +566,13 @@ class InterventionTimeEstimator(PyMCModel):
566566
Inference ...
567567
"""
568568

569-
def __init__(self, treatment_type_effect=None, sample_kwargs=None):
569+
def __init__(
570+
self, time_variable_name: str, treatment_type_effect=None, sample_kwargs=None
571+
):
570572
"""
571573
Initializes the InterventionTimeEstimator model.
572574
575+
:param time_variable_name: name of the column representing time among the covariates
573576
:param treatment_type_effect: Optional dictionary that specifies prior parameters for the
574577
intervention effects. Expected keys are:
575578
- "level": [mu, sigma]
@@ -579,9 +582,10 @@ def __init__(self, treatment_type_effect=None, sample_kwargs=None):
579582
If the associated list is incomplete, default values will be used.
580583
:param sample_kwargs: Optional dictionary of arguments passed to pm.sample().
581584
"""
585+
self.time_variable_name = time_variable_name
586+
582587
if treatment_type_effect is None:
583588
treatment_type_effect = {}
584-
585589
if not isinstance(treatment_type_effect, dict):
586590
raise TypeError("treatment_type_effect must be a dictionary.")
587591

@@ -609,7 +613,7 @@ def build_model(self, X, y, coords):
609613
with self:
610614
self.add_coords(coords)
611615

612-
t = pm.Data("t", X.sel(coeffs="t"), dims="obs_ind")
616+
t = pm.Data("t", X.sel(coeffs=self.time_variable_name), dims="obs_ind")
613617
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
614618
y = pm.Data("y", y, dims="obs_ind")
615619

@@ -751,7 +755,11 @@ def _data_setter(self, X) -> None:
751755
new_no_of_observations = X.shape[0]
752756
with self:
753757
pm.set_data(
754-
{"X": X, "t": X.sel(coeffs="t"), "y": np.zeros(new_no_of_observations)},
758+
{
759+
"X": X,
760+
"t": X.sel(coeffs=self.time_variable_name),
761+
"y": np.zeros(new_no_of_observations),
762+
},
755763
coords={"obs_ind": np.arange(new_no_of_observations)},
756764
)
757765

@@ -771,12 +779,18 @@ def set_time_range(self, time_range, data):
771779
:param time_range: tuple or None
772780
If not None, a tuple of two values (start_label, end_label) that correspond
773781
to index labels in the 't' column of the `data` DataFrame
774-
:param data: pandas.DataFrame which contains a column "t".
782+
:param data: pandas.DataFrame.
775783
"""
776784
if time_range is None:
777-
self.time_range = data["t"].min(), data["t"].max()
785+
self.time_range = (
786+
data[self.time_variable_name].min(),
787+
data[self.time_variable_name].max(),
788+
)
778789
else:
779790
self.time_range = (
780-
data["t"].loc[time_range[0]],
781-
data["t"].loc[time_range[1]],
791+
data[self.time_variable_name].loc[time_range[0]],
792+
data[self.time_variable_name].loc[time_range[1]],
782793
)
794+
795+
def get_time_variable_name(self):
796+
return self.time_variable_name

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/notebooks/its_no_treatment_time.ipynb

Lines changed: 516 additions & 125 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)