2323import xarray as xr
2424from arviz import r2_score
2525from patsy import dmatrix
26+ from pymc_extras .prior import Prior
2627
2728from causalpy .utils import round_num
2829
@@ -90,7 +91,87 @@ class PyMCModel(pm.Model):
9091 Inference data...
9192 """
9293
93- def __init__ (self , sample_kwargs : Optional [Dict [str , Any ]] = None ):
94+ default_priors = {}
95+
96+ def priors_from_data (self , X , y ) -> Dict [str , Any ]:
97+ """
98+ Generate priors dynamically based on the input data.
99+
100+ This method allows models to set sensible priors that adapt to the scale
101+ and characteristics of the actual data being analyzed. It's called during
102+ the `fit()` method before model building, allowing data-driven prior
103+ specification that can improve model performance and convergence.
104+
105+ The priors returned by this method are merged with any user-specified
106+ priors (passed via the `priors` parameter in `__init__`), with
107+ user-specified priors taking precedence in case of conflicts.
108+
109+ Parameters
110+ ----------
111+ X : xarray.DataArray
112+ Input features/covariates with dimensions ["obs_ind", "coeffs"].
113+ Used to understand the scale and structure of predictors.
114+ y : xarray.DataArray
115+ Target variable with dimensions ["obs_ind", "treated_units"].
116+ Used to understand the scale and structure of the outcome.
117+
118+ Returns
119+ -------
120+ Dict[str, Prior]
121+ Dictionary mapping parameter names to Prior objects. The keys should
122+ match parameter names used in the model's `build_model()` method.
123+
124+ Notes
125+ -----
126+ The base implementation returns an empty dictionary, meaning no
127+ data-driven priors are set by default. Subclasses should override
128+ this method to implement data-adaptive prior specification.
129+
130+ **Priority Order for Priors:**
131+ 1. User-specified priors (passed to `__init__`)
132+ 2. Data-driven priors (from this method)
133+ 3. Default priors (from `default_priors` property)
134+
135+ Examples
136+ --------
137+ A typical implementation might scale priors based on data variance:
138+
139+ >>> def priors_from_data(self, X, y):
140+ ... y_std = float(y.std())
141+ ... return {
142+ ... "sigma": Prior("HalfNormal", sigma=y_std, dims="treated_units"),
143+ ... "beta": Prior(
144+ ... "Normal",
145+ ... mu=0,
146+ ... sigma=2 * y_std,
147+ ... dims=["treated_units", "coeffs"],
148+ ... ),
149+ ... }
150+
151+ Or set shape parameters based on data dimensions:
152+
153+ >>> def priors_from_data(self, X, y):
154+ ... n_predictors = X.shape[1]
155+ ... return {
156+ ... "beta": Prior(
157+ ... "Dirichlet",
158+ ... a=np.ones(n_predictors),
159+ ... dims=["treated_units", "coeffs"],
160+ ... )
161+ ... }
162+
163+ See Also
164+ --------
165+ WeightedSumFitter.priors_from_data : Example implementation that sets
166+ Dirichlet prior shape based on number of control units.
167+ """
168+ return {}
169+
170+ def __init__ (
171+ self ,
172+ sample_kwargs : Optional [Dict [str , Any ]] = None ,
173+ priors : dict [str , Any ] | None = None ,
174+ ):
94175 """
95176 :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
96177 :func:`pymc.sample` function. Defaults to an empty dictionary.
@@ -99,9 +180,13 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
99180 self .idata = None
100181 self .sample_kwargs = sample_kwargs if sample_kwargs is not None else {}
101182
183+ self .priors = {** self .default_priors , ** (priors or {})}
184+
102185 def build_model (self , X , y , coords ) -> None :
103186 """Build the model, must be implemented by subclass."""
104- raise NotImplementedError ("This method must be implemented by a subclass" )
187+ raise NotImplementedError (
188+ "This method must be implemented by a subclass"
189+ ) # pragma: no cover
105190
106191 def _data_setter (self , X : xr .DataArray ) -> None :
107192 """
@@ -144,6 +229,10 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
144229 # sample_posterior_predictive() if provided in sample_kwargs.
145230 random_seed = self .sample_kwargs .get ("random_seed" , None )
146231
232+ # Merge priors with precedence: user-specified > data-driven > defaults
233+ # Data-driven priors are computed first, then user-specified priors override them
234+ self .priors = {** self .priors_from_data (X , y ), ** self .priors }
235+
147236 self .build_model (X , y , coords )
148237 with self :
149238 self .idata = pm .sample (** self .sample_kwargs )
@@ -239,26 +328,36 @@ def print_coefficients_for_unit(
239328 ) -> None :
240329 """Print coefficients for a single unit"""
241330 # Determine the width of the longest label
242- max_label_length = max (len (name ) for name in labels + ["sigma " ])
331+ max_label_length = max (len (name ) for name in labels + ["y_hat_sigma " ])
243332
244333 for name in labels :
245334 coeff_samples = unit_coeffs .sel (coeffs = name )
246335 print_row (max_label_length , name , coeff_samples , round_to )
247336
248337 # Add coefficient for measurement std
249- print_row (max_label_length , "sigma " , unit_sigma , round_to )
338+ print_row (max_label_length , "y_hat_sigma " , unit_sigma , round_to )
250339
251340 print ("Model coefficients:" )
252341 coeffs = az .extract (self .idata .posterior , var_names = "beta" )
253342
254- # Always has treated_units dimension - no branching needed!
343+ # Check if sigma or y_hat_sigma variable exists
344+ sigma_var_name = None
345+ if "sigma" in self .idata .posterior :
346+ sigma_var_name = "sigma"
347+ elif "y_hat_sigma" in self .idata .posterior :
348+ sigma_var_name = "y_hat_sigma"
349+ else :
350+ raise ValueError (
351+ "Neither 'sigma' nor 'y_hat_sigma' found in posterior"
352+ ) # pragma: no cover
353+
255354 treated_units = coeffs .coords ["treated_units" ].values
256355 for unit in treated_units :
257356 if len (treated_units ) > 1 :
258357 print (f"\n Treated unit: { unit } " )
259358
260359 unit_coeffs = coeffs .sel (treated_units = unit )
261- unit_sigma = az .extract (self .idata .posterior , var_names = "sigma" ).sel (
360+ unit_sigma = az .extract (self .idata .posterior , var_names = sigma_var_name ).sel (
262361 treated_units = unit
263362 )
264363 print_coefficients_for_unit (unit_coeffs , unit_sigma , labels , round_to or 2 )
@@ -301,6 +400,15 @@ class LinearRegression(PyMCModel):
301400 Inference data...
302401 """ # noqa: W605
303402
403+ default_priors = {
404+ "beta" : Prior ("Normal" , mu = 0 , sigma = 50 , dims = ["treated_units" , "coeffs" ]),
405+ "y_hat" : Prior (
406+ "Normal" ,
407+ sigma = Prior ("HalfNormal" , sigma = 1 , dims = ["treated_units" ]),
408+ dims = ["obs_ind" , "treated_units" ],
409+ ),
410+ }
411+
304412 def build_model (self , X , y , coords ):
305413 """
306414 Defines the PyMC model
@@ -314,12 +422,11 @@ def build_model(self, X, y, coords):
314422 self .add_coords (coords )
315423 X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
316424 y = pm .Data ("y" , y , dims = ["obs_ind" , "treated_units" ])
317- beta = pm .Normal ("beta" , 0 , 50 , dims = ["treated_units" , "coeffs" ])
318- sigma = pm .HalfNormal ("sigma" , 1 , dims = "treated_units" )
425+ beta = self .priors ["beta" ].create_variable ("beta" )
319426 mu = pm .Deterministic (
320427 "mu" , pt .dot (X , beta .T ), dims = ["obs_ind" , "treated_units" ]
321428 )
322- pm . Normal ("y_hat" , mu , sigma , observed = y , dims = [ "obs_ind" , "treated_units" ] )
429+ self . priors [ "y_hat" ]. create_likelihood_variable ("y_hat" , mu = mu , observed = y )
323430
324431
325432class WeightedSumFitter (PyMCModel ):
@@ -362,23 +469,56 @@ class WeightedSumFitter(PyMCModel):
362469 Inference data...
363470 """ # noqa: W605
364471
472+ default_priors = {
473+ "y_hat" : Prior (
474+ "Normal" ,
475+ sigma = Prior ("HalfNormal" , sigma = 1 , dims = ["treated_units" ]),
476+ dims = ["obs_ind" , "treated_units" ],
477+ ),
478+ }
479+
480+ def priors_from_data (self , X , y ) -> Dict [str , Any ]:
481+ """
482+ Set Dirichlet prior for weights based on number of control units.
483+
484+ For synthetic control models, this method sets the shape parameter of the
485+ Dirichlet prior on the control unit weights (`beta`) to be uniform across
486+ all available control units. This ensures that all control units have
487+ equal prior probability of contributing to the synthetic control.
488+
489+ Parameters
490+ ----------
491+ X : xarray.DataArray
492+ Control unit data with shape (n_obs, n_control_units).
493+ y : xarray.DataArray
494+ Treated unit outcome data.
495+
496+ Returns
497+ -------
498+ Dict[str, Prior]
499+ Dictionary containing:
500+ - "beta": Dirichlet prior with shape=(1,...,1) for n_control_units
501+ """
502+ n_predictors = X .shape [1 ]
503+ return {
504+ "beta" : Prior (
505+ "Dirichlet" , a = np .ones (n_predictors ), dims = ["treated_units" , "coeffs" ]
506+ ),
507+ }
508+
365509 def build_model (self , X , y , coords ):
366510 """
367511 Defines the PyMC model
368512 """
369513 with self :
370514 self .add_coords (coords )
371- n_predictors = X .sizes ["coeffs" ]
372515 X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
373516 y = pm .Data ("y" , y , dims = ["obs_ind" , "treated_units" ])
374- beta = pm .Dirichlet (
375- "beta" , a = np .ones (n_predictors ), dims = ["treated_units" , "coeffs" ]
376- )
377- sigma = pm .HalfNormal ("sigma" , 1 , dims = "treated_units" )
517+ beta = self .priors ["beta" ].create_variable ("beta" )
378518 mu = pm .Deterministic (
379519 "mu" , pt .dot (X , beta .T ), dims = ["obs_ind" , "treated_units" ]
380520 )
381- pm . Normal ("y_hat" , mu , sigma , observed = y , dims = [ "obs_ind" , "treated_units" ] )
521+ self . priors [ "y_hat" ]. create_likelihood_variable ("y_hat" , mu = mu , observed = y )
382522
383523
384524class InstrumentalVariableRegression (PyMCModel ):
@@ -568,21 +708,18 @@ class PropensityScore(PyMCModel):
568708 Inference...
569709 """ # noqa: W605
570710
571- def build_model (self , X , t , coords , prior , noncentred ):
711+ default_priors = {
712+ "b" : Prior ("Normal" , mu = 0 , sigma = 1 , dims = "coeffs" ),
713+ }
714+
715+ def build_model (self , X , t , coords , prior = None , noncentred = True ):
572716 "Defines the PyMC propensity model"
573717 with self :
574718 self .add_coords (coords )
575719 X_data = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
576720 t_data = pm .Data ("t" , t .flatten (), dims = "obs_ind" )
577- if noncentred :
578- mu_beta , sigma_beta = prior ["b" ]
579- beta_std = pm .Normal ("beta_std" , 0 , 1 , dims = "coeffs" )
580- b = pm .Deterministic (
581- "beta_" , mu_beta + sigma_beta * beta_std , dims = "coeffs"
582- )
583- else :
584- b = pm .Normal ("b" , mu = prior ["b" ][0 ], sigma = prior ["b" ][1 ], dims = "coeffs" )
585- mu = pm .math .dot (X_data , b )
721+ b = self .priors ["b" ].create_variable ("b" )
722+ mu = pt .dot (X_data , b )
586723 p = pm .Deterministic ("p" , pm .math .invlogit (mu ))
587724 pm .Bernoulli ("t_pred" , p = p , observed = t_data , dims = "obs_ind" )
588725
0 commit comments