2222import pytensor .tensor as pt
2323import xarray as xr
2424from arviz import r2_score
25+ from pymc_extras .prior import Prior
2526
2627from causalpy .utils import round_num
2728
@@ -68,7 +69,13 @@ class PyMCModel(pm.Model):
6869 Inference data...
6970 """
7071
71- def __init__ (self , sample_kwargs : Optional [Dict [str , Any ]] = None ):
72+ default_priors : dict [str , Any ]
73+
74+ def __init__ (
75+ self ,
76+ sample_kwargs : Optional [Dict [str , Any ]] = None ,
77+ priors : dict [str , Any ] | None = None ,
78+ ):
7279 """
7380 :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
7481 :func:`pymc.sample` function. Defaults to an empty dictionary.
@@ -77,6 +84,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
7784 self .idata = None
7885 self .sample_kwargs = sample_kwargs if sample_kwargs is not None else {}
7986
87+ self .priors = {** self .default_priors , ** (priors or {})}
88+
8089 def build_model (self , X , y , coords ) -> None :
8190 """Build the model, must be implemented by subclass."""
8291 raise NotImplementedError ("This method must be implemented by a subclass" )
@@ -237,6 +246,11 @@ class LinearRegression(PyMCModel):
237246 Inference data...
238247 """ # noqa: W605
239248
249+ default_priors = {
250+ "beta" : Prior ("Normal" , mu = 0 , sigma = 50 , dims = "coeffs" ),
251+ "y_hat" : Prior ("Normal" , sigma = Prior ("HalfNormal" , sigma = 1 )),
252+ }
253+
240254 def build_model (self , X , y , coords ):
241255 """
242256 Defines the PyMC model
@@ -245,10 +259,9 @@ def build_model(self, X, y, coords):
245259 self .add_coords (coords )
246260 X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
247261 y = pm .Data ("y" , y , dims = "obs_ind" )
248- beta = pm .Normal ("beta" , 0 , 50 , dims = "coeffs" )
249- sigma = pm .HalfNormal ("sigma" , 1 )
262+ beta = self .priors ["beta" ].create_variable ("beta" )
250263 mu = pm .Deterministic ("mu" , pm .math .dot (X , beta ), dims = "obs_ind" )
251- pm . Normal ("y_hat" , mu , sigma , observed = y , dims = "obs_ind" )
264+ self . priors [ "y_hat" ]. create_likelihood_variable ("y_hat" , mu = mu , observed = y )
252265
253266
254267class WeightedSumFitter (PyMCModel ):
@@ -276,6 +289,10 @@ class WeightedSumFitter(PyMCModel):
276289 Inference data...
277290 """ # noqa: W605
278291
292+ default_priors = {
293+ "y_hat" : Prior ("Normal" , sigma = Prior ("HalfNormal" , sigma = 1 )),
294+ }
295+
279296 def build_model (self , X , y , coords ):
280297 """
281298 Defines the PyMC model
@@ -286,9 +303,8 @@ def build_model(self, X, y, coords):
286303 X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
287304 y = pm .Data ("y" , y [:, 0 ], dims = "obs_ind" )
288305 beta = pm .Dirichlet ("beta" , a = np .ones (n_predictors ), dims = "coeffs" )
289- sigma = pm .HalfNormal ("sigma" , 1 )
290306 mu = pm .Deterministic ("mu" , pm .math .dot (X , beta ), dims = "obs_ind" )
291- pm . Normal ("y_hat" , mu , sigma , observed = y , dims = "obs_ind" )
307+ self . priors [ "y_hat" ]. create_likelihood_variable ("y_hat" , mu = mu , observed = y )
292308
293309
294310class InstrumentalVariableRegression (PyMCModel ):
@@ -477,13 +493,17 @@ class PropensityScore(PyMCModel):
477493 Inference...
478494 """ # noqa: W605
479495
496+ default_priors = {
497+ "b" : Prior ("Normal" , mu = 0 , sigma = 1 , dims = "coeffs" ),
498+ }
499+
480500 def build_model (self , X , t , coords ):
481501 "Defines the PyMC propensity model"
482502 with self :
483503 self .add_coords (coords )
484504 X_data = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
485505 t_data = pm .Data ("t" , t .flatten (), dims = "obs_ind" )
486- b = pm . Normal ( "b" , mu = 0 , sigma = 1 , dims = "coeffs " )
506+ b = self . priors [ "b" ]. create_variable ( "b " )
487507 mu = pm .math .dot (X_data , b )
488508 p = pm .Deterministic ("p" , pm .math .invlogit (mu ))
489509 pm .Bernoulli ("t_pred" , p = p , observed = t_data , dims = "obs_ind" )
0 commit comments