Skip to content

Commit b9c3a97

Browse files
committed
Using Priors from Pymc-experimental for the cp_param_effects
1 parent 4211ddc commit b9c3a97

File tree

4 files changed

+175
-134
lines changed

4 files changed

+175
-134
lines changed

causalpy/experiments/change_point_detection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,14 +312,14 @@ def _bayesian_plot(
312312
# Plot predicted values after change point (with HDI)
313313
h_line, h_patch = plot_xY(
314314
self.datapre.index,
315-
self.pre_pred["posterior_predictive"].mu_ts.isel(treated_units=0),
315+
self.pre_pred["posterior_predictive"].mu.isel(treated_units=0),
316316
ax=ax[0],
317317
plot_hdi_kwargs={"color": "yellowgreen"},
318318
)
319319

320320
h_line, h_patch = plot_xY(
321321
self.datapost.index,
322-
self.post_pred["posterior_predictive"].mu_ts.isel(treated_units=0),
322+
self.post_pred["posterior_predictive"].mu.isel(treated_units=0),
323323
ax=ax[0],
324324
plot_hdi_kwargs={"color": "yellowgreen"},
325325
)
@@ -330,7 +330,7 @@ def _bayesian_plot(
330330
# pre-intervention period
331331
h_line, h_patch = plot_xY(
332332
self.datapre.index,
333-
self.pre_pred["posterior_predictive"].mu.isel(treated_units=0),
333+
self.pre_pred["posterior_predictive"].mu_ts.isel(treated_units=0),
334334
ax=ax[0],
335335
plot_hdi_kwargs={"color": "C0"},
336336
)
@@ -351,7 +351,7 @@ def _bayesian_plot(
351351
# post intervention period
352352
h_line, h_patch = plot_xY(
353353
self.datapost.index,
354-
self.post_pred["posterior_predictive"].mu.isel(treated_units=0),
354+
self.post_pred["posterior_predictive"].mu_ts.isel(treated_units=0),
355355
ax=ax[0],
356356
plot_hdi_kwargs={"color": "C1"},
357357
)
@@ -367,7 +367,7 @@ def _bayesian_plot(
367367
)
368368
# Shaded causal effect
369369
post_pred_mu = (
370-
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
370+
az.extract(self.post_pred, group="posterior_predictive", var_names="mu_ts")
371371
.isel(treated_units=0)
372372
.mean("sample")
373373
) # Add .mean("sample") to get 1D array

causalpy/pymc_models.py

Lines changed: 39 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ class initialisation.
890890

891891
class LinearChangePointDetection(PyMCModel):
892892
r"""
893-
Custom PyMC model to estimate the time an intervention took place.
893+
Custom PyMC model to estimate one ChangePoint in time series.
894894
895895
This model implements three types of changepoints: level shift, trend change, and impulse response.
896896
While the underlying mathematical framework could theoretically be applied to other changepoint
@@ -965,6 +965,14 @@ class LinearChangePointDetection(PyMCModel):
965965
Inference ...
966966
"""
967967

968+
default_priors = {
969+
"beta": Prior("Normal", mu=0, sigma=5, dims=["treated_units", "coeffs"]),
970+
"level": Prior("Normal", mu=0, sigma=5),
971+
"trend": Prior("Normal", mu=0, sigma=0.5),
972+
"impulse_amplitude": Prior("Normal", mu=0, sigma=5),
973+
"impulse_decay_rate": Prior("HalfNormal", sigma=5),
974+
}
975+
968976
def __init__(
969977
self,
970978
cp_effect_type: str | list[str],
@@ -984,50 +992,14 @@ def __init__(
984992
:param sample_kwargs: Optional dictionary of arguments passed to pm.sample().
985993
"""
986994

987-
super().__init__(sample_kwargs)
988-
989-
# Hardcoded default priors
990-
self.DEFAULT_BETA_PRIOR = (0, 5)
991-
self.DEFAULT_LEVEL_PRIOR = (0, 5)
992-
self.DEFAULT_TREND_PRIOR = (0, 0.5)
993-
self.DEFAULT_IMPULSE_PRIOR = (0, 5, 5)
995+
super().__init__(sample_kwargs, cp_effect_param)
994996

995997
# Make sure we get a list of all expected effects
996998
if isinstance(cp_effect_type, str):
997999
self.cp_effect_type = [cp_effect_type]
9981000
else:
9991001
self.cp_effect_type = cp_effect_type
10001002

1001-
# Defining the priors here
1002-
self.cp_effect_param = {} if cp_effect_param is None else cp_effect_param
1003-
1004-
if "level" in self.cp_effect_type:
1005-
if (
1006-
"level" not in self.cp_effect_param
1007-
or len(self.cp_effect_param["level"]) != 2
1008-
):
1009-
self.cp_effect_param["level"] = self.DEFAULT_LEVEL_PRIOR
1010-
else:
1011-
self.cp_effect_param["level"] = self.cp_effect_param["level"]
1012-
1013-
if "trend" in self.cp_effect_type:
1014-
if (
1015-
"trend" not in self.cp_effect_param
1016-
or len(self.cp_effect_param["trend"]) != 2
1017-
):
1018-
self.cp_effect_param["trend"] = self.DEFAULT_TREND_PRIOR
1019-
else:
1020-
self.cp_effect_param["trend"] = self.cp_effect_param["trend"]
1021-
1022-
if "impulse" in self.cp_effect_type:
1023-
if (
1024-
"impulse" not in self.cp_effect_param
1025-
or len(self.cp_effect_param["impulse"]) != 3
1026-
):
1027-
self.cp_effect_param["impulse"] = self.DEFAULT_IMPULSE_PRIOR
1028-
else:
1029-
self.cp_effect_param["impulse"] = self.cp_effect_param["impulse"]
1030-
10311003
def build_model(self, X, y, coords):
10321004
"""
10331005
Defines the PyMC model
@@ -1061,39 +1033,25 @@ def build_model(self, X, y, coords):
10611033
var=(t - change_point),
10621034
dims=["obs_ind"],
10631035
)
1064-
beta = pm.Normal(
1065-
name="beta",
1066-
mu=self.DEFAULT_BETA_PRIOR[0],
1067-
sigma=self.DEFAULT_BETA_PRIOR[1],
1068-
dims=["treated_units", "coeffs"],
1069-
)
1036+
beta = self.priors["beta"].create_variable("beta")
10701037

10711038
# --- Intervention effect ---
10721039
mu_in_components = []
10731040

1074-
if "level" in self.cp_effect_param:
1075-
level = pm.Normal(
1076-
"level",
1077-
mu=self.cp_effect_param["level"][0],
1078-
sigma=self.cp_effect_param["level"][1],
1079-
)
1041+
if "level" in self.cp_effect_type:
1042+
level = self.priors["level"].create_variable("level")
10801043
mu_in_components.append(level)
1081-
if "trend" in self.cp_effect_param:
1082-
trend = pm.Normal(
1083-
"trend",
1084-
mu=self.cp_effect_param["trend"][0],
1085-
sigma=self.cp_effect_param["trend"][1],
1086-
)
1044+
1045+
if "trend" in self.cp_effect_type:
1046+
trend = self.priors["trend"].create_variable("trend")
10871047
mu_in_components.append(trend * delta_t)
1088-
if "impulse" in self.cp_effect_param:
1089-
impulse_amplitude = pm.Normal(
1090-
"impulse_amplitude",
1091-
mu=self.cp_effect_param["impulse"][0],
1092-
sigma=self.cp_effect_param["impulse"][1],
1048+
1049+
if "impulse" in self.cp_effect_type:
1050+
impulse_amplitude = self.priors["impulse_amplitude"].create_variable(
1051+
"impulse_amplitude"
10931052
)
1094-
decay_rate = pm.HalfNormal(
1095-
"decay_rate",
1096-
sigma=self.cp_effect_param["impulse"][2],
1053+
decay_rate = self.priors["impulse_decay_rate"].create_variable(
1054+
"impulse_decay_rate"
10971055
)
10981056
impulse = pm.Deterministic(
10991057
"impulse",
@@ -1104,8 +1062,8 @@ def build_model(self, X, y, coords):
11041062
# --- Parameterization ---
11051063
weight = pm.math.sigmoid(delta_t)
11061064
# Compute and store the base time series
1107-
mu = pm.Deterministic(
1108-
name="mu", var=pm.math.dot(X, beta.T), dims=["obs_ind", "treated_units"]
1065+
mu_ts = pm.Deterministic(
1066+
name="mu_ts", var=pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
11091067
)
11101068
# Compute and store the modelled intervention effect
11111069
mu_in = (
@@ -1120,20 +1078,20 @@ def build_model(self, X, y, coords):
11201078
)
11211079
)
11221080
# Compute and store the sum of the base time series and the intervention's effect
1123-
mu_ts = pm.Deterministic(
1124-
"mu_ts",
1125-
mu + (weight * mu_in)[:, None],
1081+
mu = pm.Deterministic(
1082+
"mu",
1083+
mu_ts + (weight * mu_in)[:, None],
11261084
dims=["obs_ind", "treated_units"],
11271085
)
1128-
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
11291086

11301087
# --- Likelihood ---
1088+
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
11311089
# Likelihood of the base time series
1132-
pm.Normal("y_hat", mu=mu, sigma=sigma, dims=["obs_ind", "treated_units"])
1090+
pm.Normal("y_ts", mu=mu_ts, sigma=sigma, dims=["obs_ind", "treated_units"])
11331091
# Likelihodd of the base time series and the intervention's effect
11341092
pm.Normal(
1135-
"y_ts",
1136-
mu=mu_ts,
1093+
"y_hat",
1094+
mu=mu,
11371095
sigma=sigma,
11381096
observed=y,
11391097
dims=["obs_ind", "treated_units"],
@@ -1193,8 +1151,8 @@ def score(self, X, y) -> pd.Series:
11931151
"""
11941152
Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
11951153
"""
1196-
mu_ts = self.predict(X)
1197-
mu_data = az.extract(mu_ts, group="posterior_predictive", var_names="mu_ts")
1154+
mu = self.predict(X)
1155+
mu_data = az.extract(mu, group="posterior_predictive", var_names="mu")
11981156

11991157
scores = {}
12001158

@@ -1208,6 +1166,12 @@ def score(self, X, y) -> pd.Series:
12081166

12091167
return pd.Series(scores)
12101168

1169+
def calculate_impact(
1170+
self, y_true: xr.DataArray, y_pred: az.InferenceData
1171+
) -> xr.DataArray:
1172+
impact = y_true - y_pred["posterior_predictive"]["y_ts"]
1173+
return impact.transpose(..., "obs_ind")
1174+
12111175
def set_time_range(self, time_range, data):
12121176
"""
12131177
Set time_range.

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/notebooks/cp_covid.ipynb

Lines changed: 128 additions & 51 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)