@@ -96,6 +96,77 @@ def default_priors(self):
9696 return {}
9797
9898 def priors_from_data (self , X , y ) -> Dict [str , Any ]:
99+ """
100+ Generate priors dynamically based on the input data.
101+
102+ This method allows models to set sensible priors that adapt to the scale
103+ and characteristics of the actual data being analyzed. It's called during
104+ the `fit()` method before model building, allowing data-driven prior
105+ specification that can improve model performance and convergence.
106+
107+ The priors returned by this method are merged with any user-specified
108+ priors (passed via the `priors` parameter in `__init__`), with
109+ user-specified priors taking precedence in case of conflicts.
110+
111+ Parameters
112+ ----------
113+ X : xarray.DataArray
114+ Input features/covariates with dimensions ["obs_ind", "coeffs"].
115+ Used to understand the scale and structure of predictors.
116+ y : xarray.DataArray
117+ Target variable with dimensions ["obs_ind", "treated_units"].
118+ Used to understand the scale and structure of the outcome.
119+
120+ Returns
121+ -------
122+ Dict[str, Prior]
123+ Dictionary mapping parameter names to Prior objects. The keys should
124+ match parameter names used in the model's `build_model()` method.
125+
126+ Notes
127+ -----
128+ The base implementation returns an empty dictionary, meaning no
129+ data-driven priors are set by default. Subclasses should override
130+ this method to implement data-adaptive prior specification.
131+
132+ **Priority Order for Priors:**
133+ 1. User-specified priors (passed to `__init__`)
134+ 2. Data-driven priors (from this method)
135+ 3. Default priors (from `default_priors` property)
136+
137+ Examples
138+ --------
139+ A typical implementation might scale priors based on data variance:
140+
141+ >>> def priors_from_data(self, X, y):
142+ ... y_std = float(y.std())
143+ ... return {
144+ ... "sigma": Prior("HalfNormal", sigma=y_std, dims="treated_units"),
145+ ... "beta": Prior(
146+ ... "Normal",
147+ ... mu=0,
148+ ... sigma=2 * y_std,
149+ ... dims=["treated_units", "coeffs"],
150+ ... ),
151+ ... }
152+
153+ Or set shape parameters based on data dimensions:
154+
155+ >>> def priors_from_data(self, X, y):
156+ ... n_predictors = X.shape[1]
157+ ... return {
158+ ... "beta": Prior(
159+ ... "Dirichlet",
160+ ... a=np.ones(n_predictors),
161+ ... dims=["treated_units", "coeffs"],
162+ ... )
163+ ... }
164+
165+ See Also
166+ --------
167+ WeightedSumFitter.priors_from_data : Example implementation that sets
168+ Dirichlet prior shape based on number of control units.
169+ """
99170 return {}
100171
101172 def __init__ (
@@ -160,6 +231,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
160231 # sample_posterior_predictive() if provided in sample_kwargs.
161232 random_seed = self .sample_kwargs .get ("random_seed" , None )
162233
234+ # Merge priors with precedence: user-specified > data-driven > defaults
235+ # Data-driven priors are computed first, then user-specified priors override them
163236 self .priors = {** self .priors_from_data (X , y ), ** self .priors }
164237
165238 self .build_model (X , y , coords )
@@ -407,6 +480,27 @@ class WeightedSumFitter(PyMCModel):
407480 }
408481
409482 def priors_from_data (self , X , y ) -> Dict [str , Any ]:
483+ """
484+ Set Dirichlet prior for weights based on number of control units.
485+
486+ For synthetic control models, this method sets the shape parameter of the
487+ Dirichlet prior on the control unit weights (`beta`) to be uniform across
488+ all available control units. This ensures that all control units have
489+ equal prior probability of contributing to the synthetic control.
490+
491+ Parameters
492+ ----------
493+ X : xarray.DataArray
494+ Control unit data with shape (n_obs, n_control_units).
495+ y : xarray.DataArray
496+ Treated unit outcome data.
497+
498+ Returns
499+ -------
500+ Dict[str, Prior]
501+ Dictionary containing:
502+ - "beta": Dirichlet prior with shape=(1,...,1) for n_control_units
503+ """
410504 n_predictors = X .shape [1 ]
411505 return {
412506 "beta" : Prior (
0 commit comments