2222import  pytensor .tensor  as  pt 
2323import  xarray  as  xr 
2424from  arviz  import  r2_score 
25+ from  patsy  import  dmatrix 
2526
2627from  causalpy .utils  import  round_num 
2728
@@ -473,22 +474,30 @@ class PropensityScore(PyMCModel):
473474    ...                 'coeffs': ['age', 'race'], 
474475    ...                 'obs_ind': np.arange(df.shape[0]) 
475476    ...                }, 
477+     ...                prior={'b': [0, 1]}, 
476478    ... ) 
477479    Inference... 
478480    """   # noqa: W605 
479481
480-     def  build_model (self , X , t , coords ):
482+     def  build_model (self , X , t , coords ,  prior ,  noncentred ):
481483        "Defines the PyMC propensity model" 
482484        with  self :
483485            self .add_coords (coords )
484486            X_data  =  pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
485487            t_data  =  pm .Data ("t" , t .flatten (), dims = "obs_ind" )
486-             b  =  pm .Normal ("b" , mu = 0 , sigma = 1 , dims = "coeffs" )
488+             if  noncentred :
489+                 mu_beta , sigma_beta  =  prior ["b" ]
490+                 beta_std  =  pm .Normal ("beta_std" , 0 , 1 , dims = "coeffs" )
491+                 b  =  pm .Deterministic (
492+                     "beta_" , mu_beta  +  sigma_beta  *  beta_std , dims = "coeffs" 
493+                 )
494+             else :
495+                 b  =  pm .Normal ("b" , mu = prior ["b" ][0 ], sigma = prior ["b" ][1 ], dims = "coeffs" )
487496            mu  =  pm .math .dot (X_data , b )
488497            p  =  pm .Deterministic ("p" , pm .math .invlogit (mu ))
489498            pm .Bernoulli ("t_pred" , p = p , observed = t_data , dims = "obs_ind" )
490499
491-     def  fit (self , X , t , coords ):
500+     def  fit (self , X , t , coords ,  prior = { "b" : [ 0 ,  1 ]},  noncentred = True ):
492501        """Draw samples from posterior, prior predictive, and posterior predictive 
493502        distributions. We overwrite the base method because the base method assumes 
494503        a variable y and we use t to indicate the treatment variable here. 
@@ -497,7 +506,7 @@ def fit(self, X, t, coords):
497506        # sample_posterior_predictive() if provided in sample_kwargs. 
498507        random_seed  =  self .sample_kwargs .get ("random_seed" , None )
499508
500-         self .build_model (X , t , coords )
509+         self .build_model (X , t , coords ,  prior ,  noncentred )
501510        with  self :
502511            self .idata  =  pm .sample (** self .sample_kwargs )
503512            self .idata .extend (pm .sample_prior_predictive (random_seed = random_seed ))
@@ -507,3 +516,73 @@ def fit(self, X, t, coords):
507516                )
508517            )
509518        return  self .idata 
519+ 
520+     def  fit_outcome_model (
521+         self ,
522+         X_outcome ,
523+         y ,
524+         coords ,
525+         priors = {"b_outcome" : [0 , 1 ], "a_outcome" : [0 , 1 ], "sigma" : 1 },
526+         noncentred = True ,
527+         normal_outcome = True ,
528+     ):
529+         if  not  hasattr (self , "idata" ):
530+             raise  AttributeError ("""Object is missing required attribute 'idata' 
531+                                  so cannot proceed. Call fit() first""" )
532+         propensity_scores  =  az .extract (self .idata )["p" ]
533+         random_seed  =  self .sample_kwargs .get ("random_seed" , None )
534+ 
535+         with  pm .Model (coords = coords ) as  model_outcome :
536+             X_data_outcome  =  pm .Data ("X_outcome" , X_outcome )
537+             Y_data_  =  pm .Data ("Y" , y )
538+ 
539+             if  noncentred :
540+                 mu_beta , sigma_beta  =  priors ["b_outcome" ]
541+                 beta_std  =  pm .Normal ("beta_std" , 0 , 1 , dims = "outcome_coeffs" )
542+                 beta  =  pm .Deterministic (
543+                     "beta_" , mu_beta  +  sigma_beta  *  beta_std , dims = "outcome_coeffs" 
544+                 )
545+             else :
546+                 beta  =  pm .Normal (
547+                     "beta_" ,
548+                     priors ["b_outcome" ][0 ],
549+                     priors ["b_outcome" ][1 ],
550+                     dims = "outcome_coeffs" ,
551+                 )
552+ 
553+             beta_ps_spline  =  pm .Normal ("beta_ps_spline" , 0 , 1 , size = 34 )
554+             beta_ps  =  pm .Normal ("beta_ps" , 0 , 1 )
555+ 
556+             chosen  =  np .random .choice (range (propensity_scores .shape [1 ]))
557+             p  =  propensity_scores [:, chosen ].values 
558+ 
559+             B  =  dmatrix (
560+                 "bs(ps, knots=knots, degree=3, include_intercept=True, lower_bound=0, upper_bound=1) - 1" ,
561+                 {"ps" : p , "knots" : np .linspace (0 , 1 , 30 )},
562+             )
563+             B_f  =  np .asarray (B , order = "F" )
564+             splines_summed  =  pm .Deterministic (
565+                 "spline_features" , pm .math .dot (B_f , beta_ps_spline .T )
566+             )
567+ 
568+             alpha_outcome  =  pm .Normal (
569+                 "a_outcome" , priors ["a_outcome" ][0 ], priors ["a_outcome" ][1 ]
570+             )
571+             mu_outcome  =  (
572+                 alpha_outcome 
573+                 +  pm .math .dot (X_data_outcome , beta )
574+                 +  p  *  beta_ps 
575+                 +  splines_summed 
576+             )
577+             sigma  =  pm .HalfNormal ("sigma" , priors ["sigma" ])
578+ 
579+             if  normal_outcome :
580+                 _  =  pm .Normal ("like" , mu_outcome , sigma , observed = Y_data_ )
581+             else :
582+                 nu  =  pm .Exponential ("nu" , lam = 1  /  30 )
583+                 _  =  pm .StudentT ("like" , nu = nu , mu = mu_outcome , sigma = sigma )
584+ 
585+             idata_outcome  =  pm .sample_prior_predictive (random_seed = random_seed )
586+             idata_outcome .extend (pm .sample (** self .sample_kwargs ))
587+ 
588+         return  idata_outcome , model_outcome 
0 commit comments