Skip to content

Commit 9313e6d

Browse files
typo in jax dispatch
1 parent 4af40c4 commit 9313e6d

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tests/tensor/test_interpolate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
InterpolationMethod,
99
interp,
1010
interpolate1d,
11+
polynomial_interpolate1d,
1112
valid_methods,
1213
)
1314

@@ -105,3 +106,20 @@ def test_interpolate_scalar_extrapolate(method: InterpolationMethod):
105106
# and last should take the right.
106107
interior_point = x[3] + 0.1
107108
assert f(interior_point) == (y[4] if method == "last" else y[3])
109+
110+
111+
def test_polynomial_interpolate1d():
112+
x = np.linspace(-2, 6, 10)
113+
y = np.sin(x)
114+
115+
f_op = polynomial_interpolate1d(x, y)
116+
x_hat_pt = pt.dvector("x_hat")
117+
degree = pt.iscalar("degree")
118+
119+
f = pytensor.function(
120+
[x_hat_pt, degree], f_op(x_hat_pt, degree, True), mode="FAST_RUN"
121+
)
122+
x_grid = np.linspace(-2, 6, 100)
123+
y_hat = f(x_grid, 0)
124+
125+
assert_allclose(y_hat, np.mean(y))

0 commit comments

Comments
 (0)