Skip to content

Commit 5ee3cb4

Browse files
committed
Minor fix in docstring
1 parent ee701f2 commit 5ee3cb4

File tree

5 files changed

+751
-68
lines changed

5 files changed

+751
-68
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@
2525
from sklearn.base import RegressorMixin
2626

2727
from causalpy.custom_exceptions import BadIndexException
28+
from causalpy.experiments.base import BaseExperiment
2829
from causalpy.plot_utils import get_hdi_to_df, plot_xY
2930
from causalpy.pymc_models import PyMCModel
3031
from causalpy.utils import round_num
3132

32-
from .base import BaseExperiment
33-
3433
LEGEND_FONT_SIZE = 12
3534

3635

@@ -78,19 +77,27 @@ class InterruptedTimeSeries(BaseExperiment):
7877
def __init__(
7978
self,
8079
data: pd.DataFrame,
81-
treatment_time: Union[int, float, pd.Timestamp],
80+
treatment_time: Union[int, float, pd.Timestamp, tuple, None],
8281
formula: str,
8382
model=None,
8483
**kwargs,
8584
) -> None:
8685
super().__init__(model=model)
87-
self.input_validation(data, treatment_time)
86+
# input validation TODO : for the moment only valid for given treatment time
87+
if treatment_time is not None or not isinstance(treatment_time, tuple):
88+
self.input_validation(data, treatment_time)
89+
8890
self.treatment_time = treatment_time
8991
# set experiment type - usually done in subclasses
9092
self.expt_type = "Pre-Post Fit"
91-
# split data in to pre and post intervention
92-
self.datapre = data[data.index < self.treatment_time]
93-
self.datapost = data[data.index >= self.treatment_time]
93+
94+
# Set the data according to if the model is
95+
if treatment_time is None or isinstance(treatment_time, tuple):
96+
self.datapre = data
97+
self.model.set_time_range(self.treatment_time)
98+
else:
99+
# split data in to pre and post intervention
100+
self.datapre = data[data.index < self.treatment_time]
94101

95102
self.formula = formula
96103

@@ -101,17 +108,11 @@ def __init__(
101108
self._x_design_info = X.design_info
102109
self.labels = X.design_info.column_names
103110
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
104-
# process post-intervention data
105-
(new_y, new_x) = build_design_matrices(
106-
[self._y_design_info, self._x_design_info], self.datapost
107-
)
108-
self.post_X = np.asarray(new_x)
109-
self.post_y = np.asarray(new_y)
110111

111112
# fit the model to the observed (pre-intervention) data
112113
if isinstance(self.model, PyMCModel):
113114
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])}
114-
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
115+
idata = self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
115116
elif isinstance(self.model, RegressorMixin):
116117
self.model.fit(X=self.pre_X, y=self.pre_y)
117118
else:
@@ -120,8 +121,29 @@ def __init__(
120121
# score the goodness of fit to the pre-intervention data
121122
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
122123

124+
if treatment_time is None or isinstance(treatment_time, tuple):
125+
self.treatment_time = int(
126+
az.extract(idata, group="posterior", var_names="switchpoint")
127+
.mean("sample")
128+
.values
129+
)
130+
self.datapre = data[data.index < self.treatment_time]
131+
(new_y, new_x) = build_design_matrices(
132+
[self._y_design_info, self._x_design_info], self.datapre
133+
)
134+
self.pre_X = np.asarray(new_x)
135+
self.pre_y = np.asarray(new_y)
136+
123137
# get the model predictions of the observed (pre-intervention) data
124138
self.pre_pred = self.model.predict(X=self.pre_X)
139+
# process post-intervention data
140+
self.datapost = data[data.index >= self.treatment_time]
141+
142+
(new_y, new_x) = build_design_matrices(
143+
[self._y_design_info, self._x_design_info], self.datapost
144+
)
145+
self.post_X = np.asarray(new_x)
146+
self.post_y = np.asarray(new_y)
125147

126148
# calculate the counterfactual
127149
self.post_pred = self.model.predict(X=self.post_X)

causalpy/pymc_models.py

Lines changed: 109 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -530,26 +530,30 @@ class InterventionTimeEstimator(PyMCModel):
530530
--------
531531
>>> import causalpy as cp
532532
>>> import numpy as np
533-
>>> from causalpy.pymc_models import InterventionTimeEstimator
534-
>>> df = cp.load_data("its")
535-
>>> y = df["y"].values
536-
>>> t = df["t"].values
537-
>>> coords = {"seasons": range(12)} # The data is monthly
538-
>>> estimator = InterventionTimeEstimator()
539-
>>> # We are trying to capture an impulse in the number of death per month due to Covid.
540-
>>> estimator.fit(
541-
... t,
542-
... y,
543-
... coords,
544-
... priors={"impulse":[]}
545-
... )
546-
Inference data...
533+
>>> from patsy import build_design_matrices, dmatrices
534+
>>> from causalpy.pymc_models import InterventionTimeEstimator as ITE
535+
>>> data = cp.load_data("its")
536+
>>> formula="y ~ 1 + t + C(month)"
537+
>>> y, X = dmatrices(formula, data)
538+
>>> outcome_variable_name = y.design_info.column_names[0]
539+
>>> labels = X.design_info.column_names
540+
>>> _y, _X = np.asarray(y), np.asarray(X)
541+
>>> COORDS = {"coeffs":labels, "obs_ind": np.arange(_X.shape[0])}
542+
>>> model = ITE(sample_kwargs={"draws" : 10, "tune":10, "progressbar":False}) # For a quick overview. Remove sample_kwargs parameter for better performance
543+
>>> model.set_time_range(None)
544+
>>> model.fit(X=_X, y=_y, coords=COORDS)
545+
Inference ...
547546
"""
548547

549-
def build_model(self, t, y, coords, time_range, grain_season, priors):
548+
def __init__(self, priors={}, sample_kwargs=None):
549+
super().__init__(sample_kwargs)
550+
self.priors = priors
551+
552+
def build_model(self, X, t, y, coords):
550553
"""
551554
Defines the PyMC model
552555
556+
:param X: A dataframe of the covariates
553557
:param t: An array of values representing the time over which y is spread
554558
:param y: An array of values representing our outcome y
555559
:param coords: An optional dictionary with the coordinate names for our instruments.
@@ -564,80 +568,134 @@ def build_model(self, t, y, coords, time_range, grain_season, priors):
564568
with self:
565569
self.add_coords(coords)
566570

567-
if time_range is None:
568-
time_range = (t.min(), t.max())
569-
571+
t = pm.Data("t", t, dims="obs_ind")
572+
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
573+
y = pm.Data("y", y[:, 0], dims="obs_ind")
574+
lower_bound = pm.Data("lower_bound", self.time_range[0])
575+
upper_bound = pm.Data("upper_bound", self.time_range[1])
570576
# --- Priors ---
571577
switchpoint = pm.Uniform(
572-
"switchpoint", lower=time_range[0], upper=time_range[1]
578+
"switchpoint", lower=lower_bound, upper=upper_bound
573579
)
574-
alpha = pm.Normal(name="alpha", mu=0, sigma=50)
575-
beta = pm.Normal(name="beta", mu=0, sigma=50)
576-
seasons = 0
577-
if "seasons" in coords and len(coords["seasons"]) > 0:
578-
season_idx = np.arange(len(y)) // grain_season % len(coords["seasons"])
579-
seasons_effect = pm.Normal(
580-
"seasons_effect", mu=0, sigma=50, dims="seasons"
581-
)
582-
seasons = seasons_effect[season_idx]
580+
beta = pm.Normal(name="beta", mu=0, sigma=50, dims="coeffs")
583581

584582
# --- Intervention effect ---
585583
level = trend = impulse = 0
586584

587-
if "level" in priors:
585+
if "level" in self.priors:
588586
mu, sigma = (
589587
(0, 50)
590-
if len(priors["level"]) != 2
591-
else (priors["level"][0], priors["level"][1])
588+
if len(self.priors["level"]) != 2
589+
else (self.priors["level"][0], self.priors["level"][1])
592590
)
593591
level = pm.Normal(
594592
"level",
595593
mu=mu,
596594
sigma=sigma,
597595
)
598-
if "trend" in priors:
596+
if "trend" in self.priors:
599597
mu, sigma = (
600598
(0, 50)
601-
if len(priors["trend"]) != 2
602-
else (priors["trend"][0], priors["trend"][1])
599+
if len(self.priors["trend"]) != 2
600+
else (self.priors["trend"][0], self.priors["trend"][1])
603601
)
604602
trend = pm.Normal("trend", mu=mu, sigma=sigma)
605-
if "impulse" in priors:
603+
if "impulse" in self.priors:
606604
mu, sigma1, sigma2 = (
607605
(0, 50, 50)
608-
if len(priors["impulse"]) != 3
606+
if len(self.priors["impulse"]) != 3
609607
else (
610-
priors["impulse"][0],
611-
priors["impulse"][1],
612-
priors["impulse"][2],
608+
self.priors["impulse"][0],
609+
self.priors["impulse"][1],
610+
self.priors["impulse"][2],
613611
)
614612
)
615613
impulse_amplitude = pm.Normal("impulse_amplitude", mu=mu, sigma=sigma1)
616614
decay_rate = pm.HalfNormal("decay_rate", sigma=sigma2)
617-
impulse = impulse_amplitude * pm.math.exp(
618-
-decay_rate * abs(t - switchpoint)
615+
impulse = pm.Deterministic(
616+
"impulse",
617+
impulse_amplitude
618+
* pm.math.exp(-decay_rate * pm.math.abs(t - switchpoint)),
619619
)
620620

621621
# --- Parameterization ---
622622
weight = pm.math.sigmoid(t - switchpoint)
623-
# Compute and store the modelled time series
624-
mu_ts = pm.Deterministic(name="mu_ts", var=alpha + beta * t + seasons)
623+
# Compute and store the base time series
624+
mu = pm.Deterministic(name="mu", var=pm.math.dot(X, beta))
625625
# Compute and store the modelled intervention effect
626626
mu_in = pm.Deterministic(
627627
name="mu_in", var=level + trend * (t - switchpoint) + impulse
628628
)
629-
# Compute and store the the sum of the intervention and the time series
630-
mu = pm.Deterministic("mu", mu_ts + weight * mu_in)
629+
# Compute and store the sum of the base time series and the intervention's effect
630+
mu_ts = pm.Deterministic("mu_ts", mu + weight * mu_in)
631631
sigma = pm.HalfNormal("sigma", 1)
632632

633633
# --- Likelihood ---
634-
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y)
634+
# Likelihood of the base time series
635+
pm.Normal("y_hat", mu=mu, sigma=sigma, dims="obs_ind")
636+
# Likelihodd of the base time series and the intervention's effect
637+
pm.Normal("y_ts", mu=mu_ts, sigma=sigma, observed=y, dims="obs_ind")
635638

636-
def fit(self, t, y, coords, time_range=None, grain_season=1, priors={}, n=1000):
637-
"""
638-
Draw samples from posterior distribution
639+
def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
640+
"""Draw samples from posterior, prior predictive, and posterior predictive
641+
distributions, placing them in the model's idata attribute.
639642
"""
640-
self.build_model(t, y, coords, time_range, grain_season, priors)
643+
644+
# Ensure random_seed is used in sample_prior_predictive() and
645+
# sample_posterior_predictive() if provided in sample_kwargs.
646+
random_seed = self.sample_kwargs.get("random_seed", None)
647+
t = X[:, -1]
648+
if self.time_range is None:
649+
self.time_range = (t.min(), t.max())
650+
self.build_model(X, t, y, coords)
641651
with self:
642-
self.idata = pm.sample(n, progressbar=False, **self.sample_kwargs)
652+
self.idata = pm.sample(max_treedepth=15, **self.sample_kwargs)
653+
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
654+
self.idata.extend(
655+
pm.sample_posterior_predictive(
656+
self.idata, progressbar=False, random_seed=random_seed
657+
)
658+
)
643659
return self.idata
660+
661+
def predict(self, X):
662+
"""
663+
Predict data given input data `X`
664+
665+
.. caution::
666+
Results in KeyError if model hasn't been fit.
667+
"""
668+
669+
# Ensure random_seed is used in sample_prior_predictive() and
670+
# sample_posterior_predictive() if provided in sample_kwargs.
671+
random_seed = self.sample_kwargs.get("random_seed", None)
672+
t = X[:, -1]
673+
self._data_setter(X, t)
674+
with self: # sample with new input data
675+
post_pred = pm.sample_posterior_predictive(
676+
self.idata,
677+
var_names=["y_hat", "y_ts", "mu", "mu_ts", "mu_in"],
678+
progressbar=False,
679+
random_seed=random_seed,
680+
)
681+
return post_pred
682+
683+
def _data_setter(self, X, t) -> None:
684+
"""
685+
Set data for the model.
686+
687+
This method is used internally to register new data for the model for
688+
prediction.
689+
"""
690+
new_no_of_observations = X.shape[0]
691+
with self:
692+
pm.set_data(
693+
{"X": X, "t": t, "y": np.zeros(new_no_of_observations)},
694+
coords={"obs_ind": np.arange(new_no_of_observations)},
695+
)
696+
697+
def set_time_range(self, time_range):
698+
"""
699+
Set time_range.
700+
"""
701+
self.time_range = time_range

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/notebooks/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ did_pymc_banks.ipynb
4040
its_skl.ipynb
4141
its_pymc.ipynb
4242
its_covid.ipynb
43+
its_no_treatment_time.ipynb
4344
:::
4445

4546
:::{toctree}

0 commit comments

Comments
 (0)