Skip to content

Commit b7300e7

Browse files
committed
add default_priors and support for custom priors
1 parent b35001b commit b7300e7

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

causalpy/pymc_models.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pytensor.tensor as pt
2323
import xarray as xr
2424
from arviz import r2_score
25+
from pymc_extras.prior import Prior
2526

2627
from causalpy.utils import round_num
2728

@@ -68,7 +69,13 @@ class PyMCModel(pm.Model):
6869
Inference data...
6970
"""
7071

71-
def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
72+
default_priors: dict[str, Any]
73+
74+
def __init__(
75+
self,
76+
sample_kwargs: Optional[Dict[str, Any]] = None,
77+
priors: dict[str, Any] | None = None,
78+
):
7279
"""
7380
:param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
7481
:func:`pymc.sample` function. Defaults to an empty dictionary.
@@ -77,6 +84,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
7784
self.idata = None
7885
self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {}
7986

87+
self.priors = {**self.default_priors, **(priors or {})}
88+
8089
def build_model(self, X, y, coords) -> None:
8190
"""Build the model, must be implemented by subclass."""
8291
raise NotImplementedError("This method must be implemented by a subclass")
@@ -237,6 +246,11 @@ class LinearRegression(PyMCModel):
237246
Inference data...
238247
""" # noqa: W605
239248

249+
default_priors = {
250+
"beta": Prior("Normal", mu=0, sigma=50, dims="coeffs"),
251+
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)),
252+
}
253+
240254
def build_model(self, X, y, coords):
241255
"""
242256
Defines the PyMC model
@@ -245,10 +259,9 @@ def build_model(self, X, y, coords):
245259
self.add_coords(coords)
246260
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
247261
y = pm.Data("y", y, dims="obs_ind")
248-
beta = pm.Normal("beta", 0, 50, dims="coeffs")
249-
sigma = pm.HalfNormal("sigma", 1)
262+
beta = self.priors["beta"].create_variable("beta")
250263
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
251-
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")
264+
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
252265

253266

254267
class WeightedSumFitter(PyMCModel):
@@ -276,6 +289,10 @@ class WeightedSumFitter(PyMCModel):
276289
Inference data...
277290
""" # noqa: W605
278291

292+
default_priors = {
293+
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)),
294+
}
295+
279296
def build_model(self, X, y, coords):
280297
"""
281298
Defines the PyMC model
@@ -286,9 +303,8 @@ def build_model(self, X, y, coords):
286303
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
287304
y = pm.Data("y", y[:, 0], dims="obs_ind")
288305
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
289-
sigma = pm.HalfNormal("sigma", 1)
290306
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
291-
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")
307+
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
292308

293309

294310
class InstrumentalVariableRegression(PyMCModel):
@@ -477,13 +493,17 @@ class PropensityScore(PyMCModel):
477493
Inference...
478494
""" # noqa: W605
479495

496+
default_priors = {
497+
"b": Prior("Normal", mu=0, sigma=1, dims="coeffs"),
498+
}
499+
480500
def build_model(self, X, t, coords):
481501
"Defines the PyMC propensity model"
482502
with self:
483503
self.add_coords(coords)
484504
X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"])
485505
t_data = pm.Data("t", t.flatten(), dims="obs_ind")
486-
b = pm.Normal("b", mu=0, sigma=1, dims="coeffs")
506+
b = self.priors["b"].create_variable("b")
487507
mu = pm.math.dot(X_data, b)
488508
p = pm.Deterministic("p", pm.math.invlogit(mu))
489509
pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind")

0 commit comments

Comments
 (0)