Skip to content

Commit a6573ef

Browse files
committed
Adding Priors from experimental
1 parent b9c3a97 commit a6573ef

File tree

2 files changed

+242
-210
lines changed

2 files changed

+242
-210
lines changed

causalpy/pymc_models.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -968,9 +968,14 @@ class LinearChangePointDetection(PyMCModel):
968968
default_priors = {
969969
"beta": Prior("Normal", mu=0, sigma=5, dims=["treated_units", "coeffs"]),
970970
"level": Prior("Normal", mu=0, sigma=5),
971-
"trend": Prior("Normal", mu=0, sigma=0.5),
971+
"trend": Prior("Normal", mu=0, sigma=5),
972972
"impulse_amplitude": Prior("Normal", mu=0, sigma=5),
973973
"impulse_decay_rate": Prior("HalfNormal", sigma=5),
974+
"y_hat": Prior(
975+
"Normal",
976+
sigma=Prior("HalfNormal", sigma=5),
977+
dims=["obs_ind", "treated_units"],
978+
),
974979
}
975980

976981
def __init__(
@@ -1025,9 +1030,13 @@ def build_model(self, X, y, coords):
10251030
upper_bound = pm.Data("upper_bound", self.time_range[1])
10261031

10271032
# --- Priors ---
1028-
change_point = pm.Uniform(
1029-
"change_point", lower=lower_bound, upper=upper_bound
1033+
# --- change_point unconstrained mapping ---
1034+
tau_un = pm.Normal("tau_un", 0, 1)
1035+
change_point = pm.Deterministic(
1036+
"change_point",
1037+
lower_bound + (upper_bound - 1 - lower_bound) * pm.math.sigmoid(tau_un),
10301038
)
1039+
10311040
delta_t = pm.Deterministic(
10321041
name="delta_t",
10331042
var=(t - change_point),
@@ -1077,25 +1086,17 @@ def build_model(self, X, y, coords):
10771086
vars=0,
10781087
)
10791088
)
1080-
# Compute and store the sum of the base time series and the intervention's effect
1089+
10811090
mu = pm.Deterministic(
10821091
"mu",
10831092
mu_ts + (weight * mu_in)[:, None],
10841093
dims=["obs_ind", "treated_units"],
10851094
)
10861095

10871096
# --- Likelihood ---
1088-
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
1089-
# Likelihood of the base time series
1090-
pm.Normal("y_ts", mu=mu_ts, sigma=sigma, dims=["obs_ind", "treated_units"])
1097+
10911098
# Likelihodd of the base time series and the intervention's effect
1092-
pm.Normal(
1093-
"y_hat",
1094-
mu=mu,
1095-
sigma=sigma,
1096-
observed=y,
1097-
dims=["obs_ind", "treated_units"],
1098-
)
1099+
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
10991100

11001101
def predict(self, X):
11011102
"""
@@ -1112,7 +1113,7 @@ def predict(self, X):
11121113
with self: # sample with new input data
11131114
pp = pm.sample_posterior_predictive(
11141115
self.idata,
1115-
var_names=["y_hat", "y_ts", "mu", "mu_ts", "mu_in"],
1116+
var_names=["y_hat", "mu", "mu_ts", "mu_in"],
11161117
progressbar=False,
11171118
random_seed=random_seed,
11181119
)
@@ -1169,7 +1170,7 @@ def score(self, X, y) -> pd.Series:
11691170
def calculate_impact(
11701171
self, y_true: xr.DataArray, y_pred: az.InferenceData
11711172
) -> xr.DataArray:
1172-
impact = y_true - y_pred["posterior_predictive"]["y_ts"]
1173+
impact = y_true - y_pred["posterior_predictive"]["mu_ts"]
11731174
return impact.transpose(..., "obs_ind")
11741175

11751176
def set_time_range(self, time_range, data):

docs/source/notebooks/cp_covid.ipynb

Lines changed: 225 additions & 194 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)