Skip to content

Commit 0d1f65f

Browse files
rlouftwiecki
authored andcommitted
Raise when the RandomVariable will not compile
1 parent a110e82 commit 0d1f65f

File tree

2 files changed

+58
-17
lines changed

2 files changed

+58
-17
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,39 @@
88

99
import pytensor.tensor.random.basic as aer
1010
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
11-
from pytensor.tensor.shape import Shape
11+
from pytensor.tensor.shape import Shape, Shape_i
1212

1313

1414
numpy_bit_gens = {"MT19937": 0, "PCG64": 1, "Philox": 2, "SFC64": 3}
1515

1616

17+
SIZE_NOT_COMPATIBLE = """JAX random variables require concrete values for the `size` parameter of the distributions.
18+
Concrete values are either constants:
19+
20+
>>> import pytensor.tensor as at
21+
>>> x_rv = at.random.normal(0, 1, size=(3, 2))
22+
23+
or the shape of an array:
24+
25+
>>> m = at.matrix()
26+
>>> x_rv = at.random.normal(0, 1, size=m.shape)
27+
"""
28+
29+
30+
def assert_size_argument_jax_compatible(node):
31+
"""Assert whether the current node can be compiled.
32+
33+
JAX can JIT-compile `jax.random` functions when the `size` argument
34+
is a concrete value, i.e. either a constant or the shape of any
35+
traced value.
36+
37+
"""
38+
size = node.inputs[1]
39+
size_op = size.owner.op
40+
if not isinstance(size_op, (Shape, Shape_i)):
41+
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
42+
43+
1744
@jax_typify.register(RandomState)
1845
def jax_typify_RandomState(state, **kwargs):
1946
state = state.get_state(legacy=False)
@@ -65,12 +92,7 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
6592
# by a `Shape` operator in which case JAX will compile, or it is
6693
# not and we fail gracefully.
6794
if None in out_size:
68-
if not isinstance(node.inputs[1].owner.op, Shape):
69-
raise NotImplementedError(
70-
"""JAX random variables require concrete values for the `size` parameter of the distributions.
71-
Concrete values are either constants, or the shape of an array.
72-
"""
73-
)
95+
assert_size_argument_jax_compatible(node)
7496

7597
def sample_fn(rng, size, dtype, *parameters):
7698
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)

tests/link/jax/test_random.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -449,14 +449,33 @@ def test_random_concrete_shape():
449449
"""
450450
rng = shared(np.random.RandomState(123))
451451
x_at = at.dmatrix()
452-
f = at.random.normal(0, 1, size=(3,), rng=rng)
453-
g = at.random.normal(f, 1, size=x_at.shape, rng=rng)
454-
g_fn = function([x_at], g, mode=jax_mode)
455-
_ = g_fn(np.ones((2, 3)))
452+
out = at.random.normal(0, 1, size=x_at.shape, rng=rng)
453+
jax_fn = function([x_at], out, mode=jax_mode)
454+
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
456455

457-
# This should compile, and `size_at` be passed to the list of `static_argnums`.
458-
with pytest.raises(NotImplementedError):
459-
size_at = at.scalar()
460-
g = at.random.normal(f, 1, size=size_at, rng=rng)
461-
g_fn = function([size_at], g, mode=jax_mode)
462-
_ = g_fn(10)
456+
457+
@pytest.mark.xfail(reason="size argument specified as a tuple is a `DimShuffle` node")
458+
def test_random_concrete_shape_subtensor():
459+
rng = shared(np.random.RandomState(123))
460+
x_at = at.dmatrix()
461+
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)
462+
jax_fn = function([x_at], out, mode=jax_mode)
463+
assert jax_fn(np.ones((2, 3))).shape == (3,)
464+
465+
466+
@pytest.mark.xfail(reason="size argument specified as a tuple is a `MakeVector` node")
467+
def test_random_concrete_shape_subtensor_tuple():
468+
rng = shared(np.random.RandomState(123))
469+
x_at = at.dmatrix()
470+
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
471+
jax_fn = function([x_at], out, mode=jax_mode)
472+
assert jax_fn(np.ones((2, 3))).shape == (2,)
473+
474+
475+
@pytest.mark.xfail(reason="`size_at` should be specified as a static argument")
476+
def test_random_concrete_shape_graph_input():
477+
rng = shared(np.random.RandomState(123))
478+
size_at = at.scalar()
479+
out = at.random.normal(0, 1, size=size_at, rng=rng)
480+
jax_fn = function([size_at], out, mode=jax_mode)
481+
assert jax_fn(10).shape == (10,)

0 commit comments

Comments
 (0)