2727from pytensor .link .jax .dispatch .random import numpyro_available # noqa: E402
2828
2929
30- def compile_random_function (* args , mode = "JAX" , ** kwargs ):
30+ def compile_random_function (* args , mode = jax_mode , ** kwargs ):
3131 with pytest .warns (
3232 UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
3333 ):
@@ -42,7 +42,7 @@ def test_random_RandomStream():
4242 srng = RandomStream (seed = 123 )
4343 out = srng .normal () - srng .normal ()
4444
45- fn = compile_random_function ([], out , mode = jax_mode )
45+ fn = compile_random_function ([], out )
4646 jax_res_1 = fn ()
4747 jax_res_2 = fn ()
4848
@@ -55,7 +55,7 @@ def test_random_updates(rng_ctor):
5555 rng = shared (original_value , name = "original_rng" , borrow = False )
5656 next_rng , x = pt .random .normal (name = "x" , rng = rng ).owner .outputs
5757
58- f = compile_random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
58+ f = compile_random_function ([], [x ], updates = {rng : next_rng })
5959 assert f () != f ()
6060
6161 # Check that original rng variable content was not overwritten when calling jax_typify
@@ -479,7 +479,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
479479 """
480480 rng = shared (np .random .default_rng (29403 ))
481481 g = rv_op (* dist_params , size = (10000 , * base_size ), rng = rng )
482- g_fn = compile_random_function (dist_params , g , mode = jax_mode )
482+ g_fn = compile_random_function (dist_params , g )
483483 samples = g_fn (
484484 * [
485485 i .tag .test_value
@@ -521,7 +521,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
521521 param_that_implies_size = pt .matrix ("param_that_implies_size" , shape = (None , None ))
522522
523523 rv = rv_fn (param_that_implies_size )
524- draws = rv .eval ({param_that_implies_size : np .zeros ((2 , 2 ))}, mode = jax_mode )
524+ draws = rv .eval ({param_that_implies_size : np .zeros ((2 , 2 ))})
525525
526526 assert draws .shape == (2 , 2 )
527527 assert np .unique (draws ).size == 4
@@ -531,7 +531,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
531531def test_random_bernoulli (size ):
532532 rng = shared (np .random .default_rng (123 ))
533533 g = pt .random .bernoulli (0.5 , size = (1000 , * size ), rng = rng )
534- g_fn = compile_random_function ([], g , mode = jax_mode )
534+ g_fn = compile_random_function ([], g )
535535 samples = g_fn ()
536536 np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
537537
@@ -542,7 +542,7 @@ def test_random_mvnormal():
542542 mu = np .ones (4 )
543543 cov = np .eye (4 )
544544 g = pt .random .multivariate_normal (mu , cov , size = (10000 ,), rng = rng )
545- g_fn = compile_random_function ([], g , mode = jax_mode )
545+ g_fn = compile_random_function ([], g )
546546 samples = g_fn ()
547547 np .testing .assert_allclose (samples .mean (axis = 0 ), mu , atol = 0.1 )
548548
@@ -557,7 +557,7 @@ def test_random_mvnormal():
557557def test_random_dirichlet (parameter , size ):
558558 rng = shared (np .random .default_rng (123 ))
559559 g = pt .random .dirichlet (parameter , size = (1000 , * size ), rng = rng )
560- g_fn = compile_random_function ([], g , mode = jax_mode )
560+ g_fn = compile_random_function ([], g )
561561 samples = g_fn ()
562562 np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
563563
@@ -566,7 +566,7 @@ def test_random_choice():
566566 # `replace=True` and `p is None`
567567 rng = shared (np .random .default_rng (123 ))
568568 g = pt .random .choice (np .arange (4 ), size = 10_000 , rng = rng )
569- g_fn = compile_random_function ([], g , mode = jax_mode )
569+ g_fn = compile_random_function ([], g )
570570 samples = g_fn ()
571571 assert samples .shape == (10_000 ,)
572572 # Elements are picked at equal frequency
@@ -575,7 +575,7 @@ def test_random_choice():
575575 # `replace=True` and `p is not None`
576576 rng = shared (np .random .default_rng (123 ))
577577 g = pt .random .choice (4 , p = np .array ([0.0 , 0.5 , 0.0 , 0.5 ]), size = (5 , 2 ), rng = rng )
578- g_fn = compile_random_function ([], g , mode = jax_mode )
578+ g_fn = compile_random_function ([], g )
579579 samples = g_fn ()
580580 assert samples .shape == (5 , 2 )
581581 # Only odd numbers are picked
@@ -584,7 +584,7 @@ def test_random_choice():
584584 # `replace=False` and `p is None`
585585 rng = shared (np .random .default_rng (123 ))
586586 g = pt .random .choice (np .arange (100 ), replace = False , size = (2 , 49 ), rng = rng )
587- g_fn = compile_random_function ([], g , mode = jax_mode )
587+ g_fn = compile_random_function ([], g )
588588 samples = g_fn ()
589589 assert samples .shape == (2 , 49 )
590590 # Elements are unique
@@ -599,7 +599,7 @@ def test_random_choice():
599599 rng = rng ,
600600 replace = False ,
601601 )
602- g_fn = compile_random_function ([], g , mode = jax_mode )
602+ g_fn = compile_random_function ([], g )
603603 samples = g_fn ()
604604 assert samples .shape == (3 ,)
605605 # Elements are unique
@@ -611,14 +611,14 @@ def test_random_choice():
611611def test_random_categorical ():
612612 rng = shared (np .random .default_rng (123 ))
613613 g = pt .random .categorical (0.25 * np .ones (4 ), size = (10000 , 4 ), rng = rng )
614- g_fn = compile_random_function ([], g , mode = jax_mode )
614+ g_fn = compile_random_function ([], g )
615615 samples = g_fn ()
616616 assert samples .shape == (10000 , 4 )
617617 np .testing .assert_allclose (samples .mean (axis = 0 ), 6 / 4 , 1 )
618618
619619 # Test zero probabilities
620620 g = pt .random .categorical ([0 , 0.5 , 0 , 0.5 ], size = (1000 ,), rng = rng )
621- g_fn = compile_random_function ([], g , mode = jax_mode )
621+ g_fn = compile_random_function ([], g )
622622 samples = g_fn ()
623623 assert samples .shape == (1000 ,)
624624 assert np .all (samples % 2 == 1 )
@@ -628,7 +628,7 @@ def test_random_permutation():
628628 array = np .arange (4 )
629629 rng = shared (np .random .default_rng (123 ))
630630 g = pt .random .permutation (array , rng = rng )
631- g_fn = compile_random_function ([], g , mode = jax_mode )
631+ g_fn = compile_random_function ([], g )
632632 permuted = g_fn ()
633633 with pytest .raises (AssertionError ):
634634 np .testing .assert_allclose (array , permuted )
@@ -651,7 +651,7 @@ def test_random_geometric():
651651 rng = shared (np .random .default_rng (123 ))
652652 p = np .array ([0.3 , 0.7 ])
653653 g = pt .random .geometric (p , size = (10_000 , 2 ), rng = rng )
654- g_fn = compile_random_function ([], g , mode = jax_mode )
654+ g_fn = compile_random_function ([], g )
655655 samples = g_fn ()
656656 np .testing .assert_allclose (samples .mean (axis = 0 ), 1 / p , rtol = 0.1 )
657657 np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt ((1 - p ) / p ** 2 ), rtol = 0.1 )
@@ -662,7 +662,7 @@ def test_negative_binomial():
662662 n = np .array ([10 , 40 ])
663663 p = np .array ([0.3 , 0.7 ])
664664 g = pt .random .negative_binomial (n , p , size = (10_000 , 2 ), rng = rng )
665- g_fn = compile_random_function ([], g , mode = jax_mode )
665+ g_fn = compile_random_function ([], g )
666666 samples = g_fn ()
667667 np .testing .assert_allclose (samples .mean (axis = 0 ), n * (1 - p ) / p , rtol = 0.1 )
668668 np .testing .assert_allclose (
@@ -676,7 +676,7 @@ def test_binomial():
676676 n = np .array ([10 , 40 ])
677677 p = np .array ([0.3 , 0.7 ])
678678 g = pt .random .binomial (n , p , size = (10_000 , 2 ), rng = rng )
679- g_fn = compile_random_function ([], g , mode = jax_mode )
679+ g_fn = compile_random_function ([], g )
680680 samples = g_fn ()
681681 np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
682682 np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.1 )
@@ -691,7 +691,7 @@ def test_beta_binomial():
691691 a = np .array ([1.5 , 13 ])
692692 b = np .array ([0.5 , 9 ])
693693 g = pt .random .betabinom (n , a , b , size = (10_000 , 2 ), rng = rng )
694- g_fn = compile_random_function ([], g , mode = jax_mode )
694+ g_fn = compile_random_function ([], g )
695695 samples = g_fn ()
696696 np .testing .assert_allclose (samples .mean (axis = 0 ), n * a / (a + b ), rtol = 0.1 )
697697 np .testing .assert_allclose (
@@ -725,7 +725,7 @@ def test_vonmises_mu_outside_circle():
725725 mu = np .array ([- 30 , 40 ])
726726 kappa = np .array ([100 , 10 ])
727727 g = pt .random .vonmises (mu , kappa , size = (10_000 , 2 ), rng = rng )
728- g_fn = compile_random_function ([], g , mode = jax_mode )
728+ g_fn = compile_random_function ([], g )
729729 samples = g_fn ()
730730 np .testing .assert_allclose (
731731 samples .mean (axis = 0 ), (mu + np .pi ) % (2.0 * np .pi ) - np .pi , rtol = 0.1
@@ -823,15 +823,15 @@ def test_random_concrete_shape():
823823 rng = shared (np .random .default_rng (123 ))
824824 x_pt = pt .dmatrix ()
825825 out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
826- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
826+ jax_fn = compile_random_function ([x_pt ], out )
827827 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
828828
829829
830830def test_random_concrete_shape_from_param ():
831831 rng = shared (np .random .default_rng (123 ))
832832 x_pt = pt .dmatrix ()
833833 out = pt .random .normal (x_pt , 1 , rng = rng )
834- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
834+ jax_fn = compile_random_function ([x_pt ], out )
835835 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
836836
837837
@@ -850,7 +850,7 @@ def test_random_concrete_shape_subtensor():
850850 rng = shared (np .random .default_rng (123 ))
851851 x_pt = pt .dmatrix ()
852852 out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
853- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
853+ jax_fn = compile_random_function ([x_pt ], out )
854854 assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
855855
856856
@@ -866,7 +866,7 @@ def test_random_concrete_shape_subtensor_tuple():
866866 rng = shared (np .random .default_rng (123 ))
867867 x_pt = pt .dmatrix ()
868868 out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
869- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
869+ jax_fn = compile_random_function ([x_pt ], out )
870870 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
871871
872872
@@ -877,7 +877,7 @@ def test_random_concrete_shape_graph_input():
877877 rng = shared (np .random .default_rng (123 ))
878878 size_pt = pt .scalar ()
879879 out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
880- jax_fn = compile_random_function ([size_pt ], out , mode = jax_mode )
880+ jax_fn = compile_random_function ([size_pt ], out )
881881 assert jax_fn (10 ).shape == (10 ,)
882882
883883
0 commit comments