@@ -95,6 +95,9 @@ class PyMCModel(pm.Model):
9595 def default_priors (self ):
9696 return {}
9797
98+ def priors_from_data (self , X , y ) -> Dict [str , Any ]:
99+ return {}
100+
98101 def __init__ (
99102 self ,
100103 sample_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -155,6 +158,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
155158 # sample_posterior_predictive() if provided in sample_kwargs.
156159 random_seed = self .sample_kwargs .get ("random_seed" , None )
157160
161+ self .priors = {** self .priors_from_data (X , y ), ** self .priors }
162+
158163 self .build_model (X , y , coords )
159164 with self :
160165 self .idata = pm .sample (** self .sample_kwargs )
@@ -250,26 +255,34 @@ def print_coefficients_for_unit(
250255 ) -> None :
251256 """Print coefficients for a single unit"""
252257 # Determine the width of the longest label
253- max_label_length = max (len (name ) for name in labels + ["sigma " ])
258+ max_label_length = max (len (name ) for name in labels + ["y_hat_sigma " ])
254259
255260 for name in labels :
256261 coeff_samples = unit_coeffs .sel (coeffs = name )
257262 print_row (max_label_length , name , coeff_samples , round_to )
258263
259264 # Add coefficient for measurement std
260- print_row (max_label_length , "sigma " , unit_sigma , round_to )
265+ print_row (max_label_length , "y_hat_sigma " , unit_sigma , round_to )
261266
262267 print ("Model coefficients:" )
263268 coeffs = az .extract (self .idata .posterior , var_names = "beta" )
264269
265- # Always has treated_units dimension - no branching needed!
270+ # Check if sigma or y_hat_sigma variable exists
271+ sigma_var_name = None
272+ if "sigma" in self .idata .posterior :
273+ sigma_var_name = "sigma"
274+ elif "y_hat_sigma" in self .idata .posterior :
275+ sigma_var_name = "y_hat_sigma"
276+ else :
277+ raise ValueError ("Neither 'sigma' nor 'y_hat_sigma' found in posterior" )
278+
266279 treated_units = coeffs .coords ["treated_units" ].values
267280 for unit in treated_units :
268281 if len (treated_units ) > 1 :
269282 print (f"\n Treated unit: { unit } " )
270283
271284 unit_coeffs = coeffs .sel (treated_units = unit )
272- unit_sigma = az .extract (self .idata .posterior , var_names = "sigma" ).sel (
285+ unit_sigma = az .extract (self .idata .posterior , var_names = sigma_var_name ).sel (
273286 treated_units = unit
274287 )
275288 print_coefficients_for_unit (unit_coeffs , unit_sigma , labels , round_to or 2 )
@@ -314,7 +327,11 @@ class LinearRegression(PyMCModel):
314327
315328 default_priors = {
316329 "beta" : Prior ("Normal" , mu = 0 , sigma = 50 , dims = ["treated_units" , "coeffs" ]),
317- "y_hat" : Prior ("Normal" , sigma = Prior ("HalfNormal" , sigma = 1 ), dims = "obs_ind" ),
330+ "y_hat" : Prior (
331+ "Normal" ,
332+ sigma = Prior ("HalfNormal" , sigma = 1 , dims = ["treated_units" ]),
333+ dims = ["obs_ind" , "treated_units" ],
334+ ),
318335 }
319336
320337 def build_model (self , X , y , coords ):
@@ -331,11 +348,10 @@ def build_model(self, X, y, coords):
331348 X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
332349 y = pm .Data ("y" , y , dims = ["obs_ind" , "treated_units" ])
333350 beta = self .priors ["beta" ].create_variable ("beta" )
334- sigma = pm .HalfNormal ("sigma" , 1 , dims = "treated_units" )
335351 mu = pm .Deterministic (
336352 "mu" , pt .dot (X , beta .T ), dims = ["obs_ind" , "treated_units" ]
337353 )
338- pm . Normal ("y_hat" , mu , sigma , observed = y , dims = [ "obs_ind" , "treated_units" ] )
354+ self . priors [ "y_hat" ]. create_likelihood_variable ("y_hat" , mu = mu , observed = y )
339355
340356
341357class WeightedSumFitter (PyMCModel ):
@@ -379,26 +395,34 @@ class WeightedSumFitter(PyMCModel):
379395 """ # noqa: W605
380396
381397 default_priors = {
382- "y_hat" : Prior ("Normal" , sigma = Prior ("HalfNormal" , sigma = 1 ), dims = "obs_ind" ),
398+ "y_hat" : Prior (
399+ "Normal" ,
400+ sigma = Prior ("HalfNormal" , sigma = 1 , dims = ["treated_units" ]),
401+ dims = ["obs_ind" , "treated_units" ],
402+ ),
383403 }
384404
405+ def priors_from_data (self , X , y ) -> Dict [str , Any ]:
406+ n_predictors = X .shape [1 ]
407+ return {
408+ "beta" : Prior (
409+ "Dirichlet" , a = np .ones (n_predictors ), dims = ["treated_units" , "coeffs" ]
410+ ),
411+ }
412+
385413 def build_model (self , X , y , coords ):
386414 """
387415 Defines the PyMC model
388416 """
389417 with self :
390418 self .add_coords (coords )
391- n_predictors = X .sizes ["coeffs" ]
392419 X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
393420 y = pm .Data ("y" , y , dims = ["obs_ind" , "treated_units" ])
394- beta = pm .Dirichlet (
395- "beta" , a = np .ones (n_predictors ), dims = ["treated_units" , "coeffs" ]
396- )
397- sigma = pm .HalfNormal ("sigma" , 1 , dims = "treated_units" )
421+ beta = self .priors ["beta" ].create_variable ("beta" )
398422 mu = pm .Deterministic (
399423 "mu" , pt .dot (X , beta .T ), dims = ["obs_ind" , "treated_units" ]
400424 )
401- pm . Normal ("y_hat" , mu , sigma , observed = y , dims = [ "obs_ind" , "treated_units" ] )
425+ self . priors [ "y_hat" ]. create_likelihood_variable ("y_hat" , mu = mu , observed = y )
402426
403427
404428class InstrumentalVariableRegression (PyMCModel ):
@@ -598,24 +622,8 @@ def build_model(self, X, t, coords, prior=None, noncentred=True):
598622 self .add_coords (coords )
599623 X_data = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
600624 t_data = pm .Data ("t" , t .flatten (), dims = "obs_ind" )
601-
602- if prior is not None :
603- # Use legacy interface for backward compatibility
604- if noncentred :
605- mu_beta , sigma_beta = prior ["b" ]
606- beta_std = pm .Normal ("beta_std" , 0 , 1 , dims = "coeffs" )
607- b = pm .Deterministic (
608- "beta_" , mu_beta + sigma_beta * beta_std , dims = "coeffs"
609- )
610- else :
611- b = pm .Normal (
612- "b" , mu = prior ["b" ][0 ], sigma = prior ["b" ][1 ], dims = "coeffs"
613- )
614- else :
615- # Use Prior class
616- b = self .priors ["b" ].create_variable ("b" )
617-
618- mu = pm .math .dot (X_data , b )
625+ b = self .priors ["b" ].create_variable ("b" )
626+ mu = pt .dot (X_data , b )
619627 p = pm .Deterministic ("p" , pm .math .invlogit (mu ))
620628 pm .Bernoulli ("t_pred" , p = p , observed = t_data , dims = "obs_ind" )
621629
0 commit comments