Skip to content

Commit cdb4986

Browse files
authored
Merge branch 'main' into cetagostini/adding_bsts_to_causalpy
2 parents a2e5e40 + 673e805 commit cdb4986

File tree

12 files changed

+748
-132
lines changed

12 files changed

+748
-132
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: ["3.10", "3.11", "3.12"]
14+
python-version: ["3.11", "3.12", "3.13"]
1515

1616
steps:
1717
- uses: actions/checkout@v4

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ repos:
1313
- --exclude=docs/
1414
- --exclude=scripts/
1515
- repo: https://github.com/pre-commit/pre-commit-hooks
16-
rev: v5.0.0
16+
rev: v6.0.0
1717
hooks:
1818
- id: debug-statements
1919
- id: trailing-whitespace
@@ -25,7 +25,7 @@ repos:
2525
exclude: &exclude_pattern 'iv_weak_instruments.ipynb'
2626
args: ["--maxkb=1500"]
2727
- repo: https://github.com/astral-sh/ruff-pre-commit
28-
rev: v0.12.5
28+
rev: v0.13.2
2929
hooks:
3030
# Run the linter
3131
- id: ruff

causalpy/experiments/prepostnegd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class PrePostNEGD(BaseExperiment):
8282
Intercept -0.5, 94% HDI [-1, 0.2]
8383
C(group)[T.1] 2, 94% HDI [2, 2]
8484
pre 1, 94% HDI [1, 1]
85-
sigma 0.5, 94% HDI [0.5, 0.6]
85+
y_hat_sigma 0.5, 94% HDI [0.5, 0.6]
8686
"""
8787

8888
supports_ols = False

causalpy/pymc_models.py

Lines changed: 162 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import xarray as xr
2424
from arviz import r2_score
2525
from patsy import dmatrix
26+
from pymc_extras.prior import Prior
2627

2728
from causalpy.utils import round_num
2829

@@ -90,7 +91,87 @@ class PyMCModel(pm.Model):
9091
Inference data...
9192
"""
9293

93-
def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
94+
default_priors = {}
95+
96+
def priors_from_data(self, X, y) -> Dict[str, Any]:
97+
"""
98+
Generate priors dynamically based on the input data.
99+
100+
This method allows models to set sensible priors that adapt to the scale
101+
and characteristics of the actual data being analyzed. It's called during
102+
the `fit()` method before model building, allowing data-driven prior
103+
specification that can improve model performance and convergence.
104+
105+
The priors returned by this method are merged with any user-specified
106+
priors (passed via the `priors` parameter in `__init__`), with
107+
user-specified priors taking precedence in case of conflicts.
108+
109+
Parameters
110+
----------
111+
X : xarray.DataArray
112+
Input features/covariates with dimensions ["obs_ind", "coeffs"].
113+
Used to understand the scale and structure of predictors.
114+
y : xarray.DataArray
115+
Target variable with dimensions ["obs_ind", "treated_units"].
116+
Used to understand the scale and structure of the outcome.
117+
118+
Returns
119+
-------
120+
Dict[str, Prior]
121+
Dictionary mapping parameter names to Prior objects. The keys should
122+
match parameter names used in the model's `build_model()` method.
123+
124+
Notes
125+
-----
126+
The base implementation returns an empty dictionary, meaning no
127+
data-driven priors are set by default. Subclasses should override
128+
this method to implement data-adaptive prior specification.
129+
130+
**Priority Order for Priors:**
131+
1. User-specified priors (passed to `__init__`)
132+
2. Data-driven priors (from this method)
133+
3. Default priors (from `default_priors` property)
134+
135+
Examples
136+
--------
137+
A typical implementation might scale priors based on data variance:
138+
139+
>>> def priors_from_data(self, X, y):
140+
... y_std = float(y.std())
141+
... return {
142+
... "sigma": Prior("HalfNormal", sigma=y_std, dims="treated_units"),
143+
... "beta": Prior(
144+
... "Normal",
145+
... mu=0,
146+
... sigma=2 * y_std,
147+
... dims=["treated_units", "coeffs"],
148+
... ),
149+
... }
150+
151+
Or set shape parameters based on data dimensions:
152+
153+
>>> def priors_from_data(self, X, y):
154+
... n_predictors = X.shape[1]
155+
... return {
156+
... "beta": Prior(
157+
... "Dirichlet",
158+
... a=np.ones(n_predictors),
159+
... dims=["treated_units", "coeffs"],
160+
... )
161+
... }
162+
163+
See Also
164+
--------
165+
WeightedSumFitter.priors_from_data : Example implementation that sets
166+
Dirichlet prior shape based on number of control units.
167+
"""
168+
return {}
169+
170+
def __init__(
171+
self,
172+
sample_kwargs: Optional[Dict[str, Any]] = None,
173+
priors: dict[str, Any] | None = None,
174+
):
94175
"""
95176
:param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
96177
:func:`pymc.sample` function. Defaults to an empty dictionary.
@@ -99,9 +180,13 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
99180
self.idata = None
100181
self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {}
101182

183+
self.priors = {**self.default_priors, **(priors or {})}
184+
102185
def build_model(self, X, y, coords) -> None:
103186
"""Build the model, must be implemented by subclass."""
104-
raise NotImplementedError("This method must be implemented by a subclass")
187+
raise NotImplementedError(
188+
"This method must be implemented by a subclass"
189+
) # pragma: no cover
105190

106191
def _data_setter(self, X: xr.DataArray) -> None:
107192
"""
@@ -144,6 +229,10 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
144229
# sample_posterior_predictive() if provided in sample_kwargs.
145230
random_seed = self.sample_kwargs.get("random_seed", None)
146231

232+
# Merge priors with precedence: user-specified > data-driven > defaults
233+
# Data-driven priors are computed first, then user-specified priors override them
234+
self.priors = {**self.priors_from_data(X, y), **self.priors}
235+
147236
self.build_model(X, y, coords)
148237
with self:
149238
self.idata = pm.sample(**self.sample_kwargs)
@@ -260,26 +349,36 @@ def print_coefficients_for_unit(
260349
) -> None:
261350
"""Print coefficients for a single unit"""
262351
# Determine the width of the longest label
263-
max_label_length = max(len(name) for name in labels + ["sigma"])
352+
max_label_length = max(len(name) for name in labels + ["y_hat_sigma"])
264353

265354
for name in labels:
266355
coeff_samples = unit_coeffs.sel(coeffs=name)
267356
print_row(max_label_length, name, coeff_samples, round_to)
268357

269358
# Add coefficient for measurement std
270-
print_row(max_label_length, "sigma", unit_sigma, round_to)
359+
print_row(max_label_length, "y_hat_sigma", unit_sigma, round_to)
271360

272361
print("Model coefficients:")
273362
coeffs = az.extract(self.idata.posterior, var_names="beta")
274363

275-
# Always has treated_units dimension - no branching needed!
364+
# Check if sigma or y_hat_sigma variable exists
365+
sigma_var_name = None
366+
if "sigma" in self.idata.posterior:
367+
sigma_var_name = "sigma"
368+
elif "y_hat_sigma" in self.idata.posterior:
369+
sigma_var_name = "y_hat_sigma"
370+
else:
371+
raise ValueError(
372+
"Neither 'sigma' nor 'y_hat_sigma' found in posterior"
373+
) # pragma: no cover
374+
276375
treated_units = coeffs.coords["treated_units"].values
277376
for unit in treated_units:
278377
if len(treated_units) > 1:
279378
print(f"\nTreated unit: {unit}")
280379

281380
unit_coeffs = coeffs.sel(treated_units=unit)
282-
unit_sigma = az.extract(self.idata.posterior, var_names="sigma").sel(
381+
unit_sigma = az.extract(self.idata.posterior, var_names=sigma_var_name).sel(
283382
treated_units=unit
284383
)
285384
print_coefficients_for_unit(unit_coeffs, unit_sigma, labels, round_to or 2)
@@ -322,6 +421,15 @@ class LinearRegression(PyMCModel):
322421
Inference data...
323422
""" # noqa: W605
324423

424+
default_priors = {
425+
"beta": Prior("Normal", mu=0, sigma=50, dims=["treated_units", "coeffs"]),
426+
"y_hat": Prior(
427+
"Normal",
428+
sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]),
429+
dims=["obs_ind", "treated_units"],
430+
),
431+
}
432+
325433
def build_model(self, X, y, coords):
326434
"""
327435
Defines the PyMC model
@@ -335,12 +443,11 @@ def build_model(self, X, y, coords):
335443
self.add_coords(coords)
336444
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
337445
y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
338-
beta = pm.Normal("beta", 0, 50, dims=["treated_units", "coeffs"])
339-
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
446+
beta = self.priors["beta"].create_variable("beta")
340447
mu = pm.Deterministic(
341448
"mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
342449
)
343-
pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
450+
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
344451

345452

346453
class WeightedSumFitter(PyMCModel):
@@ -383,23 +490,56 @@ class WeightedSumFitter(PyMCModel):
383490
Inference data...
384491
""" # noqa: W605
385492

493+
default_priors = {
494+
"y_hat": Prior(
495+
"Normal",
496+
sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]),
497+
dims=["obs_ind", "treated_units"],
498+
),
499+
}
500+
501+
def priors_from_data(self, X, y) -> Dict[str, Any]:
502+
"""
503+
Set Dirichlet prior for weights based on number of control units.
504+
505+
For synthetic control models, this method sets the shape parameter of the
506+
Dirichlet prior on the control unit weights (`beta`) to be uniform across
507+
all available control units. This ensures that all control units have
508+
equal prior probability of contributing to the synthetic control.
509+
510+
Parameters
511+
----------
512+
X : xarray.DataArray
513+
Control unit data with shape (n_obs, n_control_units).
514+
y : xarray.DataArray
515+
Treated unit outcome data.
516+
517+
Returns
518+
-------
519+
Dict[str, Prior]
520+
Dictionary containing:
521+
- "beta": Dirichlet prior with shape=(1,...,1) for n_control_units
522+
"""
523+
n_predictors = X.shape[1]
524+
return {
525+
"beta": Prior(
526+
"Dirichlet", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
527+
),
528+
}
529+
386530
def build_model(self, X, y, coords):
387531
"""
388532
Defines the PyMC model
389533
"""
390534
with self:
391535
self.add_coords(coords)
392-
n_predictors = X.sizes["coeffs"]
393536
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
394537
y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
395-
beta = pm.Dirichlet(
396-
"beta", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
397-
)
398-
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
538+
beta = self.priors["beta"].create_variable("beta")
399539
mu = pm.Deterministic(
400540
"mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
401541
)
402-
pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
542+
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
403543

404544

405545
class InstrumentalVariableRegression(PyMCModel):
@@ -589,21 +729,18 @@ class PropensityScore(PyMCModel):
589729
Inference...
590730
""" # noqa: W605
591731

592-
def build_model(self, X, t, coords, prior, noncentred):
732+
default_priors = {
733+
"b": Prior("Normal", mu=0, sigma=1, dims="coeffs"),
734+
}
735+
736+
def build_model(self, X, t, coords, prior=None, noncentred=True):
593737
"Defines the PyMC propensity model"
594738
with self:
595739
self.add_coords(coords)
596740
X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"])
597741
t_data = pm.Data("t", t.flatten(), dims="obs_ind")
598-
if noncentred:
599-
mu_beta, sigma_beta = prior["b"]
600-
beta_std = pm.Normal("beta_std", 0, 1, dims="coeffs")
601-
b = pm.Deterministic(
602-
"beta_", mu_beta + sigma_beta * beta_std, dims="coeffs"
603-
)
604-
else:
605-
b = pm.Normal("b", mu=prior["b"][0], sigma=prior["b"][1], dims="coeffs")
606-
mu = pm.math.dot(X_data, b)
742+
b = self.priors["b"].create_variable("b")
743+
mu = pt.dot(X_data, b)
607744
p = pm.Deterministic("p", pm.math.invlogit(mu))
608745
pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind")
609746

0 commit comments

Comments
 (0)