Skip to content

Commit 4af40c4

Browse files
Add jax dispatch for searchsorted
1 parent 6d3a2a4 commit 4af40c4

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

pytensor/link/jax/dispatch/extra_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
FillDiagonalOffset,
1111
RavelMultiIndex,
1212
Repeat,
13+
SearchsortedOp,
1314
Unique,
1415
UnravelIndex,
1516
)
@@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs):
130131
# return filldiagonaloffset
131132

132133
raise NotImplementedError("flatiter not implemented in JAX")
134+
135+
136+
@jax_funcify.register(SearchsortedOp)
137+
def jax_funcify_SearchsortedOp(op, **kwargs):
138+
side = op.side
139+
140+
def searchsorted(x, v, side=side, sorter=None):
141+
return jnp.searchsorted(x=x, v=v, side=side, sorter=sorter)
142+
143+
return searchsorted

tests/link/jax/test_extra_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def test_extra_ops():
5555
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
5656
)
5757

58+
values = np.arange(10)
59+
query = np.array(6)
60+
out = pt_extra_ops.searchsorted(values, query)
61+
fgraph = FunctionGraph([], out)
62+
compare_jax_and_py(
63+
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
64+
)
65+
5866

5967
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
6068
def test_bartlett_dynamic_shape():

0 commit comments

Comments
 (0)