File tree Expand file tree Collapse file tree 2 files changed +43
-1
lines changed
pytensor/link/jax/dispatch Expand file tree Collapse file tree 2 files changed +43
-1
lines changed Original file line number Diff line number Diff line change 18
18
ScalarFromTensor ,
19
19
Split ,
20
20
TensorFromScalar ,
21
+ Tri ,
21
22
get_underlying_scalar_constant_value ,
22
23
)
23
24
from pytensor .tensor .exceptions import NotScalarConstantError
26
27
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
27
28
to be constants. The graph that you defined thus cannot be JIT-compiled
28
29
by JAX. An example of a graph that can be compiled to JAX:
29
-
30
30
>>> import pytensor.tensor basic
31
31
>>> at.arange(1, 10, 2)
32
32
"""
@@ -193,3 +193,18 @@ def scalar_from_tensor(x):
193
193
return jnp .array (x ).flatten ()[0 ]
194
194
195
195
return scalar_from_tensor
196
+
197
+
198
+ @jax_funcify .register (Tri )
199
+ def jax_funcify_Tri (op , node , ** kwargs ):
200
+ # node.inputs is N, M, k
201
+ const_args = [getattr (x , "data" , None ) for x in node .inputs ]
202
+
203
+ def tri (* args ):
204
+ # args is N, M, k
205
+ args = [
206
+ x if const_x is None else const_x for x , const_x in zip (args , const_args )
207
+ ]
208
+ return jnp .tri (* args , dtype = op .dtype )
209
+
210
+ return tri
Original file line number Diff line number Diff line change @@ -191,3 +191,30 @@ def test_jax_eye():
191
191
out_fg = FunctionGraph ([], [out ])
192
192
193
193
compare_jax_and_py (out_fg , [])
194
+
195
+
196
+ def test_tri ():
197
+ out = at .tri (10 , 10 , 0 )
198
+ fgraph = FunctionGraph ([], [out ])
199
+ compare_jax_and_py (fgraph , [])
200
+
201
+
202
+ def test_tri_nonconcrete ():
203
+ """JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
204
+
205
+ m , n , k = (
206
+ scalar ("a" , dtype = "int64" ),
207
+ scalar ("n" , dtype = "int64" ),
208
+ scalar ("k" , dtype = "int64" ),
209
+ )
210
+ m .tag .test_value = 10
211
+ n .tag .test_value = 10
212
+ k .tag .test_value = 0
213
+
214
+ out = at .tri (m , n , k )
215
+
216
+ # The actual error the user will see should be jax.errors.ConcretizationTypeError, but
217
+ # the error handler raises an Attribute error first, so that's what this test needs to pass
218
+ with pytest .raises (AttributeError ):
219
+ fgraph = FunctionGraph ([m , n , k ], [out ])
220
+ compare_jax_and_py (fgraph , [get_test_value (i ) for i in fgraph .inputs ])
You can’t perform that action at this time.
0 commit comments