Skip to content

Commit 64c97b7

Browse files
committed
changing column index restriction to label restriction
1 parent fcfd059 commit 64c97b7

File tree

3 files changed

+47
-22
lines changed

3 files changed

+47
-22
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,12 @@ def __init__(
8888
self.treatment_time = treatment_time
8989
# set experiment type - usually done in subclasses
9090
self.expt_type = "Pre-Post Fit"
91+
# set if the model is supposed to infer the treatment_time
92+
self.infer_treatment_time = isinstance(self.treatment_time, (type(None), tuple))
9193

92-
# Set the data according to if the model is
93-
if treatment_time is None or isinstance(treatment_time, tuple):
94+
# Set the data according to if the model is fitted on the whole bunch or not
95+
if self.infer_treatment_time:
9496
self.datapre = data
95-
self.model.set_time_range(self.treatment_time, self.datapre)
9697
else:
9798
# split data in to pre and post intervention
9899
self.datapre = data[data.index < self.treatment_time]
@@ -107,6 +108,12 @@ def __init__(
107108
self.labels = X.design_info.column_names
108109
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
109110

111+
# Setting the time range in which the model infers treatment_time
112+
# Setting the timeline index so that the model can keep of time track between predicts
113+
if self.infer_treatment_time:
114+
self.model.set_time_range(self.treatment_time, self.datapre)
115+
self.model.set_timeline(self.labels.index("t"))
116+
110117
# fit the model to the observed (pre-intervention) data
111118
if isinstance(self.model, PyMCModel):
112119
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])}
@@ -119,17 +126,15 @@ def __init__(
119126
# score the goodness of fit to the pre-intervention data
120127
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
121128

122-
if treatment_time is None or isinstance(treatment_time, tuple):
129+
if self.infer_treatment_time:
123130
# We're getting the inferred switchpoint as one of the values of the timeline, from the last column
124131
switchpoint = int(
125132
az.extract(idata, group="posterior", var_names="switchpoint")
126133
.mean("sample")
127134
.values
128135
)
129-
130136
# we're getting the associated index of that switchpoint
131-
last_column = data.columns[-1]
132-
self.treatment_time = data[data[last_column] == switchpoint].index[0]
137+
self.treatment_time = data[data["t"] == switchpoint].index[0]
133138

134139
# We're getting datapre as intended for prediction
135140
self.datapre = data[data.index < self.treatment_time]
@@ -162,11 +167,13 @@ def __init__(
162167

163168
def input_validation(self, data, treatment_time, model):
164169
"""Validate the input data and model formula for correctness"""
165-
if isinstance(treatment_time, (type(None), tuple)) and not hasattr(
166-
model, "set_time_range"
167-
):
170+
if treatment_time is None and not hasattr(model, "set_time_range"):
171+
raise ModelException(
172+
"If treatment_time is None, provided model must have a 'set_time_range' method"
173+
)
174+
if isinstance(treatment_time, tuple) and not hasattr(model, "set_time_range"):
168175
raise ModelException(
169-
"If treatment_time is None or a tuple, provided model must have a 'set_time_range' method"
176+
"If treatment_time is a tuple, provided model must have a 'set_time_range' method"
170177
)
171178
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
172179
treatment_time, (pd.Timestamp, tuple, type(None))

causalpy/pymc_models.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,9 @@ class InterventionTimeEstimator(PyMCModel):
539539
>>> labels = X.design_info.column_names
540540
>>> _y, _X = np.asarray(y), np.asarray(X)
541541
>>> COORDS = {"coeffs":labels, "obs_ind": np.arange(_X.shape[0])}
542-
>>> model = ITE(sample_kwargs={"draws" : 10, "tune":10, "progressbar":False}) # For a quick overview. Remove sample_kwargs parameter for better performance
543-
>>> model.set_time_range(None)
542+
>>> model = ITE(sample_kwargs={"draws" : 10, "tune":10, "progressbar":False})
543+
>>> model.set_time_range(None, data)
544+
>>> model.set_timeline(-1)
544545
>>> model.fit(X=_X, y=_y, coords=COORDS)
545546
Inference ...
546547
"""
@@ -549,7 +550,7 @@ def __init__(self, priors={}, sample_kwargs=None):
549550
super().__init__(sample_kwargs)
550551
self.priors = priors
551552

552-
def build_model(self, X, t, y, coords):
553+
def build_model(self, X, y, coords):
553554
"""
554555
Defines the PyMC model
555556
@@ -568,7 +569,7 @@ def build_model(self, X, t, y, coords):
568569
with self:
569570
self.add_coords(coords)
570571

571-
t = pm.Data("t", t, dims="obs_ind")
572+
t = pm.Data("t", X[:, self.timeline], dims="obs_ind")
572573
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
573574
y = pm.Data("y", y[:, 0], dims="obs_ind")
574575
lower_bound = pm.Data("lower_bound", self.time_range[0])
@@ -644,10 +645,9 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
644645
# Ensure random_seed is used in sample_prior_predictive() and
645646
# sample_posterior_predictive() if provided in sample_kwargs.
646647
random_seed = self.sample_kwargs.get("random_seed", None)
647-
t = X[:, -1]
648648
if self.time_range is None:
649-
self.time_range = (t.min(), t.max())
650-
self.build_model(X, t, y, coords)
649+
self.time_range = (X[:, self.timeline].min(), X[:, self.timeline].max())
650+
self.build_model(X, y, coords)
651651
with self:
652652
self.idata = pm.sample(max_treedepth=15, **self.sample_kwargs)
653653
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
@@ -669,7 +669,7 @@ def predict(self, X):
669669
# Ensure random_seed is used in sample_prior_predictive() and
670670
# sample_posterior_predictive() if provided in sample_kwargs.
671671
random_seed = self.sample_kwargs.get("random_seed", None)
672-
t = X[:, -1]
672+
t = X[:, self.timeline]
673673
self._data_setter(X, t)
674674
with self: # sample with new input data
675675
post_pred = pm.sample_posterior_predictive(
@@ -693,3 +693,21 @@ def _data_setter(self, X, t) -> None:
693693
{"X": X, "t": t, "y": np.zeros(new_no_of_observations)},
694694
coords={"obs_ind": np.arange(new_no_of_observations)},
695695
)
696+
697+
def set_time_range(self, time_range, data):
698+
"""
699+
Set time_range.
700+
"""
701+
if time_range is None:
702+
self.time_range = time_range
703+
else:
704+
self.time_range = (
705+
data["t"].loc[time_range[0]],
706+
data["t"].loc[time_range[1]],
707+
)
708+
709+
def set_timeline(self, index):
710+
"""
711+
Set the index of the timeline in the given covariates
712+
"""
713+
self.timeline = index

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)