Skip to content

Commit d2eb992

Browse files
Add jax dispatch for searchsorted
1 parent 6d3a2a4 commit d2eb992

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

pytensor/link/jax/dispatch/extra_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
FillDiagonalOffset,
1111
RavelMultiIndex,
1212
Repeat,
13+
SearchsortedOp,
1314
Unique,
1415
UnravelIndex,
1516
)
@@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs):
130131
# return filldiagonaloffset
131132

132133
raise NotImplementedError("flatiter not implemented in JAX")
134+
135+
136+
@jax_funcify.register(SearchsortedOp)
137+
def jax_funcify_SearchsortedOp(op, **kwargs):
138+
side = op.side
139+
140+
def searchsorted(x, v, side=side, sorter=None):
141+
return jnp.searchsorted(x=x, v=v, side=side, sorter=sorter)
142+
143+
return searchsorted

tests/link/jax/test_extra_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def test_extra_ops():
5555
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
5656
)
5757

58+
values = np.arange(10)
59+
query = np.array(6)
60+
out = pt_extra_ops.searchsorted(values, query)
61+
fgraph = FunctionGraph([], out)
62+
compare_jax_and_py(
63+
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
64+
)
65+
5866

5967
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
6068
def test_bartlett_dynamic_shape():

tests/tensor/test_interpolate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
InterpolationMethod,
99
interp,
1010
interpolate1d,
11+
polynomial_interpolate1d,
1112
valid_methods,
1213
)
1314

@@ -105,3 +106,20 @@ def test_interpolate_scalar_extrapolate(method: InterpolationMethod):
105106
# and last should take the right.
106107
interior_point = x[3] + 0.1
107108
assert f(interior_point) == (y[4] if method == "last" else y[3])
109+
110+
111+
def test_polynomial_interpolate1d():
112+
x = np.linspace(-2, 6, 10)
113+
y = np.sin(x)
114+
115+
f_op = polynomial_interpolate1d(x, y)
116+
x_hat_pt = pt.dvector("x_hat")
117+
degree = pt.iscalar("degree")
118+
119+
f = pytensor.function(
120+
[x_hat_pt, degree], f_op(x_hat_pt, degree, True), mode="FAST_RUN"
121+
)
122+
x_grid = np.linspace(-2, 6, 100)
123+
y_hat = f(x_grid, 0)
124+
125+
assert_allclose(y_hat, np.mean(y))

0 commit comments

Comments
 (0)