@@ -73,6 +73,9 @@ class PyMCModel(pm.Model):
7373 def default_priors (self ):
7474 return {}
7575
76+ def priors_from_data (self , X , y ) -> Dict [str , Any ]:
77+ return {}
78+
7679 def __init__ (
7780 self ,
7881 sample_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -122,6 +125,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
122125 # sample_posterior_predictive() if provided in sample_kwargs.
123126 random_seed = self .sample_kwargs .get ("random_seed" , None )
124127
128+ self .priors = {** self .priors_from_data (X , y ), ** self .priors }
129+
125130 self .build_model (X , y , coords )
126131 with self :
127132 self .idata = pm .sample (** self .sample_kwargs )
@@ -295,16 +300,22 @@ class WeightedSumFitter(PyMCModel):
295300 "y_hat" : Prior ("Normal" , sigma = Prior ("HalfNormal" , sigma = 1 ), dims = "obs_ind" ),
296301 }
297302
303+ def priors_from_data (self , X , y ) -> Dict [str , Any ]:
304+ n_predictors = X .shape [1 ]
305+
306+ return {
307+ "beta" : Prior ("Dirichlet" , a = np .ones (n_predictors ), dims = "coeffs" ),
308+ }
309+
298310 def build_model (self , X , y , coords ):
299311 """
300312 Defines the PyMC model
301313 """
302314 with self :
303315 self .add_coords (coords )
304- n_predictors = X .shape [1 ]
305316 X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
306317 y = pm .Data ("y" , y [:, 0 ], dims = "obs_ind" )
307- beta = pm . Dirichlet ( "beta" , a = np . ones ( n_predictors ), dims = "coeffs " )
318+ beta = self . priors [ "beta" ]. create_variable ( "beta " )
308319 mu = pm .Deterministic ("mu" , pm .math .dot (X , beta ), dims = "obs_ind" )
309320 self .priors ["y_hat" ].create_likelihood_variable ("y_hat" , mu = mu , observed = y )
310321
0 commit comments