Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
10a017e
New feature : Model to estimate when a intervention had effect
JeanVanDyk May 28, 2025
69d79b3
New feature : Model to estimate when a intervention had effect
JeanVanDyk May 28, 2025
bf4eaaa
Minor fix in docstring
JeanVanDyk May 29, 2025
3420c9a
Minor fix in docstring
JeanVanDyk May 29, 2025
3dc23b3
Minor fix in docstring
JeanVanDyk May 29, 2025
d739b4a
Minor fix in docstring
JeanVanDyk May 29, 2025
d48f0c3
Minor fix in docstring
JeanVanDyk May 29, 2025
14afe09
Minor fix in docstring
JeanVanDyk May 29, 2025
60357a5
Minor fix in docstring
JeanVanDyk May 29, 2025
7f57b13
Minor fix in docstring
JeanVanDyk May 29, 2025
2cb92fc
Minor fix in docstring
JeanVanDyk May 29, 2025
d9c06ac
Minor fix in docstring
JeanVanDyk May 29, 2025
52cc0fa
Minor fix in docstring
JeanVanDyk May 29, 2025
faf085b
Minor fix in docstring
JeanVanDyk May 29, 2025
cc9a1f4
Minor fix in docstring
JeanVanDyk May 29, 2025
dea9d6e
Minor fix in docstring
JeanVanDyk May 29, 2025
5e9cde6
fix : hiding progressbar
JeanVanDyk May 30, 2025
ee701f2
Enhancement : Adding the possibility for the user to indicate priors …
JeanVanDyk May 30, 2025
5ee3cb4
Minor fix in docstring
JeanVanDyk Jun 4, 2025
08c520c
updating example notebook
JeanVanDyk Jun 4, 2025
b1681da
updating example notebook
JeanVanDyk Jun 4, 2025
fcfd059
Supporting Date format and adding exceptions for model related issues
JeanVanDyk Jun 4, 2025
64c97b7
changing column index restriction to label restriction
JeanVanDyk Jun 5, 2025
2996331
codespell
JeanVanDyk Jun 17, 2025
1da80fd
resolved merge
JeanVanDyk Jun 17, 2025
020f679
fixing merging issues
JeanVanDyk Jun 18, 2025
5039fda
fixing merging issues
JeanVanDyk Jun 18, 2025
4761b7e
codespell
JeanVanDyk Jun 18, 2025
bec5cd8
codespell
JeanVanDyk Jun 18, 2025
2d4d158
updating notebook
JeanVanDyk Jun 19, 2025
8d607b8
updating notebook with examples and adding time_variable_name parameter
JeanVanDyk Jun 20, 2025
d00f828
Merge branch 'main' into pr/480
drbenvincent Jun 20, 2025
942a1d5
fixing example
JeanVanDyk Jun 20, 2025
4aef14b
revert changes in docs and fixing issues
JeanVanDyk Jun 20, 2025
2b2cbdf
Removing the overriding of fit and calculate_impact, adding a test an…
JeanVanDyk Jun 20, 2025
6769aa7
Using all samples for uncertainty
JeanVanDyk Jun 23, 2025
692d85c
uml and docs
JeanVanDyk Jun 24, 2025
cc0979e
Changig Handler's name
JeanVanDyk Jul 8, 2025
72ebb8d
Adding the abstract class
JeanVanDyk Jul 8, 2025
d0f4a58
Updating the notebook
JeanVanDyk Jul 8, 2025
411aac7
Updating treatment type effect input
JeanVanDyk Jul 8, 2025
59db689
Updating integration test
JeanVanDyk Jul 8, 2025
4a10196
updating doctest
JeanVanDyk Jul 9, 2025
6de9707
Removing time variable
JeanVanDyk Jul 9, 2025
db8051f
Updating integration test
JeanVanDyk Jul 9, 2025
101d62a
Updating integration test to meet coverage reco
JeanVanDyk Jul 9, 2025
950623c
Updating notebook
JeanVanDyk Jul 9, 2025
5c4eb13
Updating integration test
JeanVanDyk Jul 9, 2025
6b1552a
typo
JeanVanDyk Jul 9, 2025
3ec5d69
updating notebook
JeanVanDyk Sep 23, 2025
ead382a
resolving conflicts
JeanVanDyk Sep 24, 2025
fd78418
resolving conflicts
JeanVanDyk Sep 24, 2025
1e5670b
resolving conflicts
JeanVanDyk Sep 24, 2025
0b669c9
resolving conflicts
JeanVanDyk Sep 24, 2025
3a693a1
resolving conflicts
JeanVanDyk Sep 24, 2025
42a7d1b
resolving conflicts
JeanVanDyk Sep 24, 2025
70c3426
resolving conflicts
JeanVanDyk Sep 24, 2025
e7b089a
resolving conflicts
JeanVanDyk Sep 24, 2025
eef8acb
resolving conflicts
JeanVanDyk Sep 24, 2025
9370efe
resolving conflicts
JeanVanDyk Sep 24, 2025
0c55851
resolving conflicts
JeanVanDyk Sep 24, 2025
f1a6622
resolving conflicts
JeanVanDyk Sep 24, 2025
d681a43
resolving conflicts
JeanVanDyk Sep 24, 2025
8b93362
resolving conflicts
JeanVanDyk Sep 24, 2025
67d696e
resolving conflicts
JeanVanDyk Sep 24, 2025
082d4d5
resolving conflicts
JeanVanDyk Sep 24, 2025
0a1b01e
resolving conflicts
JeanVanDyk Sep 24, 2025
c1fd388
resolving conflicts
JeanVanDyk Sep 24, 2025
d009b15
resolving conflicts
JeanVanDyk Sep 24, 2025
701fe13
resolving conflicts
JeanVanDyk Sep 24, 2025
46b453f
resolving conflicts
JeanVanDyk Sep 24, 2025
5318add
resolving conflicts
JeanVanDyk Sep 24, 2025
47cf44e
resolving conflicts
JeanVanDyk Sep 24, 2025
9bbc4cb
resolving conflicts
JeanVanDyk Sep 24, 2025
cf2a6f7
resolving conflicts
JeanVanDyk Sep 24, 2025
e5ee32c
resolve conflicts
JeanVanDyk Sep 24, 2025
7da8f91
removing errors in the notebook
JeanVanDyk Sep 26, 2025
c535061
Adding plot forest and comments to better compare models with and wit…
JeanVanDyk Sep 26, 2025
812cd4b
Adding mathjax formulas for examples
JeanVanDyk Sep 26, 2025
704066d
Typo
JeanVanDyk Sep 26, 2025
08463b5
Changing variables name in preprocessing
JeanVanDyk Sep 29, 2025
837e670
improving docstring
JeanVanDyk Sep 29, 2025
049c21b
adding references
JeanVanDyk Sep 29, 2025
cd51404
refining notebook
JeanVanDyk Sep 29, 2025
07ce191
adding ; after plots
JeanVanDyk Oct 1, 2025
7315283
small docstring update
drbenvincent Oct 3, 2025
bae5a12
add developer focussed module level docstring for ITS
drbenvincent Oct 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions causalpy/pymc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,147 @@
)
)
return self.idata


class InterventionTimeEstimator(PyMCModel):
r"""
Custom PyMC model to estimate the time an intervetnion took place.

defines the PyMC model :

.. math::
\alpha &\sim \mathrm{Normal}(0, 1) \\
\beta &\sim \mathrm{Normal}(0, 1) \\
s(t) &= \gamma_{i(t)} \quad \textrm{with} \quad \gamma_{k \in [0, ..., n_{seasons}-1]} \sim \mathrm{Normal}(0, 1)\\
base_{\mu}(t) &= \alpha + \beta \cdot t + s_t\\
\\
\tau &\sim \mathrm{Uniform}(0, 1) \\
w(t) &= sigmoid(t-\tau) \\
\\
level &\sim \mathrm{Normal}(0, 1) \\
trend &\sim \mathrm{Normal}(0, 1) \\
A &\sim \mathrm{Normal}(0, 1) \\
\lambda &\sim \mathrm{HalfNormal}(0, 1) \\
impulse(t) &= A \cdot exp(-\lambda \cdot |t-\tau|) \\
intervention(t) &= level + trend \cdot (t-\tau) + impulse_t\\
\\
\sigma &\sim \mathrm{Normal}(0, 1) \\
\mu(t) &= base_{\mu}(t) + w(t) \cdot intervention(t) \\
\\
y(t) &\sim \mathrm{Normal}(\mu (t), \sigma)

Example
--------
>>> import causalpy as cp
>>> import numpy as np
>>> from causalpy.pymc_models import InterventionTimeEstimator
>>> df = cp.load_data("its")
>>> y = df["y"].values
>>> t = df["t"].values
>>> coords = {"seasons": range(12)} # The data is monthly
>>> estimator = InterventionTimeEstimator()
>>> # We are trying to capture an impulse in the number of death per month due to Covid.
>>> estimator.fit(
... t,
... y,
... coords,
... priors={"impulse":[]}
... )
Inference data...
"""

def build_model(self, t, y, coords, time_range, grain_season, priors):
"""
Defines the PyMC model

:param t: An array of values representing the time over which y is spread
:param y: An array of values representing our outcome y
:param coords: An optional dictionary with the coordinate names for our instruments.
In particular, used to determine the number of seasons.
:param time_range: An optional tuple providing a specific time_range where the
intervention effect should have taken place.
:param priors: An optional dictionary of priors for the parameters of the
different distributions.
:code:`priors = {"alpha":[0, 5], "beta":[0,2], "level":[5, 5], "impulse":[1, 2 ,3]}`
"""

with self:
self.add_coords(coords)

Check warning on line 565 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L564-L565

Added lines #L564 - L565 were not covered by tests

if time_range is None:
time_range = (t.min(), t.max())

Check warning on line 568 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L567-L568

Added lines #L567 - L568 were not covered by tests

# --- Priors ---
switchpoint = pm.Uniform(

Check warning on line 571 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L571

Added line #L571 was not covered by tests
"switchpoint", lower=time_range[0], upper=time_range[1]
)
alpha = pm.Normal(name="alpha", mu=0, sigma=50)
beta = pm.Normal(name="beta", mu=0, sigma=50)
seasons = 0
if "seasons" in coords and len(coords["seasons"]) > 0:
season_idx = np.arange(len(y)) // grain_season % len(coords["seasons"])
seasons_effect = pm.Normal(

Check warning on line 579 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L574-L579

Added lines #L574 - L579 were not covered by tests
"seasons_effect", mu=0, sigma=50, dims="seasons"
)
seasons = seasons_effect[season_idx]

Check warning on line 582 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L582

Added line #L582 was not covered by tests

# --- Intervention effect ---
level = trend = impulse = 0

Check warning on line 585 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L585

Added line #L585 was not covered by tests

if "level" in priors:
mu, sigma = (

Check warning on line 588 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L587-L588

Added lines #L587 - L588 were not covered by tests
(0, 50)
if len(priors["level"]) != 2
else (priors["level"][0], priors["level"][1])
)
level = pm.Normal(

Check warning on line 593 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L593

Added line #L593 was not covered by tests
"level",
mu=mu,
sigma=sigma,
)
if "trend" in priors:
mu, sigma = (

Check warning on line 599 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L598-L599

Added lines #L598 - L599 were not covered by tests
(0, 50)
if len(priors["trend"]) != 2
else (priors["trend"][0], priors["trend"][1])
)
trend = pm.Normal("trend", mu=mu, sigma=sigma)
if "impulse" in priors:
mu, sigma1, sigma2 = (

Check warning on line 606 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L604-L606

Added lines #L604 - L606 were not covered by tests
(0, 50, 50)
if len(priors["impulse"]) != 3
else (
priors["impulse"][0],
priors["impulse"][1],
priors["impulse"][2],
)
)
impulse_amplitude = pm.Normal("impulse_amplitude", mu=mu, sigma=sigma1)
decay_rate = pm.HalfNormal("decay_rate", sigma=sigma2)
impulse = impulse_amplitude * pm.math.exp(

Check warning on line 617 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L615-L617

Added lines #L615 - L617 were not covered by tests
-decay_rate * abs(t - switchpoint)
)

# --- Parameterization ---
weight = pm.math.sigmoid(t - switchpoint)

Check warning on line 622 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L622

Added line #L622 was not covered by tests
# Compute and store the modelled time series
mu_ts = pm.Deterministic(name="mu_ts", var=alpha + beta * t + seasons)

Check warning on line 624 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L624

Added line #L624 was not covered by tests
# Compute and store the modelled intervention effect
mu_in = pm.Deterministic(

Check warning on line 626 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L626

Added line #L626 was not covered by tests
name="mu_in", var=level + trend * (t - switchpoint) + impulse
)
# Compute and store the the sum of the intervention and the time series
mu = pm.Deterministic("mu", mu_ts + weight * mu_in)
sigma = pm.HalfNormal("sigma", 1)

Check warning on line 631 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L630-L631

Added lines #L630 - L631 were not covered by tests

# --- Likelihood ---
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y)

Check warning on line 634 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L634

Added line #L634 was not covered by tests

def fit(self, t, y, coords, time_range=None, grain_season=1, priors={}, n=1000):
"""
Draw samples from posterior distribution
"""
self.build_model(t, y, coords, time_range, grain_season, priors)
with self:
self.idata = pm.sample(n, progressbar=False, **self.sample_kwargs)
return self.idata

Check warning on line 643 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L640-L643

Added lines #L640 - L643 were not covered by tests
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.