@@ -962,11 +962,28 @@ class TransferFunctionLinearRegression(PyMCModel):
962962 The current implementation uses independent Normal errors. Future versions may
963963 include AR(1) autocorrelation modeling for residuals.
964964
965- **Priors**: The model uses data-informed priors that scale with the outcome variable :
965+ **Prior Customization** :
966966
967- - Baseline coefficients: ``Normal(0, 5 * std(y))``
968- - Treatment coefficients: ``Normal(0, 2 * std(y))`` or ``HalfNormal(2 * std(y))``
969- - Error std: ``HalfNormal(2 * std(y))``
967+ Priors are managed using the ``Prior`` class from ``pymc_extras`` and can be
968+ customized via the ``priors`` parameter:
969+
970+ >>> from pymc_extras.prior import Prior
971+ >>> model = cp.pymc_models.TransferFunctionLinearRegression(
972+ ... saturation_type=None,
973+ ... adstock_config={...},
974+ ... priors={
975+ ... "beta": Prior(
976+ ... "Normal", mu=0, sigma=100, dims=["treated_units", "coeffs"]
977+ ... ),
978+ ... "sigma": Prior("HalfNormal", sigma=50, dims=["treated_units"]),
979+ ... },
980+ ... )
981+
982+ By default, data-informed priors are set automatically via ``priors_from_data()``:
983+
984+ - Baseline coefficients (``beta``): ``Normal(0, 5 * std(y))``
985+ - Treatment coefficients (``theta_treatment``): ``Normal(0, 2 * std(y))`` or ``HalfNormal(2 * std(y))``
986+ - Error std (``sigma``): ``HalfNormal(2 * std(y))``
970987
971988 This adaptive approach ensures priors are reasonable regardless of data scale.
972989
@@ -1011,6 +1028,78 @@ def __init__(
10111028 self .treatment_names = None
10121029 self .n_treatments = None
10131030
1031+ def priors_from_data (self , X , y ) -> Dict [str , Any ]:
1032+ """
1033+ Generate data-informed priors that scale with outcome variable.
1034+
1035+ Computes priors for baseline coefficients, treatment coefficients,
1036+ and error standard deviation based on the standard deviation of y.
1037+ This ensures priors are reasonable regardless of data scale.
1038+
1039+ Parameters
1040+ ----------
1041+ X : xr.DataArray
1042+ Baseline design matrix (n_obs, n_baseline_features).
1043+ y : xr.DataArray
1044+ Outcome variable (n_obs, 1).
1045+
1046+ Returns
1047+ -------
1048+ Dict[str, Prior]
1049+ Dictionary with Prior objects for beta, theta_treatment, and sigma.
1050+
1051+ Notes
1052+ -----
1053+ The returned dictionary contains Prior objects with the following structure::
1054+
1055+ {
1056+ "beta": Prior(
1057+ "Normal", mu=0, sigma=5 * y_scale, dims=["treated_units", "coeffs"]
1058+ ),
1059+ "theta_treatment": Prior(
1060+ "Normal",
1061+ mu=0,
1062+ sigma=2 * y_scale,
1063+ dims=["treated_units", "treatment_names"],
1064+ ),
1065+ "sigma": Prior("HalfNormal", sigma=2 * y_scale, dims=["treated_units"]),
1066+ }
1067+
1068+ where ``y_scale = std(y)``.
1069+ """
1070+ y_scale = float (np .std (y ))
1071+
1072+ priors = {
1073+ "beta" : Prior (
1074+ "Normal" ,
1075+ mu = 0 ,
1076+ sigma = 5 * y_scale ,
1077+ dims = ["treated_units" , "coeffs" ],
1078+ ),
1079+ "sigma" : Prior (
1080+ "HalfNormal" ,
1081+ sigma = 2 * y_scale ,
1082+ dims = ["treated_units" ],
1083+ ),
1084+ }
1085+
1086+ # Treatment coefficient prior depends on constraint
1087+ if self .coef_constraint == "nonnegative" :
1088+ priors ["theta_treatment" ] = Prior (
1089+ "HalfNormal" ,
1090+ sigma = 2 * y_scale ,
1091+ dims = ["treated_units" , "treatment_names" ],
1092+ )
1093+ else :
1094+ priors ["theta_treatment" ] = Prior (
1095+ "Normal" ,
1096+ mu = 0 ,
1097+ sigma = 2 * y_scale ,
1098+ dims = ["treated_units" , "treatment_names" ],
1099+ )
1100+
1101+ return priors
1102+
10141103 def build_model (self , X , y , coords , treatment_data ):
10151104 """
10161105 Build the PyMC model with transforms.
@@ -1039,9 +1128,6 @@ def build_model(self, X, y, coords, treatment_data):
10391128 ).values .tolist ()
10401129 self .n_treatments = treatment_data .shape [1 ]
10411130
1042- # Compute data scale BEFORE entering model context (for data-informed priors)
1043- y_scale = float (np .std (y ))
1044-
10451131 with self :
10461132 self .add_coords (coords )
10471133
@@ -1188,27 +1274,13 @@ def build_model(self, X, y, coords, treatment_data):
11881274 # ==================================================================
11891275 # Regression Coefficients (with data-informed priors)
11901276 # ==================================================================
1191- # Baseline coefficients: prior std = 5 * outcome scale
1192- # This allows intercept to range widely while keeping some regularization
1193- beta = pm .Normal (
1194- "beta" , mu = 0 , sigma = 5 * y_scale , dims = ["treated_units" , "coeffs" ]
1195- )
1277+ # Baseline coefficients: data-informed priors set via priors_from_data()
1278+ beta = self .priors ["beta" ].create_variable ("beta" )
11961279
1197- # Treatment coefficients: prior std = 2 * outcome scale
1198- # Treatments typically have smaller effects than baseline level
1199- if self .coef_constraint == "nonnegative" :
1200- theta_treatment = pm .HalfNormal (
1201- "theta_treatment" ,
1202- sigma = 2 * y_scale ,
1203- dims = ["treated_units" , "treatment_names" ],
1204- )
1205- else :
1206- theta_treatment = pm .Normal (
1207- "theta_treatment" ,
1208- mu = 0 ,
1209- sigma = 2 * y_scale ,
1210- dims = ["treated_units" , "treatment_names" ],
1211- )
1280+ # Treatment coefficients: data-informed priors set via priors_from_data()
1281+ theta_treatment = self .priors ["theta_treatment" ].create_variable (
1282+ "theta_treatment"
1283+ )
12121284
12131285 # ==================================================================
12141286 # Mean Function
@@ -1229,8 +1301,8 @@ def build_model(self, X, y, coords, treatment_data):
12291301 # ==================================================================
12301302 # Likelihood
12311303 # ==================================================================
1232- # Error std: prior centered on outcome scale with wide support
1233- sigma = pm . HalfNormal ( "sigma" , sigma = 2 * y_scale , dims = [ "treated_units" ] )
1304+ # Error std: data-informed prior set via priors_from_data()
1305+ sigma = self . priors [ "sigma" ]. create_variable ( " sigma" )
12341306
12351307 # For now, use independent Normal errors
12361308 # Note: AR(1) errors in regression context require more complex implementation
@@ -1270,6 +1342,10 @@ def fit(self, X, y, coords, treatment_data):
12701342 # sample_posterior_predictive() if provided in sample_kwargs.
12711343 random_seed = self .sample_kwargs .get ("random_seed" , None )
12721344
1345+ # Merge priors with precedence: user-specified > data-driven > defaults
1346+ # Data-driven priors are computed first, then user-specified priors override them
1347+ self .priors = {** self .priors_from_data (X , y ), ** self .priors }
1348+
12731349 # Build the model with treatment data
12741350 self .build_model (X , y , coords , treatment_data )
12751351
@@ -1367,12 +1443,31 @@ class TransferFunctionARRegression(PyMCModel):
13671443 - Posterior predictive sampling requires forward simulation of the AR process
13681444 - Convergence can be slower than the independent errors model; consider increasing tune/draws
13691445
1370- **Priors**: The model uses data-informed priors that scale with the outcome variable:
1446+ **Prior Customization**:
1447+
1448+ Priors are managed using the ``Prior`` class from ``pymc_extras`` and can be
1449+ customized via the ``priors`` parameter:
1450+
1451+ >>> from pymc_extras.prior import Prior
1452+ >>> model = cp.pymc_models.TransferFunctionARRegression(
1453+ ... saturation_type=None,
1454+ ... adstock_config={...},
1455+ ... priors={
1456+ ... "beta": Prior(
1457+ ... "Normal", mu=0, sigma=100, dims=["treated_units", "coeffs"]
1458+ ... ),
1459+ ... "rho": Prior(
1460+ ... "Uniform", lower=-0.95, upper=0.95, dims=["treated_units"]
1461+ ... ),
1462+ ... },
1463+ ... )
1464+
1465+ By default, data-informed priors are set automatically via ``priors_from_data()``:
13711466
1372- - Baseline coefficients: ``Normal(0, 5 * std(y))``
1373- - Treatment coefficients: ``Normal(0, 2 * std(y))`` or ``HalfNormal(2 * std(y))``
1374- - Error std: ``HalfNormal(2 * std(y))``
1375- - AR(1) coefficient: ``Uniform(-0.99, 0.99)``
1467+ - Baseline coefficients (``beta``) : ``Normal(0, 5 * std(y))``
1468+ - Treatment coefficients (``theta_treatment``) : ``Normal(0, 2 * std(y))`` or ``HalfNormal(2 * std(y))``
1469+ - Error std (``sigma``) : ``HalfNormal(2 * std(y))``
1470+ - AR(1) coefficient (``rho``) : ``Uniform(-0.99, 0.99)``
13761471
13771472 This adaptive approach ensures priors are reasonable regardless of data scale.
13781473
@@ -1429,6 +1524,86 @@ def __init__(
14291524 self .treatment_names = None
14301525 self .n_treatments = None
14311526
1527+ def priors_from_data (self , X , y ) -> Dict [str , Any ]:
1528+ """
1529+ Generate data-informed priors including AR(1) coefficient.
1530+
1531+ Similar to TransferFunctionLinearRegression but also includes
1532+ a prior for the AR(1) coefficient rho.
1533+
1534+ Parameters
1535+ ----------
1536+ X : xr.DataArray
1537+ Baseline design matrix.
1538+ y : xr.DataArray
1539+ Outcome variable.
1540+
1541+ Returns
1542+ -------
1543+ Dict[str, Prior]
1544+ Dictionary with Prior objects for beta, theta_treatment, sigma, and rho.
1545+
1546+ Notes
1547+ -----
1548+ The returned dictionary contains Prior objects with the following structure::
1549+
1550+ {
1551+ "beta": Prior(
1552+ "Normal", mu=0, sigma=5 * y_scale, dims=["treated_units", "coeffs"]
1553+ ),
1554+ "theta_treatment": Prior(
1555+ "Normal",
1556+ mu=0,
1557+ sigma=2 * y_scale,
1558+ dims=["treated_units", "treatment_names"],
1559+ ),
1560+ "sigma": Prior("HalfNormal", sigma=2 * y_scale, dims=["treated_units"]),
1561+ "rho": Prior(
1562+ "Uniform", lower=-0.99, upper=0.99, dims=["treated_units"]
1563+ ),
1564+ }
1565+
1566+ where ``y_scale = std(y)``.
1567+ """
1568+ y_scale = float (np .std (y ))
1569+
1570+ priors = {
1571+ "beta" : Prior (
1572+ "Normal" ,
1573+ mu = 0 ,
1574+ sigma = 5 * y_scale ,
1575+ dims = ["treated_units" , "coeffs" ],
1576+ ),
1577+ "sigma" : Prior (
1578+ "HalfNormal" ,
1579+ sigma = 2 * y_scale ,
1580+ dims = ["treated_units" ],
1581+ ),
1582+ "rho" : Prior (
1583+ "Uniform" ,
1584+ lower = - 0.99 ,
1585+ upper = 0.99 ,
1586+ dims = ["treated_units" ],
1587+ ),
1588+ }
1589+
1590+ # Treatment coefficient prior depends on constraint
1591+ if self .coef_constraint == "nonnegative" :
1592+ priors ["theta_treatment" ] = Prior (
1593+ "HalfNormal" ,
1594+ sigma = 2 * y_scale ,
1595+ dims = ["treated_units" , "treatment_names" ],
1596+ )
1597+ else :
1598+ priors ["theta_treatment" ] = Prior (
1599+ "Normal" ,
1600+ mu = 0 ,
1601+ sigma = 2 * y_scale ,
1602+ dims = ["treated_units" , "treatment_names" ],
1603+ )
1604+
1605+ return priors
1606+
14321607 def build_model (self , X , y , coords , treatment_data ):
14331608 """
14341609 Build the PyMC model with transforms and AR(1) errors using quasi-differencing.
@@ -1457,9 +1632,6 @@ def build_model(self, X, y, coords, treatment_data):
14571632 ).values .tolist ()
14581633 self .n_treatments = treatment_data .shape [1 ]
14591634
1460- # Compute data scale BEFORE entering model context (for data-informed priors)
1461- y_scale = float (np .std (y ))
1462-
14631635 with self :
14641636 self .add_coords (coords )
14651637
@@ -1611,27 +1783,13 @@ def build_model(self, X, y, coords, treatment_data):
16111783 # ==================================================================
16121784 # Regression Coefficients (with data-informed priors)
16131785 # ==================================================================
1614- # Baseline coefficients: prior std = 5 * outcome scale
1615- # This allows intercept to range widely while keeping some regularization
1616- beta = pm .Normal (
1617- "beta" , mu = 0 , sigma = 5 * y_scale , dims = ["treated_units" , "coeffs" ]
1618- )
1786+ # Baseline coefficients: data-informed priors set via priors_from_data()
1787+ beta = self .priors ["beta" ].create_variable ("beta" )
16191788
1620- # Treatment coefficients: prior std = 2 * outcome scale
1621- # Treatments typically have smaller effects than baseline level
1622- if self .coef_constraint == "nonnegative" :
1623- theta_treatment = pm .HalfNormal (
1624- "theta_treatment" ,
1625- sigma = 2 * y_scale ,
1626- dims = ["treated_units" , "treatment_names" ],
1627- )
1628- else :
1629- theta_treatment = pm .Normal (
1630- "theta_treatment" ,
1631- mu = 0 ,
1632- sigma = 2 * y_scale ,
1633- dims = ["treated_units" , "treatment_names" ],
1634- )
1789+ # Treatment coefficients: data-informed priors set via priors_from_data()
1790+ theta_treatment = self .priors ["theta_treatment" ].create_variable (
1791+ "theta_treatment"
1792+ )
16351793
16361794 # ==================================================================
16371795 # Mean Function
@@ -1652,11 +1810,11 @@ def build_model(self, X, y, coords, treatment_data):
16521810 # ==================================================================
16531811 # AR(1) Likelihood via Quasi-Differencing
16541812 # ==================================================================
1655- # AR(1) parameter: rho (constrained to ensure stationarity )
1656- rho = pm . Uniform ( "rho" , lower = - 0.99 , upper = 0.99 , dims = [ "treated_units" ] )
1813+ # AR(1) parameter: data-informed prior set via priors_from_data( )
1814+ rho = self . priors [ "rho" ]. create_variable ( "rho" )
16571815
1658- # Innovation standard deviation: prior centered on outcome scale
1659- sigma = pm . HalfNormal ( "sigma" , sigma = 2 * y_scale , dims = [ "treated_units" ] )
1816+ # Innovation standard deviation: data-informed prior set via priors_from_data()
1817+ sigma = self . priors [ "sigma" ]. create_variable ( " sigma" )
16601818
16611819 # Quasi-differencing approach using manual log-likelihood
16621820 # We can't use y_diff as observed data because it depends on rho
@@ -1717,6 +1875,10 @@ def fit(self, X, y, coords, treatment_data):
17171875 # sample_posterior_predictive() if provided in sample_kwargs.
17181876 random_seed = self .sample_kwargs .get ("random_seed" , None )
17191877
1878+ # Merge priors with precedence: user-specified > data-driven > defaults
1879+ # Data-driven priors are computed first, then user-specified priors override them
1880+ self .priors = {** self .priors_from_data (X , y ), ** self .priors }
1881+
17201882 # Build the model with treatment data
17211883 self .build_model (X , y , coords , treatment_data )
17221884
0 commit comments