Skip to content

Commit 61c40a8

Browse files
rlouftwiecki
authored andcommitted
Rewrite size input of RandomVariables in JAX backend
1 parent 0d1f65f commit 61c40a8

File tree

7 files changed

+110
-7
lines changed

7 files changed

+110
-7
lines changed

pytensor/compile/mode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
449449

450450
JAX = Mode(
451451
JAXLinker(),
452-
RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
452+
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
453453
)
454454
NUMBA = Mode(
455455
NumbaLinker(),

pytensor/link/jax/dispatch/random.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pytensor.tensor.random.basic as aer
1010
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
11+
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
1112
from pytensor.tensor.shape import Shape, Shape_i
1213

1314

@@ -28,7 +29,7 @@
2829

2930

3031
def assert_size_argument_jax_compatible(node):
31-
"""Assert whether the current node can be compiled.
32+
"""Assert whether the current node can be JIT-compiled by JAX.
3233
3334
JAX can JIT-compile `jax.random` functions when the `size` argument
3435
is a concrete value, i.e. either a constant or the shape of any
@@ -37,7 +38,7 @@ def assert_size_argument_jax_compatible(node):
3738
"""
3839
size = node.inputs[1]
3940
size_op = size.owner.op
40-
if not isinstance(size_op, (Shape, Shape_i)):
41+
if not isinstance(size_op, (Shape, Shape_i, JAXShapeTuple)):
4142
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
4243

4344

pytensor/link/jax/dispatch/shape.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,31 @@
11
import jax.numpy as jnp
22

33
from pytensor.graph import Constant
4+
from pytensor.graph.basic import Apply
5+
from pytensor.graph.op import Op
46
from pytensor.link.jax.dispatch.basic import jax_funcify
57
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
8+
from pytensor.tensor.type import TensorType
9+
10+
11+
class JAXShapeTuple(Op):
12+
"""Dummy Op that represents a `size` specified as a tuple."""
13+
14+
def make_node(self, *inputs):
15+
dtype = inputs[0].type.dtype
16+
otype = TensorType(dtype, shape=(len(inputs),))
17+
return Apply(self, inputs, [otype()])
18+
19+
def perform(self, *inputs):
20+
return tuple(inputs)
21+
22+
23+
@jax_funcify.register(JAXShapeTuple)
24+
def jax_funcify_JAXShapeTuple(op, **kwargs):
25+
def shape_tuple_fn(*x):
26+
return tuple(x)
27+
28+
return shape_tuple_fn
629

730

831
@jax_funcify.register(Reshape)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,10 @@
11
# TODO: This is for backward-compatibility; remove when reasonable.
22
from pytensor.tensor.random.rewriting.basic import *
3+
4+
5+
# isort: off
6+
7+
# Register JAX specializations
8+
import pytensor.tensor.random.rewriting.jax
9+
10+
# isort: on
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph.rewriting.basic import in2out, node_rewriter
3+
from pytensor.tensor.basic import MakeVector
4+
from pytensor.tensor.elemwise import DimShuffle
5+
from pytensor.tensor.random.op import RandomVariable
6+
7+
8+
@node_rewriter([RandomVariable])
9+
def size_parameter_as_tuple(fgraph, node):
10+
"""Replace `MakeVector` and `DimShuffle` (when used to transform a scalar
11+
into a 1d vector) when they are found as the input of a `size` or `shape`
12+
parameter by `JAXShapeTuple` during transpilation.
13+
14+
The JAX implementations of `MakeVector` and `DimShuffle` always return JAX
15+
`TracedArrays`, but JAX only accepts concrete values as inputs for the `size`
16+
or `shape` parameter. When these `Op`s are used to convert scalar or tuple
17+
inputs, however, we can avoid tracing by making them return a tuple of their
18+
inputs instead.
19+
20+
Note that JAX does not accept scalar inputs for the `size` or `shape`
21+
parameters, and this rewrite also ensures that scalar inputs are turned into
22+
tuples during transpilation.
23+
24+
"""
25+
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
26+
27+
size_arg = node.inputs[1]
28+
size_node = size_arg.owner
29+
30+
if size_node is None:
31+
return
32+
33+
if isinstance(size_node.op, JAXShapeTuple):
34+
return
35+
36+
if isinstance(size_node.op, MakeVector) or (
37+
isinstance(size_node.op, DimShuffle)
38+
and size_node.op.input_broadcastable == ()
39+
and size_node.op.new_order == ("x",)
40+
):
41+
# Here PyTensor converted a tuple or list to a tensor
42+
new_size_args = JAXShapeTuple()(*size_node.inputs)
43+
new_inputs = list(node.inputs)
44+
new_inputs[1] = new_size_args
45+
46+
new_node = node.clone_with_new_inputs(new_inputs)
47+
return new_node.outputs
48+
49+
50+
optdb.register(
51+
"jax_size_parameter_as_tuple", in2out(size_parameter_as_tuple), "jax", position=100
52+
)

tests/link/jax/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def set_pytensor_flags():
2727
jax = pytest.importorskip("jax")
2828

2929

30-
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
30+
opts = RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"])
3131
jax_mode = Mode(JAXLinker(), opts)
3232
py_mode = Mode("py", opts)
3333

tests/link/jax/test_random.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,25 +454,44 @@ def test_random_concrete_shape():
454454
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
455455

456456

457-
@pytest.mark.xfail(reason="size argument specified as a tuple is a `DimShuffle` node")
458457
def test_random_concrete_shape_subtensor():
458+
"""JAX should compile when a concrete value is passed for the `size` parameter.
459+
460+
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
461+
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
462+
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
463+
rewrite.
464+
465+
JAX does not accept scalars as `size` or `shape` arguments, so this is a
466+
slight improvement over their API.
467+
468+
"""
459469
rng = shared(np.random.RandomState(123))
460470
x_at = at.dmatrix()
461471
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)
462472
jax_fn = function([x_at], out, mode=jax_mode)
463473
assert jax_fn(np.ones((2, 3))).shape == (3,)
464474

465475

466-
@pytest.mark.xfail(reason="size argument specified as a tuple is a `MakeVector` node")
467476
def test_random_concrete_shape_subtensor_tuple():
477+
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
478+
479+
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
480+
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
481+
scalar inputs into tuples of concrete values using the
482+
`jax_size_parameter_as_tuple` rewrite.
483+
484+
"""
468485
rng = shared(np.random.RandomState(123))
469486
x_at = at.dmatrix()
470487
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
471488
jax_fn = function([x_at], out, mode=jax_mode)
472489
assert jax_fn(np.ones((2, 3))).shape == (2,)
473490

474491

475-
@pytest.mark.xfail(reason="`size_at` should be specified as a static argument")
492+
@pytest.mark.xfail(
493+
reason="`size_at` should be specified as a static argument", strict=True
494+
)
476495
def test_random_concrete_shape_graph_input():
477496
rng = shared(np.random.RandomState(123))
478497
size_at = at.scalar()

0 commit comments

Comments
 (0)