Skip to content

Commit e9b56ae

Browse files
committed
Add more tests for JAX implementation of ChoiceRV
1 parent 36b2ac9 commit e9b56ae

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

tests/link/jax/test_random.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -545,27 +545,49 @@ def test_random_dirichlet(parameter, size):
545545

546546

547547
def test_random_choice():
548-
# Elements are picked at equal frequency
549-
num_samples = 10000
548+
# `replace=True` and `p is None`
550549
rng = shared(np.random.RandomState(123))
551-
g = pt.random.choice(np.arange(4), size=num_samples, rng=rng)
550+
g = pt.random.choice(np.arange(4), size=10_000, rng=rng)
551+
g_fn = compile_random_function([], g, mode=jax_mode)
552+
samples = g_fn()
553+
assert samples.shape == (10_000,)
554+
# Elements are picked at equal frequency
555+
np.testing.assert_allclose(np.mean(samples == 3), 0.25, 2)
556+
557+
# `replace=True` and `p is not None`
558+
rng = shared(np.random.default_rng(123))
559+
g = pt.random.choice(4, p=np.array([0.0, 0.5, 0.0, 0.5]), size=(5, 2), rng=rng)
552560
g_fn = compile_random_function([], g, mode=jax_mode)
553561
samples = g_fn()
554-
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)
562+
assert samples.shape == (5, 2)
563+
# Only odd numbers are picked
564+
assert np.all(samples % 2 == 1)
555565

556-
# `replace=False` produces unique results
566+
# `replace=False` and `p is None`
557567
rng = shared(np.random.RandomState(123))
558-
g = pt.random.choice(np.arange(100), replace=False, size=99, rng=rng)
568+
g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng)
559569
g_fn = compile_random_function([], g, mode=jax_mode)
560570
samples = g_fn()
561-
assert len(np.unique(samples)) == 99
571+
assert samples.shape == (2, 49)
572+
# Elements are unique
573+
assert len(np.unique(samples)) == 98
562574

563-
# We can pass an array with probabilities
575+
# `replace=False` and `p is not None`
564576
rng = shared(np.random.RandomState(123))
565-
g = pt.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
577+
g = pt.random.choice(
578+
8,
579+
p=np.array([0.25, 0, 0.25, 0, 0.25, 0, 0.25, 0]),
580+
size=3,
581+
rng=rng,
582+
replace=False,
583+
)
566584
g_fn = compile_random_function([], g, mode=jax_mode)
567585
samples = g_fn()
568-
np.testing.assert_allclose(samples, np.zeros(10))
586+
assert samples.shape == (3,)
587+
# Elements are unique
588+
assert len(np.unique(samples)) == 3
589+
# Only even numbers are picked
590+
assert np.all(samples % 2 == 0)
569591

570592

571593
def test_random_categorical():

0 commit comments

Comments
 (0)