Skip to content

Commit 0650644

Browse files
committed
add docstrings to the priors_from_data methods
1 parent bcba49f commit 0650644

File tree

2 files changed

+98
-4
lines changed

2 files changed

+98
-4
lines changed

causalpy/pymc_models.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,77 @@ def default_priors(self):
9696
return {}
9797

9898
def priors_from_data(self, X, y) -> Dict[str, Any]:
99+
"""
100+
Generate priors dynamically based on the input data.
101+
102+
This method allows models to set sensible priors that adapt to the scale
103+
and characteristics of the actual data being analyzed. It's called during
104+
the `fit()` method before model building, allowing data-driven prior
105+
specification that can improve model performance and convergence.
106+
107+
The priors returned by this method are merged with any user-specified
108+
priors (passed via the `priors` parameter in `__init__`), with
109+
user-specified priors taking precedence in case of conflicts.
110+
111+
Parameters
112+
----------
113+
X : xarray.DataArray
114+
Input features/covariates with dimensions ["obs_ind", "coeffs"].
115+
Used to understand the scale and structure of predictors.
116+
y : xarray.DataArray
117+
Target variable with dimensions ["obs_ind", "treated_units"].
118+
Used to understand the scale and structure of the outcome.
119+
120+
Returns
121+
-------
122+
Dict[str, Prior]
123+
Dictionary mapping parameter names to Prior objects. The keys should
124+
match parameter names used in the model's `build_model()` method.
125+
126+
Notes
127+
-----
128+
The base implementation returns an empty dictionary, meaning no
129+
data-driven priors are set by default. Subclasses should override
130+
this method to implement data-adaptive prior specification.
131+
132+
**Priority Order for Priors:**
133+
1. User-specified priors (passed to `__init__`)
134+
2. Data-driven priors (from this method)
135+
3. Default priors (from `default_priors` property)
136+
137+
Examples
138+
--------
139+
A typical implementation might scale priors based on data variance:
140+
141+
>>> def priors_from_data(self, X, y):
142+
... y_std = float(y.std())
143+
... return {
144+
... "sigma": Prior("HalfNormal", sigma=y_std, dims="treated_units"),
145+
... "beta": Prior(
146+
... "Normal",
147+
... mu=0,
148+
... sigma=2 * y_std,
149+
... dims=["treated_units", "coeffs"],
150+
... ),
151+
... }
152+
153+
Or set shape parameters based on data dimensions:
154+
155+
>>> def priors_from_data(self, X, y):
156+
... n_predictors = X.shape[1]
157+
... return {
158+
... "beta": Prior(
159+
... "Dirichlet",
160+
... a=np.ones(n_predictors),
161+
... dims=["treated_units", "coeffs"],
162+
... )
163+
... }
164+
165+
See Also
166+
--------
167+
WeightedSumFitter.priors_from_data : Example implementation that sets
168+
Dirichlet prior shape based on number of control units.
169+
"""
99170
return {}
100171

101172
def __init__(
@@ -160,6 +231,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
160231
# sample_posterior_predictive() if provided in sample_kwargs.
161232
random_seed = self.sample_kwargs.get("random_seed", None)
162233

234+
# Merge priors with precedence: user-specified > data-driven > defaults
235+
# Data-driven priors are computed first, then user-specified priors override them
163236
self.priors = {**self.priors_from_data(X, y), **self.priors}
164237

165238
self.build_model(X, y, coords)
@@ -407,6 +480,27 @@ class WeightedSumFitter(PyMCModel):
407480
}
408481

409482
def priors_from_data(self, X, y) -> Dict[str, Any]:
483+
"""
484+
Set Dirichlet prior for weights based on number of control units.
485+
486+
For synthetic control models, this method sets the shape parameter of the
487+
Dirichlet prior on the control unit weights (`beta`) to be uniform across
488+
all available control units. This ensures that all control units have
489+
equal prior probability of contributing to the synthetic control.
490+
491+
Parameters
492+
----------
493+
X : xarray.DataArray
494+
Control unit data with shape (n_obs, n_control_units).
495+
y : xarray.DataArray
496+
Treated unit outcome data.
497+
498+
Returns
499+
-------
500+
Dict[str, Prior]
501+
Dictionary containing:
502+
- "beta": Dirichlet prior with shape=(1,...,1) for n_control_units
503+
"""
410504
n_predictors = X.shape[1]
411505
return {
412506
"beta": Prior(

docs/source/_static/interrogate_badge.svg

Lines changed: 4 additions & 4 deletions
Loading

0 commit comments

Comments
 (0)