2323import xarray as xr
2424from arviz import r2_score
2525from patsy import dmatrix
26+ from pymc_extras .prior import Prior
2627
2728from causalpy .utils import round_num
2829
@@ -90,7 +91,87 @@ class PyMCModel(pm.Model):
9091 Inference data...
9192 """
9293
93- def __init__ (self , sample_kwargs : Optional [Dict [str , Any ]] = None ):
94+ default_priors = {}
95+
96+ def priors_from_data (self , X , y ) -> Dict [str , Any ]:
97+ """
98+ Generate priors dynamically based on the input data.
99+
100+ This method allows models to set sensible priors that adapt to the scale
101+ and characteristics of the actual data being analyzed. It's called during
102+ the `fit()` method before model building, allowing data-driven prior
103+ specification that can improve model performance and convergence.
104+
105+ The priors returned by this method are merged with any user-specified
106+ priors (passed via the `priors` parameter in `__init__`), with
107+ user-specified priors taking precedence in case of conflicts.
108+
109+ Parameters
110+ ----------
111+ X : xarray.DataArray
112+ Input features/covariates with dimensions ["obs_ind", "coeffs"].
113+ Used to understand the scale and structure of predictors.
114+ y : xarray.DataArray
115+ Target variable with dimensions ["obs_ind", "treated_units"].
116+ Used to understand the scale and structure of the outcome.
117+
118+ Returns
119+ -------
120+ Dict[str, Prior]
121+ Dictionary mapping parameter names to Prior objects. The keys should
122+ match parameter names used in the model's `build_model()` method.
123+
124+ Notes
125+ -----
126+ The base implementation returns an empty dictionary, meaning no
127+ data-driven priors are set by default. Subclasses should override
128+ this method to implement data-adaptive prior specification.
129+
130+ **Priority Order for Priors:**
131+ 1. User-specified priors (passed to `__init__`)
132+ 2. Data-driven priors (from this method)
133+ 3. Default priors (from `default_priors` property)
134+
135+ Examples
136+ --------
137+ A typical implementation might scale priors based on data variance:
138+
139+ >>> def priors_from_data(self, X, y):
140+ ... y_std = float(y.std())
141+ ... return {
142+ ... "sigma": Prior("HalfNormal", sigma=y_std, dims="treated_units"),
143+ ... "beta": Prior(
144+ ... "Normal",
145+ ... mu=0,
146+ ... sigma=2 * y_std,
147+ ... dims=["treated_units", "coeffs"],
148+ ... ),
149+ ... }
150+
151+ Or set shape parameters based on data dimensions:
152+
153+ >>> def priors_from_data(self, X, y):
154+ ... n_predictors = X.shape[1]
155+ ... return {
156+ ... "beta": Prior(
157+ ... "Dirichlet",
158+ ... a=np.ones(n_predictors),
159+ ... dims=["treated_units", "coeffs"],
160+ ... )
161+ ... }
162+
163+ See Also
164+ --------
165+ WeightedSumFitter.priors_from_data : Example implementation that sets
166+ Dirichlet prior shape based on number of control units.
167+ """
168+ return {}
169+
170+ def __init__ (
171+ self ,
172+ sample_kwargs : Optional [Dict [str , Any ]] = None ,
173+ priors : dict [str , Any ] | None = None ,
174+ ):
94175 """
95176 :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
96177 :func:`pymc.sample` function. Defaults to an empty dictionary.
@@ -99,9 +180,13 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
99180 self .idata = None
100181 self .sample_kwargs = sample_kwargs if sample_kwargs is not None else {}
101182
183+ self .priors = {** self .default_priors , ** (priors or {})}
184+
102185 def build_model (self , X , y , coords ) -> None :
103186 """Build the model, must be implemented by subclass."""
104- raise NotImplementedError ("This method must be implemented by a subclass" )
187+ raise NotImplementedError (
188+ "This method must be implemented by a subclass"
189+ ) # pragma: no cover
105190
106191 def _data_setter (self , X : xr .DataArray ) -> None :
107192 """
@@ -144,6 +229,10 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
144229 # sample_posterior_predictive() if provided in sample_kwargs.
145230 random_seed = self .sample_kwargs .get ("random_seed" , None )
146231
232+ # Merge priors with precedence: user-specified > data-driven > defaults
233+ # Data-driven priors are computed first, then user-specified priors override them
234+ self .priors = {** self .priors_from_data (X , y ), ** self .priors }
235+
147236 self .build_model (X , y , coords )
148237 with self :
149238 self .idata = pm .sample (** self .sample_kwargs )
@@ -260,26 +349,36 @@ def print_coefficients_for_unit(
260349 ) -> None :
261350 """Print coefficients for a single unit"""
262351 # Determine the width of the longest label
263- max_label_length = max (len (name ) for name in labels + ["sigma " ])
352+ max_label_length = max (len (name ) for name in labels + ["y_hat_sigma " ])
264353
265354 for name in labels :
266355 coeff_samples = unit_coeffs .sel (coeffs = name )
267356 print_row (max_label_length , name , coeff_samples , round_to )
268357
269358 # Add coefficient for measurement std
270- print_row (max_label_length , "sigma " , unit_sigma , round_to )
359+ print_row (max_label_length , "y_hat_sigma " , unit_sigma , round_to )
271360
272361 print ("Model coefficients:" )
273362 coeffs = az .extract (self .idata .posterior , var_names = "beta" )
274363
275- # Always has treated_units dimension - no branching needed!
364+ # Check if sigma or y_hat_sigma variable exists
365+ sigma_var_name = None
366+ if "sigma" in self .idata .posterior :
367+ sigma_var_name = "sigma"
368+ elif "y_hat_sigma" in self .idata .posterior :
369+ sigma_var_name = "y_hat_sigma"
370+ else :
371+ raise ValueError (
372+ "Neither 'sigma' nor 'y_hat_sigma' found in posterior"
373+ ) # pragma: no cover
374+
276375 treated_units = coeffs .coords ["treated_units" ].values
277376 for unit in treated_units :
278377 if len (treated_units ) > 1 :
279378 print (f"\n Treated unit: { unit } " )
280379
281380 unit_coeffs = coeffs .sel (treated_units = unit )
282- unit_sigma = az .extract (self .idata .posterior , var_names = "sigma" ).sel (
381+ unit_sigma = az .extract (self .idata .posterior , var_names = sigma_var_name ).sel (
283382 treated_units = unit
284383 )
285384 print_coefficients_for_unit (unit_coeffs , unit_sigma , labels , round_to or 2 )
@@ -322,6 +421,15 @@ class LinearRegression(PyMCModel):
322421 Inference data...
323422 """ # noqa: W605
324423
424+ default_priors = {
425+ "beta" : Prior ("Normal" , mu = 0 , sigma = 50 , dims = ["treated_units" , "coeffs" ]),
426+ "y_hat" : Prior (
427+ "Normal" ,
428+ sigma = Prior ("HalfNormal" , sigma = 1 , dims = ["treated_units" ]),
429+ dims = ["obs_ind" , "treated_units" ],
430+ ),
431+ }
432+
325433 def build_model (self , X , y , coords ):
326434 """
327435 Defines the PyMC model
@@ -335,12 +443,11 @@ def build_model(self, X, y, coords):
335443 self .add_coords (coords )
336444 X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
337445 y = pm .Data ("y" , y , dims = ["obs_ind" , "treated_units" ])
338- beta = pm .Normal ("beta" , 0 , 50 , dims = ["treated_units" , "coeffs" ])
339- sigma = pm .HalfNormal ("sigma" , 1 , dims = "treated_units" )
446+ beta = self .priors ["beta" ].create_variable ("beta" )
340447 mu = pm .Deterministic (
341448 "mu" , pt .dot (X , beta .T ), dims = ["obs_ind" , "treated_units" ]
342449 )
343- pm . Normal ("y_hat" , mu , sigma , observed = y , dims = [ "obs_ind" , "treated_units" ] )
450+ self . priors [ "y_hat" ]. create_likelihood_variable ("y_hat" , mu = mu , observed = y )
344451
345452
346453class WeightedSumFitter (PyMCModel ):
@@ -383,23 +490,56 @@ class WeightedSumFitter(PyMCModel):
383490 Inference data...
384491 """ # noqa: W605
385492
493+ default_priors = {
494+ "y_hat" : Prior (
495+ "Normal" ,
496+ sigma = Prior ("HalfNormal" , sigma = 1 , dims = ["treated_units" ]),
497+ dims = ["obs_ind" , "treated_units" ],
498+ ),
499+ }
500+
501+ def priors_from_data (self , X , y ) -> Dict [str , Any ]:
502+ """
503+ Set Dirichlet prior for weights based on number of control units.
504+
505+ For synthetic control models, this method sets the shape parameter of the
506+ Dirichlet prior on the control unit weights (`beta`) to be uniform across
507+ all available control units. This ensures that all control units have
508+ equal prior probability of contributing to the synthetic control.
509+
510+ Parameters
511+ ----------
512+ X : xarray.DataArray
513+ Control unit data with shape (n_obs, n_control_units).
514+ y : xarray.DataArray
515+ Treated unit outcome data.
516+
517+ Returns
518+ -------
519+ Dict[str, Prior]
520+ Dictionary containing:
521+ - "beta": Dirichlet prior with shape=(1,...,1) for n_control_units
522+ """
523+ n_predictors = X .shape [1 ]
524+ return {
525+ "beta" : Prior (
526+ "Dirichlet" , a = np .ones (n_predictors ), dims = ["treated_units" , "coeffs" ]
527+ ),
528+ }
529+
386530 def build_model (self , X , y , coords ):
387531 """
388532 Defines the PyMC model
389533 """
390534 with self :
391535 self .add_coords (coords )
392- n_predictors = X .sizes ["coeffs" ]
393536 X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
394537 y = pm .Data ("y" , y , dims = ["obs_ind" , "treated_units" ])
395- beta = pm .Dirichlet (
396- "beta" , a = np .ones (n_predictors ), dims = ["treated_units" , "coeffs" ]
397- )
398- sigma = pm .HalfNormal ("sigma" , 1 , dims = "treated_units" )
538+ beta = self .priors ["beta" ].create_variable ("beta" )
399539 mu = pm .Deterministic (
400540 "mu" , pt .dot (X , beta .T ), dims = ["obs_ind" , "treated_units" ]
401541 )
402- pm . Normal ("y_hat" , mu , sigma , observed = y , dims = [ "obs_ind" , "treated_units" ] )
542+ self . priors [ "y_hat" ]. create_likelihood_variable ("y_hat" , mu = mu , observed = y )
403543
404544
405545class InstrumentalVariableRegression (PyMCModel ):
@@ -589,21 +729,18 @@ class PropensityScore(PyMCModel):
589729 Inference...
590730 """ # noqa: W605
591731
592- def build_model (self , X , t , coords , prior , noncentred ):
732+ default_priors = {
733+ "b" : Prior ("Normal" , mu = 0 , sigma = 1 , dims = "coeffs" ),
734+ }
735+
736+ def build_model (self , X , t , coords , prior = None , noncentred = True ):
593737 "Defines the PyMC propensity model"
594738 with self :
595739 self .add_coords (coords )
596740 X_data = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
597741 t_data = pm .Data ("t" , t .flatten (), dims = "obs_ind" )
598- if noncentred :
599- mu_beta , sigma_beta = prior ["b" ]
600- beta_std = pm .Normal ("beta_std" , 0 , 1 , dims = "coeffs" )
601- b = pm .Deterministic (
602- "beta_" , mu_beta + sigma_beta * beta_std , dims = "coeffs"
603- )
604- else :
605- b = pm .Normal ("b" , mu = prior ["b" ][0 ], sigma = prior ["b" ][1 ], dims = "coeffs" )
606- mu = pm .math .dot (X_data , b )
742+ b = self .priors ["b" ].create_variable ("b" )
743+ mu = pt .dot (X_data , b )
607744 p = pm .Deterministic ("p" , pm .math .invlogit (mu ))
608745 pm .Bernoulli ("t_pred" , p = p , observed = t_data , dims = "obs_ind" )
609746
0 commit comments