Skip to content

Commit a18f83b

Browse files
committed
full run through
Signed-off-by: Nathaniel <[email protected]>
1 parent e9df0d5 commit a18f83b

File tree

2 files changed

+2138
-1669
lines changed

2 files changed

+2138
-1669
lines changed

causalpy/pymc_models.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ def fit_outcome_model(
525525
priors={"b_outcome": [0, 1], "a_outcome": [0, 1], "sigma": 1},
526526
noncentred=True,
527527
normal_outcome=True,
528+
spline_component=False,
528529
):
529530
if not hasattr(self, "idata"):
530531
raise AttributeError("""Object is missing required attribute 'idata'
@@ -551,35 +552,40 @@ def fit_outcome_model(
551552
)
552553

553554
beta_ps_spline = pm.Normal("beta_ps_spline", 0, 1, size=34)
554-
beta_ps = pm.Normal("beta_ps", 0, 1)
555+
beta_ps = pm.Normal("beta_ps", 0, 1, size=2)
555556

556557
chosen = np.random.choice(range(propensity_scores.shape[1]))
557558
p = propensity_scores[:, chosen].values
558559

559-
B = dmatrix(
560-
"bs(ps, knots=knots, degree=3, include_intercept=True, lower_bound=0, upper_bound=1) - 1",
561-
{"ps": p, "knots": np.linspace(0, 1, 30)},
562-
)
563-
B_f = np.asarray(B, order="F")
564-
splines_summed = pm.Deterministic(
565-
"spline_features", pm.math.dot(B_f, beta_ps_spline.T)
566-
)
567-
568560
alpha_outcome = pm.Normal(
569561
"a_outcome", priors["a_outcome"][0], priors["a_outcome"][1]
570562
)
563+
571564
mu_outcome = (
572565
alpha_outcome
573566
+ pm.math.dot(X_data_outcome, beta)
574-
+ p * beta_ps
575-
+ splines_summed
567+
+ beta_ps[0] * p
568+
+ beta_ps[1] * (p * self.t.flatten())
576569
)
570+
571+
if spline_component:
572+
beta_ps_spline = pm.Normal("beta_ps_spline", 0, 1, size=34)
573+
B = dmatrix(
574+
"bs(ps, knots=knots, degree=3, include_intercept=True, lower_bound=0, upper_bound=1) - 1",
575+
{"ps": p, "knots": np.linspace(0, 1, 30)},
576+
)
577+
B_f = np.asarray(B, order="F")
578+
splines_summed = pm.Deterministic(
579+
"spline_features", pm.math.dot(B_f, beta_ps_spline.T)
580+
)
581+
mu_outcome = mu_outcome + splines_summed
582+
577583
sigma = pm.HalfNormal("sigma", priors["sigma"])
578584

579585
if normal_outcome:
580586
_ = pm.Normal("like", mu_outcome, sigma, observed=Y_data_)
581587
else:
582-
nu = pm.Exponential("nu", lam=1 / 30)
588+
nu = pm.Exponential("nu", lam=1 / 10)
583589
_ = pm.StudentT("like", nu=nu, mu=mu_outcome, sigma=sigma)
584590

585591
idata_outcome = pm.sample_prior_predictive(random_seed=random_seed)

0 commit comments

Comments
 (0)