Skip to content

Commit b9792d8

Browse files
Add JAX support for pt.tri (#302)
1 parent 6b43b43 commit b9792d8

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ScalarFromTensor,
1919
Split,
2020
TensorFromScalar,
21+
Tri,
2122
get_underlying_scalar_constant_value,
2223
)
2324
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -26,7 +27,6 @@
2627
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
2728
to be constants. The graph that you defined thus cannot be JIT-compiled
2829
by JAX. An example of a graph that can be compiled to JAX:
29-
3030
>>> import pytensor.tensor basic
3131
>>> at.arange(1, 10, 2)
3232
"""
@@ -193,3 +193,18 @@ def scalar_from_tensor(x):
193193
return jnp.array(x).flatten()[0]
194194

195195
return scalar_from_tensor
196+
197+
198+
@jax_funcify.register(Tri)
199+
def jax_funcify_Tri(op, node, **kwargs):
200+
# node.inputs is N, M, k
201+
const_args = [getattr(x, "data", None) for x in node.inputs]
202+
203+
def tri(*args):
204+
# args is N, M, k
205+
args = [
206+
x if const_x is None else const_x for x, const_x in zip(args, const_args)
207+
]
208+
return jnp.tri(*args, dtype=op.dtype)
209+
210+
return tri

tests/link/jax/test_tensor_basic.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,30 @@ def test_jax_eye():
191191
out_fg = FunctionGraph([], [out])
192192

193193
compare_jax_and_py(out_fg, [])
194+
195+
196+
def test_tri():
197+
out = at.tri(10, 10, 0)
198+
fgraph = FunctionGraph([], [out])
199+
compare_jax_and_py(fgraph, [])
200+
201+
202+
def test_tri_nonconcrete():
203+
"""JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
204+
205+
m, n, k = (
206+
scalar("a", dtype="int64"),
207+
scalar("n", dtype="int64"),
208+
scalar("k", dtype="int64"),
209+
)
210+
m.tag.test_value = 10
211+
n.tag.test_value = 10
212+
k.tag.test_value = 0
213+
214+
out = at.tri(m, n, k)
215+
216+
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but
217+
# the error handler raises an Attribute error first, so that's what this test needs to pass
218+
with pytest.raises(AttributeError):
219+
fgraph = FunctionGraph([m, n, k], [out])
220+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

0 commit comments

Comments
 (0)