@@ -525,6 +525,7 @@ def fit_outcome_model(
525525 priors = {"b_outcome" : [0 , 1 ], "a_outcome" : [0 , 1 ], "sigma" : 1 },
526526 noncentred = True ,
527527 normal_outcome = True ,
528+ spline_component = False ,
528529 ):
529530 if not hasattr (self , "idata" ):
530531 raise AttributeError ("""Object is missing required attribute 'idata'
@@ -551,35 +552,40 @@ def fit_outcome_model(
551552 )
552553
553554 beta_ps_spline = pm .Normal ("beta_ps_spline" , 0 , 1 , size = 34 )
554- beta_ps = pm .Normal ("beta_ps" , 0 , 1 )
555+ beta_ps = pm .Normal ("beta_ps" , 0 , 1 , size = 2 )
555556
556557 chosen = np .random .choice (range (propensity_scores .shape [1 ]))
557558 p = propensity_scores [:, chosen ].values
558559
559- B = dmatrix (
560- "bs(ps, knots=knots, degree=3, include_intercept=True, lower_bound=0, upper_bound=1) - 1" ,
561- {"ps" : p , "knots" : np .linspace (0 , 1 , 30 )},
562- )
563- B_f = np .asarray (B , order = "F" )
564- splines_summed = pm .Deterministic (
565- "spline_features" , pm .math .dot (B_f , beta_ps_spline .T )
566- )
567-
568560 alpha_outcome = pm .Normal (
569561 "a_outcome" , priors ["a_outcome" ][0 ], priors ["a_outcome" ][1 ]
570562 )
563+
571564 mu_outcome = (
572565 alpha_outcome
573566 + pm .math .dot (X_data_outcome , beta )
574- + p * beta_ps
575- + splines_summed
567+ + beta_ps [ 0 ] * p
568+ + beta_ps [ 1 ] * ( p * self . t . flatten ())
576569 )
570+
571+ if spline_component :
572+ beta_ps_spline = pm .Normal ("beta_ps_spline" , 0 , 1 , size = 34 )
573+ B = dmatrix (
574+ "bs(ps, knots=knots, degree=3, include_intercept=True, lower_bound=0, upper_bound=1) - 1" ,
575+ {"ps" : p , "knots" : np .linspace (0 , 1 , 30 )},
576+ )
577+ B_f = np .asarray (B , order = "F" )
578+ splines_summed = pm .Deterministic (
579+ "spline_features" , pm .math .dot (B_f , beta_ps_spline .T )
580+ )
581+ mu_outcome = mu_outcome + splines_summed
582+
577583 sigma = pm .HalfNormal ("sigma" , priors ["sigma" ])
578584
579585 if normal_outcome :
580586 _ = pm .Normal ("like" , mu_outcome , sigma , observed = Y_data_ )
581587 else :
582- nu = pm .Exponential ("nu" , lam = 1 / 30 )
588+ nu = pm .Exponential ("nu" , lam = 1 / 10 )
583589 _ = pm .StudentT ("like" , nu = nu , mu = mu_outcome , sigma = sigma )
584590
585591 idata_outcome = pm .sample_prior_predictive (random_seed = random_seed )
0 commit comments