@@ -929,12 +929,29 @@ def trace_backend(request):
929929 return trace
930930
931931
932- def test_random_deterministics (trace_backend ):
932+ @pytest .fixture (scope = "function" , params = ["FAST_COMPILE" , "NUMBA" , "JAX" ])
933+ def pytensor_mode (request ):
934+ return request .param
935+
936+
937+ def test_random_deterministics (trace_backend , pytensor_mode ):
933938 with pm .Model () as m :
934939 x = pm .Bernoulli ("x" , p = 0.5 ) * 0 # Force it to be zero
935940 pm .Deterministic ("y" , x + pm .Normal .dist ())
936941
937- idata1 = pm .sample (tune = 0 , draws = 1 , random_seed = 1 , trace = trace_backend )
938- idata2 = pm .sample (tune = 0 , draws = 1 , random_seed = 1 , trace = trace_backend )
939-
940- assert idata1 .posterior .equals (idata2 .posterior )
942+ if pytensor_mode == "JAX" :
943+ expected_warning = (
944+ "At the moment, it is not possible to set the random generator's key for "
945+ "JAX linked functions. This means that the draws yielded by the random "
946+ "variables that are requested by 'Deterministic' will not be reproducible."
947+ )
948+ with pytest .warns (UserWarning , match = expected_warning ):
949+ with pytensor .config .change_flags (mode = pytensor_mode ):
950+ idata1 = pm .sample (tune = 0 , draws = 1 , random_seed = 1 , trace = trace_backend )
951+ idata2 = pm .sample (tune = 0 , draws = 1 , random_seed = 1 , trace = trace_backend )
952+ assert not idata1 .posterior .equals (idata2 .posterior )
953+ else :
954+ with pytensor .config .change_flags (mode = pytensor_mode ):
955+ idata1 = pm .sample (tune = 0 , draws = 1 , random_seed = 1 , trace = trace_backend )
956+ idata2 = pm .sample (tune = 0 , draws = 1 , random_seed = 1 , trace = trace_backend )
957+ assert idata1 .posterior .equals (idata2 .posterior )
0 commit comments