Skip to content

Commit 6de9707

Browse files
committed
Removing time variable
1 parent 4a10196 commit 6de9707

File tree

2 files changed

+16
-21
lines changed

2 files changed

+16
-21
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,11 @@ def data_postprocessing(
9393
# --- Return ---
9494
res = {}
9595

96-
# --- Retrieve timeline and inferred treatment time ---
97-
time_var = model.time_variable_name
98-
timeline = data[time_var]
99-
10096
tt_samples = idata.posterior["treatment_time"].values
10197
tt_mean = int(tt_samples.mean().item())
10298

10399
# Actual timestamp (index) corresponding to inferred treatment
104-
tt = data[timeline == tt_mean].index[0]
100+
tt = data.index[tt_mean]
105101
# Index of the inferred treatment time in the data
106102
tt_idx = data.index.get_loc(tt)
107103
res["treatment_time"] = tt
@@ -128,9 +124,13 @@ def data_postprocessing(
128124

129125
# --- Create a mask to isolate post-treatment period ---
130126
# Timeline reshaped to match broadcasting with treatment time
131-
timeline_reshape = timeline.values.reshape(1, 1, -1)
127+
timeline = [
128+
[[i for i in range(len(data))] for _ in range(len(tt_samples[0]))]
129+
for _ in range(len(tt_samples))
130+
]
131+
timeline_broadcast = np.array(timeline)
132132
tt_broadcast = tt_samples[:, :, None].astype(int)
133-
mask = (timeline_reshape >= tt_broadcast).astype(int)
133+
mask = (timeline_broadcast >= tt_broadcast).astype(int)
134134

135135
# --- Compute cumulative post-treatment impact ---
136136
post_impact = impact * mask
@@ -184,12 +184,10 @@ def plot_intervention_line(
184184
Draw a vertical line at the inferred treatment time and shade the HDI interval around it.
185185
"""
186186
data = pd.concat([datapre, datapost])
187-
timeline = data[model.time_variable_name]
188-
189187
# Extract the HDI (uncertainty interval) of the treatment time
190188
hdi = az.hdi(idata, var_names=["treatment_time"])["treatment_time"].values
191-
x1 = data[timeline == int(hdi[0])].index[0]
192-
x2 = data[timeline == int(hdi[1])].index[0]
189+
x1 = data.index[int(hdi[0])]
190+
x2 = data.index[int(hdi[1])]
193191

194192
for i in [0, 1, 2]:
195193
ymin, ymax = ax[i].get_ylim()

causalpy/pymc_models.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -560,23 +560,21 @@ class InterventionTimeEstimator(PyMCModel):
560560
... coords={"obs_ind": data.index},
561561
... )
562562
>>> COORDS = {"coeffs":labels, "obs_ind": np.arange(_X.shape[0])}
563-
>>> model = ITE(time_variable_name="t", treatment_effect_type="level", sample_kwargs={"draws" : 10, "tune":10, "progressbar":False})
563+
>>> model = ITE(treatment_effect_type="level", sample_kwargs={"draws" : 10, "tune":10, "progressbar":False})
564564
>>> model.set_time_range(None, data)
565565
>>> model.fit(X=_X, y=_y, coords=COORDS)
566566
Inference ...
567567
"""
568568

569569
def __init__(
570570
self,
571-
time_variable_name: str,
572571
treatment_effect_type: str | list[str],
573572
treatment_effect_param=None,
574573
sample_kwargs=None,
575574
):
576575
"""
577576
Initializes the InterventionTimeEstimator model.
578577
579-
:param time_variable_name: name of the column representing time among the covariates
580578
:param treatment_effect_type: Optional dictionary that specifies prior parameters for the
581579
intervention effects. Expected keys are:
582580
- "level": [mu, sigma]
@@ -586,7 +584,6 @@ def __init__(
586584
If the associated list is incomplete, default values will be used.
587585
:param sample_kwargs: Optional dictionary of arguments passed to pm.sample().
588586
"""
589-
self.time_variable_name = time_variable_name
590587

591588
super().__init__(sample_kwargs)
592589

@@ -657,7 +654,7 @@ def build_model(self, X, y, coords):
657654
with self:
658655
self.add_coords(coords)
659656

660-
t = pm.Data("t", X.sel(coeffs=self.time_variable_name), dims="obs_ind")
657+
t = pm.Data("t", np.arange(len(X)), dims="obs_ind")
661658
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
662659
y = pm.Data("y", y, dims="obs_ind")
663660

@@ -768,7 +765,7 @@ def _data_setter(self, X) -> None:
768765
pm.set_data(
769766
{
770767
"X": X,
771-
"t": X.sel(coeffs=self.time_variable_name),
768+
"t": np.arange(len(X)),
772769
"y": np.zeros(new_no_of_observations),
773770
},
774771
coords={"obs_ind": np.arange(new_no_of_observations)},
@@ -794,11 +791,11 @@ def set_time_range(self, time_range, data):
794791
"""
795792
if time_range is None:
796793
self.time_range = (
797-
data[self.time_variable_name].min(),
798-
data[self.time_variable_name].max(),
794+
0,
795+
len(data),
799796
)
800797
else:
801798
self.time_range = (
802-
data[self.time_variable_name].loc[time_range[0]],
803-
data[self.time_variable_name].loc[time_range[1]],
799+
data.index.get_loc(time_range[0]),
800+
data.index.get_loc(time_range[1]),
804801
)

0 commit comments

Comments
 (0)