Skip to content

Commit a1eea92

Browse files
committed
update to use Prior class from pymc-extras + fix failing doctests
1 parent e10d585 commit a1eea92

File tree

2 files changed

+227
-65
lines changed

2 files changed

+227
-65
lines changed

causalpy/pymc_models.py

Lines changed: 223 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -962,11 +962,28 @@ class TransferFunctionLinearRegression(PyMCModel):
962962
The current implementation uses independent Normal errors. Future versions may
963963
include AR(1) autocorrelation modeling for residuals.
964964
965-
**Priors**: The model uses data-informed priors that scale with the outcome variable:
965+
**Prior Customization**:
966966
967-
- Baseline coefficients: ``Normal(0, 5 * std(y))``
968-
- Treatment coefficients: ``Normal(0, 2 * std(y))`` or ``HalfNormal(2 * std(y))``
969-
- Error std: ``HalfNormal(2 * std(y))``
967+
Priors are managed using the ``Prior`` class from ``pymc_extras`` and can be
968+
customized via the ``priors`` parameter:
969+
970+
>>> from pymc_extras.prior import Prior
971+
>>> model = cp.pymc_models.TransferFunctionLinearRegression(
972+
... saturation_type=None,
973+
... adstock_config={...},
974+
... priors={
975+
... "beta": Prior(
976+
... "Normal", mu=0, sigma=100, dims=["treated_units", "coeffs"]
977+
... ),
978+
... "sigma": Prior("HalfNormal", sigma=50, dims=["treated_units"]),
979+
... },
980+
... )
981+
982+
By default, data-informed priors are set automatically via ``priors_from_data()``:
983+
984+
- Baseline coefficients (``beta``): ``Normal(0, 5 * std(y))``
985+
- Treatment coefficients (``theta_treatment``): ``Normal(0, 2 * std(y))`` or ``HalfNormal(2 * std(y))``
986+
- Error std (``sigma``): ``HalfNormal(2 * std(y))``
970987
971988
This adaptive approach ensures priors are reasonable regardless of data scale.
972989
@@ -1011,6 +1028,78 @@ def __init__(
10111028
self.treatment_names = None
10121029
self.n_treatments = None
10131030

1031+
def priors_from_data(self, X, y) -> Dict[str, Any]:
1032+
"""
1033+
Generate data-informed priors that scale with outcome variable.
1034+
1035+
Computes priors for baseline coefficients, treatment coefficients,
1036+
and error standard deviation based on the standard deviation of y.
1037+
This ensures priors are reasonable regardless of data scale.
1038+
1039+
Parameters
1040+
----------
1041+
X : xr.DataArray
1042+
Baseline design matrix (n_obs, n_baseline_features).
1043+
y : xr.DataArray
1044+
Outcome variable (n_obs, 1).
1045+
1046+
Returns
1047+
-------
1048+
Dict[str, Prior]
1049+
Dictionary with Prior objects for beta, theta_treatment, and sigma.
1050+
1051+
Notes
1052+
-----
1053+
The returned dictionary contains Prior objects with the following structure::
1054+
1055+
{
1056+
"beta": Prior(
1057+
"Normal", mu=0, sigma=5 * y_scale, dims=["treated_units", "coeffs"]
1058+
),
1059+
"theta_treatment": Prior(
1060+
"Normal",
1061+
mu=0,
1062+
sigma=2 * y_scale,
1063+
dims=["treated_units", "treatment_names"],
1064+
),
1065+
"sigma": Prior("HalfNormal", sigma=2 * y_scale, dims=["treated_units"]),
1066+
}
1067+
1068+
where ``y_scale = std(y)``.
1069+
"""
1070+
y_scale = float(np.std(y))
1071+
1072+
priors = {
1073+
"beta": Prior(
1074+
"Normal",
1075+
mu=0,
1076+
sigma=5 * y_scale,
1077+
dims=["treated_units", "coeffs"],
1078+
),
1079+
"sigma": Prior(
1080+
"HalfNormal",
1081+
sigma=2 * y_scale,
1082+
dims=["treated_units"],
1083+
),
1084+
}
1085+
1086+
# Treatment coefficient prior depends on constraint
1087+
if self.coef_constraint == "nonnegative":
1088+
priors["theta_treatment"] = Prior(
1089+
"HalfNormal",
1090+
sigma=2 * y_scale,
1091+
dims=["treated_units", "treatment_names"],
1092+
)
1093+
else:
1094+
priors["theta_treatment"] = Prior(
1095+
"Normal",
1096+
mu=0,
1097+
sigma=2 * y_scale,
1098+
dims=["treated_units", "treatment_names"],
1099+
)
1100+
1101+
return priors
1102+
10141103
def build_model(self, X, y, coords, treatment_data):
10151104
"""
10161105
Build the PyMC model with transforms.
@@ -1039,9 +1128,6 @@ def build_model(self, X, y, coords, treatment_data):
10391128
).values.tolist()
10401129
self.n_treatments = treatment_data.shape[1]
10411130

1042-
# Compute data scale BEFORE entering model context (for data-informed priors)
1043-
y_scale = float(np.std(y))
1044-
10451131
with self:
10461132
self.add_coords(coords)
10471133

@@ -1188,27 +1274,13 @@ def build_model(self, X, y, coords, treatment_data):
11881274
# ==================================================================
11891275
# Regression Coefficients (with data-informed priors)
11901276
# ==================================================================
1191-
# Baseline coefficients: prior std = 5 * outcome scale
1192-
# This allows intercept to range widely while keeping some regularization
1193-
beta = pm.Normal(
1194-
"beta", mu=0, sigma=5 * y_scale, dims=["treated_units", "coeffs"]
1195-
)
1277+
# Baseline coefficients: data-informed priors set via priors_from_data()
1278+
beta = self.priors["beta"].create_variable("beta")
11961279

1197-
# Treatment coefficients: prior std = 2 * outcome scale
1198-
# Treatments typically have smaller effects than baseline level
1199-
if self.coef_constraint == "nonnegative":
1200-
theta_treatment = pm.HalfNormal(
1201-
"theta_treatment",
1202-
sigma=2 * y_scale,
1203-
dims=["treated_units", "treatment_names"],
1204-
)
1205-
else:
1206-
theta_treatment = pm.Normal(
1207-
"theta_treatment",
1208-
mu=0,
1209-
sigma=2 * y_scale,
1210-
dims=["treated_units", "treatment_names"],
1211-
)
1280+
# Treatment coefficients: data-informed priors set via priors_from_data()
1281+
theta_treatment = self.priors["theta_treatment"].create_variable(
1282+
"theta_treatment"
1283+
)
12121284

12131285
# ==================================================================
12141286
# Mean Function
@@ -1229,8 +1301,8 @@ def build_model(self, X, y, coords, treatment_data):
12291301
# ==================================================================
12301302
# Likelihood
12311303
# ==================================================================
1232-
# Error std: prior centered on outcome scale with wide support
1233-
sigma = pm.HalfNormal("sigma", sigma=2 * y_scale, dims=["treated_units"])
1304+
# Error std: data-informed prior set via priors_from_data()
1305+
sigma = self.priors["sigma"].create_variable("sigma")
12341306

12351307
# For now, use independent Normal errors
12361308
# Note: AR(1) errors in regression context require more complex implementation
@@ -1270,6 +1342,10 @@ def fit(self, X, y, coords, treatment_data):
12701342
# sample_posterior_predictive() if provided in sample_kwargs.
12711343
random_seed = self.sample_kwargs.get("random_seed", None)
12721344

1345+
# Merge priors with precedence: user-specified > data-driven > defaults
1346+
# Data-driven priors are computed first, then user-specified priors override them
1347+
self.priors = {**self.priors_from_data(X, y), **self.priors}
1348+
12731349
# Build the model with treatment data
12741350
self.build_model(X, y, coords, treatment_data)
12751351

@@ -1367,12 +1443,31 @@ class TransferFunctionARRegression(PyMCModel):
13671443
- Posterior predictive sampling requires forward simulation of the AR process
13681444
- Convergence can be slower than the independent errors model; consider increasing tune/draws
13691445
1370-
**Priors**: The model uses data-informed priors that scale with the outcome variable:
1446+
**Prior Customization**:
1447+
1448+
Priors are managed using the ``Prior`` class from ``pymc_extras`` and can be
1449+
customized via the ``priors`` parameter:
1450+
1451+
>>> from pymc_extras.prior import Prior
1452+
>>> model = cp.pymc_models.TransferFunctionARRegression(
1453+
... saturation_type=None,
1454+
... adstock_config={...},
1455+
... priors={
1456+
... "beta": Prior(
1457+
... "Normal", mu=0, sigma=100, dims=["treated_units", "coeffs"]
1458+
... ),
1459+
... "rho": Prior(
1460+
... "Uniform", lower=-0.95, upper=0.95, dims=["treated_units"]
1461+
... ),
1462+
... },
1463+
... )
1464+
1465+
By default, data-informed priors are set automatically via ``priors_from_data()``:
13711466
1372-
- Baseline coefficients: ``Normal(0, 5 * std(y))``
1373-
- Treatment coefficients: ``Normal(0, 2 * std(y))`` or ``HalfNormal(2 * std(y))``
1374-
- Error std: ``HalfNormal(2 * std(y))``
1375-
- AR(1) coefficient: ``Uniform(-0.99, 0.99)``
1467+
- Baseline coefficients (``beta``): ``Normal(0, 5 * std(y))``
1468+
- Treatment coefficients (``theta_treatment``): ``Normal(0, 2 * std(y))`` or ``HalfNormal(2 * std(y))``
1469+
- Error std (``sigma``): ``HalfNormal(2 * std(y))``
1470+
- AR(1) coefficient (``rho``): ``Uniform(-0.99, 0.99)``
13761471
13771472
This adaptive approach ensures priors are reasonable regardless of data scale.
13781473
@@ -1429,6 +1524,86 @@ def __init__(
14291524
self.treatment_names = None
14301525
self.n_treatments = None
14311526

1527+
def priors_from_data(self, X, y) -> Dict[str, Any]:
1528+
"""
1529+
Generate data-informed priors including AR(1) coefficient.
1530+
1531+
Similar to TransferFunctionLinearRegression but also includes
1532+
a prior for the AR(1) coefficient rho.
1533+
1534+
Parameters
1535+
----------
1536+
X : xr.DataArray
1537+
Baseline design matrix.
1538+
y : xr.DataArray
1539+
Outcome variable.
1540+
1541+
Returns
1542+
-------
1543+
Dict[str, Prior]
1544+
Dictionary with Prior objects for beta, theta_treatment, sigma, and rho.
1545+
1546+
Notes
1547+
-----
1548+
The returned dictionary contains Prior objects with the following structure::
1549+
1550+
{
1551+
"beta": Prior(
1552+
"Normal", mu=0, sigma=5 * y_scale, dims=["treated_units", "coeffs"]
1553+
),
1554+
"theta_treatment": Prior(
1555+
"Normal",
1556+
mu=0,
1557+
sigma=2 * y_scale,
1558+
dims=["treated_units", "treatment_names"],
1559+
),
1560+
"sigma": Prior("HalfNormal", sigma=2 * y_scale, dims=["treated_units"]),
1561+
"rho": Prior(
1562+
"Uniform", lower=-0.99, upper=0.99, dims=["treated_units"]
1563+
),
1564+
}
1565+
1566+
where ``y_scale = std(y)``.
1567+
"""
1568+
y_scale = float(np.std(y))
1569+
1570+
priors = {
1571+
"beta": Prior(
1572+
"Normal",
1573+
mu=0,
1574+
sigma=5 * y_scale,
1575+
dims=["treated_units", "coeffs"],
1576+
),
1577+
"sigma": Prior(
1578+
"HalfNormal",
1579+
sigma=2 * y_scale,
1580+
dims=["treated_units"],
1581+
),
1582+
"rho": Prior(
1583+
"Uniform",
1584+
lower=-0.99,
1585+
upper=0.99,
1586+
dims=["treated_units"],
1587+
),
1588+
}
1589+
1590+
# Treatment coefficient prior depends on constraint
1591+
if self.coef_constraint == "nonnegative":
1592+
priors["theta_treatment"] = Prior(
1593+
"HalfNormal",
1594+
sigma=2 * y_scale,
1595+
dims=["treated_units", "treatment_names"],
1596+
)
1597+
else:
1598+
priors["theta_treatment"] = Prior(
1599+
"Normal",
1600+
mu=0,
1601+
sigma=2 * y_scale,
1602+
dims=["treated_units", "treatment_names"],
1603+
)
1604+
1605+
return priors
1606+
14321607
def build_model(self, X, y, coords, treatment_data):
14331608
"""
14341609
Build the PyMC model with transforms and AR(1) errors using quasi-differencing.
@@ -1457,9 +1632,6 @@ def build_model(self, X, y, coords, treatment_data):
14571632
).values.tolist()
14581633
self.n_treatments = treatment_data.shape[1]
14591634

1460-
# Compute data scale BEFORE entering model context (for data-informed priors)
1461-
y_scale = float(np.std(y))
1462-
14631635
with self:
14641636
self.add_coords(coords)
14651637

@@ -1611,27 +1783,13 @@ def build_model(self, X, y, coords, treatment_data):
16111783
# ==================================================================
16121784
# Regression Coefficients (with data-informed priors)
16131785
# ==================================================================
1614-
# Baseline coefficients: prior std = 5 * outcome scale
1615-
# This allows intercept to range widely while keeping some regularization
1616-
beta = pm.Normal(
1617-
"beta", mu=0, sigma=5 * y_scale, dims=["treated_units", "coeffs"]
1618-
)
1786+
# Baseline coefficients: data-informed priors set via priors_from_data()
1787+
beta = self.priors["beta"].create_variable("beta")
16191788

1620-
# Treatment coefficients: prior std = 2 * outcome scale
1621-
# Treatments typically have smaller effects than baseline level
1622-
if self.coef_constraint == "nonnegative":
1623-
theta_treatment = pm.HalfNormal(
1624-
"theta_treatment",
1625-
sigma=2 * y_scale,
1626-
dims=["treated_units", "treatment_names"],
1627-
)
1628-
else:
1629-
theta_treatment = pm.Normal(
1630-
"theta_treatment",
1631-
mu=0,
1632-
sigma=2 * y_scale,
1633-
dims=["treated_units", "treatment_names"],
1634-
)
1789+
# Treatment coefficients: data-informed priors set via priors_from_data()
1790+
theta_treatment = self.priors["theta_treatment"].create_variable(
1791+
"theta_treatment"
1792+
)
16351793

16361794
# ==================================================================
16371795
# Mean Function
@@ -1652,11 +1810,11 @@ def build_model(self, X, y, coords, treatment_data):
16521810
# ==================================================================
16531811
# AR(1) Likelihood via Quasi-Differencing
16541812
# ==================================================================
1655-
# AR(1) parameter: rho (constrained to ensure stationarity)
1656-
rho = pm.Uniform("rho", lower=-0.99, upper=0.99, dims=["treated_units"])
1813+
# AR(1) parameter: data-informed prior set via priors_from_data()
1814+
rho = self.priors["rho"].create_variable("rho")
16571815

1658-
# Innovation standard deviation: prior centered on outcome scale
1659-
sigma = pm.HalfNormal("sigma", sigma=2 * y_scale, dims=["treated_units"])
1816+
# Innovation standard deviation: data-informed prior set via priors_from_data()
1817+
sigma = self.priors["sigma"].create_variable("sigma")
16601818

16611819
# Quasi-differencing approach using manual log-likelihood
16621820
# We can't use y_diff as observed data because it depends on rho
@@ -1717,6 +1875,10 @@ def fit(self, X, y, coords, treatment_data):
17171875
# sample_posterior_predictive() if provided in sample_kwargs.
17181876
random_seed = self.sample_kwargs.get("random_seed", None)
17191877

1878+
# Merge priors with precedence: user-specified > data-driven > defaults
1879+
# Data-driven priors are computed first, then user-specified priors override them
1880+
self.priors = {**self.priors_from_data(X, y), **self.priors}
1881+
17201882
# Build the model with treatment data
17211883
self.build_model(X, y, coords, treatment_data)
17221884

0 commit comments

Comments
 (0)