@@ -1017,6 +1017,55 @@ def test_logging_sampled_basic_rvs_posterior_mutable(self, mock_sample_results,
1017
1017
]
1018
1018
caplog .clear ()
1019
1019
1020
+ def test_observed_data_needed_in_pp (self ):
1021
+ # Model where y_data is not part of the generative graph.
1022
+ # It shouldn't be needed to set a dummy value for posterior predictive sampling
1023
+
1024
+ with pm .Model (coords = {"trial" : range (5 ), "feature" : range (3 )}) as m :
1025
+ x_data = pm .Data ("x_data" , np .random .normal (size = (5 , 3 )), dims = ("trial" , "feat" ))
1026
+ y_data = pm .Data ("y_data" , np .random .normal (size = (5 ,)), dims = ("trial" ,))
1027
+ sigma = pm .HalfNormal ("sigma" )
1028
+ mu = x_data .sum (- 1 )
1029
+ pm .Normal ("y" , mu = mu , sigma = sigma , observed = y_data , shape = mu .shape , dims = ("trial" ,))
1030
+
1031
+ prior = pm .sample_prior_predictive (samples = 25 ).prior
1032
+
1033
+ fake_idata = InferenceData (posterior = prior )
1034
+
1035
+ new_coords = {"trial" : range (2 ), "feature" : range (3 )}
1036
+ new_x_data = np .random .normal (size = (2 , 3 ))
1037
+ with m :
1038
+ pm .set_data (
1039
+ {
1040
+ "x_data" : new_x_data ,
1041
+ },
1042
+ coords = new_coords ,
1043
+ )
1044
+ pp = pm .sample_posterior_predictive (fake_idata , predictions = True , progressbar = False )
1045
+ assert pp .predictions ["y" ].shape == (1 , 25 , 2 )
1046
+
1047
+ # In this case y_data is part of the generative graph, so we must set it to a compatible value
1048
+ with pm .Model (coords = {"trial" : range (5 ), "feature" : range (3 )}) as m :
1049
+ x_data = pm .Data ("x_data" , np .random .normal (size = (5 , 3 )), dims = ("trial" , "feat" ))
1050
+ y_data = pm .Data ("y_data" , np .random .normal (size = (5 ,)), dims = ("trial" ,))
1051
+ sigma = pm .HalfNormal ("sigma" )
1052
+ mu = (y_data .sum () * x_data ).sum (- 1 )
1053
+ pm .Normal ("y" , mu = mu , sigma = sigma , observed = y_data , shape = mu .shape , dims = ("trial" ,))
1054
+
1055
+ prior = pm .sample_prior_predictive (samples = 25 ).prior
1056
+
1057
+ fake_idata = InferenceData (posterior = prior )
1058
+
1059
+ with m :
1060
+ pm .set_data ({"x_data" : new_x_data }, coords = new_coords )
1061
+ with pytest .raises (ValueError , match = "conflicting sizes for dimension 'trial'" ):
1062
+ pm .sample_posterior_predictive (fake_idata , predictions = True , progressbar = False )
1063
+
1064
+ new_y_data = np .random .normal (size = (2 ,))
1065
+ with m :
1066
+ pm .set_data ({"y_data" : new_y_data })
1067
+ assert pp .predictions ["y" ].shape == (1 , 25 , 2 )
1068
+
1020
1069
1021
1070
@pytest .fixture (scope = "class" )
1022
1071
def point_list_arg_bug_fixture () -> tuple [pm .Model , pm .backends .base .MultiTrace ]:
0 commit comments