@@ -545,27 +545,49 @@ def test_random_dirichlet(parameter, size):
545
545
546
546
547
547
def test_random_choice ():
548
- # Elements are picked at equal frequency
549
- num_samples = 10000
548
+ # `replace=True` and `p is None`
550
549
rng = shared (np .random .RandomState (123 ))
551
- g = pt .random .choice (np .arange (4 ), size = num_samples , rng = rng )
550
+ g = pt .random .choice (np .arange (4 ), size = 10_000 , rng = rng )
551
+ g_fn = compile_random_function ([], g , mode = jax_mode )
552
+ samples = g_fn ()
553
+ assert samples .shape == (10_000 ,)
554
+ # Elements are picked at equal frequency
555
+ np .testing .assert_allclose (np .mean (samples == 3 ), 0.25 , 2 )
556
+
557
+ # `replace=True` and `p is not None`
558
+ rng = shared (np .random .default_rng (123 ))
559
+ g = pt .random .choice (4 , p = np .array ([0.0 , 0.5 , 0.0 , 0.5 ]), size = (5 , 2 ), rng = rng )
552
560
g_fn = compile_random_function ([], g , mode = jax_mode )
553
561
samples = g_fn ()
554
- np .testing .assert_allclose (np .sum (samples == 3 ) / num_samples , 0.25 , 2 )
562
+ assert samples .shape == (5 , 2 )
563
+ # Only odd numbers are picked
564
+ assert np .all (samples % 2 == 1 )
555
565
556
- # `replace=False` produces unique results
566
+ # `replace=False` and `p is None`
557
567
rng = shared (np .random .RandomState (123 ))
558
- g = pt .random .choice (np .arange (100 ), replace = False , size = 99 , rng = rng )
568
+ g = pt .random .choice (np .arange (100 ), replace = False , size = ( 2 , 49 ) , rng = rng )
559
569
g_fn = compile_random_function ([], g , mode = jax_mode )
560
570
samples = g_fn ()
561
- assert len (np .unique (samples )) == 99
571
+ assert samples .shape == (2 , 49 )
572
+ # Elements are unique
573
+ assert len (np .unique (samples )) == 98
562
574
563
- # We can pass an array with probabilities
575
+ # `replace=False` and `p is not None`
564
576
rng = shared (np .random .RandomState (123 ))
565
- g = pt .random .choice (np .arange (3 ), p = np .array ([1.0 , 0.0 , 0.0 ]), size = 10 , rng = rng )
577
+ g = pt .random .choice (
578
+ 8 ,
579
+ p = np .array ([0.25 , 0 , 0.25 , 0 , 0.25 , 0 , 0.25 , 0 ]),
580
+ size = 3 ,
581
+ rng = rng ,
582
+ replace = False ,
583
+ )
566
584
g_fn = compile_random_function ([], g , mode = jax_mode )
567
585
samples = g_fn ()
568
- np .testing .assert_allclose (samples , np .zeros (10 ))
586
+ assert samples .shape == (3 ,)
587
+ # Elements are unique
588
+ assert len (np .unique (samples )) == 3
589
+ # Only even numbers are picked
590
+ assert np .all (samples % 2 == 0 )
569
591
570
592
571
593
def test_random_categorical ():
0 commit comments