File tree Expand file tree Collapse file tree 2 files changed +19
-0
lines changed
pytensor/link/jax/dispatch Expand file tree Collapse file tree 2 files changed +19
-0
lines changed Original file line number Diff line number Diff line change 1010 FillDiagonalOffset ,
1111 RavelMultiIndex ,
1212 Repeat ,
13+ SearchsortedOp ,
1314 Unique ,
1415 UnravelIndex ,
1516)
@@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs):
130131 # return filldiagonaloffset
131132
132133 raise NotImplementedError ("flatiter not implemented in JAX" )
134+
135+
136+ @jax_funcify .register (SearchsortedOp )
137+ def jax_funcify_SearchsortedOp (op , ** kwargs ):
138+ side = op .side
139+
140+ def searchsorted (x , v , side = side , sorter = None ):
141+ return jnp .searchsorted (x = x , v = v , side = side , sorter = sorter )
142+
143+ return searchsorted
Original file line number Diff line number Diff line change @@ -55,6 +55,14 @@ def test_extra_ops():
5555 fgraph , [get_test_value (i ) for i in fgraph .inputs ], must_be_device_array = False
5656 )
5757
58+ values = np .arange (10 )
59+ query = np .array (6 )
60+ out = pt_extra_ops .searchsorted (values , query )
61+ fgraph = FunctionGraph ([], out )
62+ compare_jax_and_py (
63+ fgraph , [get_test_value (i ) for i in fgraph .inputs ], must_be_device_array = False
64+ )
65+
5866
5967@pytest .mark .xfail (reason = "Jitted JAX does not support dynamic shapes" )
6068def test_bartlett_dynamic_shape ():
You can’t perform that action at this time.
0 commit comments