Skip to content

Commit fcfd059

Browse files
committed
Supporting Date format and adding exceptions for model related issues
1 parent b1681da commit fcfd059

File tree

3 files changed

+20
-21
lines changed

3 files changed

+20
-21
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(
9292
# Set the data according to if the model is
9393
if treatment_time is None or isinstance(treatment_time, tuple):
9494
self.datapre = data
95-
self.model.set_time_range(self.treatment_time)
95+
self.model.set_time_range(self.treatment_time, self.datapre)
9696
else:
9797
# split data in to pre and post intervention
9898
self.datapre = data[data.index < self.treatment_time]
@@ -120,11 +120,18 @@ def __init__(
120120
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
121121

122122
if treatment_time is None or isinstance(treatment_time, tuple):
123-
self.treatment_time = int(
123+
# We're getting the inferred switchpoint as one of the values of the timeline, from the last column
124+
switchpoint = int(
124125
az.extract(idata, group="posterior", var_names="switchpoint")
125126
.mean("sample")
126127
.values
127128
)
129+
130+
# we're getting the associated index of that switchpoint
131+
last_column = data.columns[-1]
132+
self.treatment_time = data[data[last_column] == switchpoint].index[0]
133+
134+
# We're getting datapre as intended for prediction
128135
self.datapre = data[data.index < self.treatment_time]
129136
(new_y, new_x) = build_design_matrices(
130137
[self._y_design_info, self._x_design_info], self.datapre
@@ -155,22 +162,20 @@ def __init__(
155162

156163
def input_validation(self, data, treatment_time, model):
157164
"""Validate the input data and model formula for correctness"""
158-
if treatment_time is None and not hasattr(model, "set_time_range"):
159-
raise ModelException(
160-
"If treatment_time is None, provided model must have a 'set_time_range' method"
161-
)
162-
elif isinstance(treatment_time, tuple) and not hasattr(model, "set_time_range"):
165+
if isinstance(treatment_time, (type(None), tuple)) and not hasattr(
166+
model, "set_time_range"
167+
):
163168
raise ModelException(
164-
"If treatment_time is a tuple, provided model must have a 'set_time_range' method"
169+
"If treatment_time is None or a tuple, provided model must have a 'set_time_range' method"
165170
)
166-
elif isinstance(data.index, pd.DatetimeIndex) and not isinstance(
167-
treatment_time, pd.Timestamp
171+
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
172+
treatment_time, (pd.Timestamp, tuple, type(None))
168173
):
169174
raise BadIndexException(
170175
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
171176
)
172-
elif not isinstance(data.index, pd.DatetimeIndex) and isinstance(
173-
treatment_time, pd.Timestamp
177+
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
178+
treatment_time, (pd.Timestamp)
174179
):
175180
raise BadIndexException(
176181
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501

causalpy/pymc_models.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -693,9 +693,3 @@ def _data_setter(self, X, t) -> None:
693693
{"X": X, "t": t, "y": np.zeros(new_no_of_observations)},
694694
coords={"obs_ind": np.arange(new_no_of_observations)},
695695
)
696-
697-
def set_time_range(self, time_range):
698-
"""
699-
Set time_range.
700-
"""
701-
self.time_range = time_range

docs/source/notebooks/its_no_treatment_time.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@
176176
{
177177
"data": {
178178
"text/plain": [
179-
"[<matplotlib.lines.Line2D at 0x1ec1d5ed810>]"
179+
"[<matplotlib.lines.Line2D at 0x1cc0da5d810>]"
180180
]
181181
},
182182
"execution_count": 4,
@@ -264,7 +264,7 @@
264264
{
265265
"data": {
266266
"application/vnd.jupyter.widget-view+json": {
267-
"model_id": "5947d54eeaaa430bac3bf1ab430e3837",
267+
"model_id": "cf378625c7a04cd183378bed51e5e63a",
268268
"version_major": 2,
269269
"version_minor": 0
270270
},
@@ -289,7 +289,7 @@
289289
"name": "stderr",
290290
"output_type": "stream",
291291
"text": [
292-
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 59 seconds.\n",
292+
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 49 seconds.\n",
293293
"Sampling: [beta, decay_rate, impulse_amplitude, sigma, switchpoint, y_hat, y_ts]\n",
294294
"Sampling: [y_ts]\n",
295295
"Sampling: [y_hat, y_ts]\n",

0 commit comments

Comments
 (0)