Skip to content

Commit a0c02d0

Browse files
Refactor out OpFromGraph
1 parent 0e03119 commit a0c02d0

File tree

1 file changed

+41
-38
lines changed

1 file changed

+41
-38
lines changed

pytensor/tensor/interpolate.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
from difflib import get_close_matches
33
from typing import Literal, get_args
44

5-
from pytensor.compile.builders import OpFromGraph
65
from pytensor.tensor import TensorLike
76
from pytensor.tensor.basic import as_tensor_variable, switch
8-
from pytensor.tensor.blockwise import Blockwise
97
from pytensor.tensor.extra_ops import searchsorted
8+
from pytensor.tensor.functional import vectorize
109
from pytensor.tensor.math import clip, eq, le
1110
from pytensor.tensor.sort import argsort
12-
from pytensor.tensor.type import scalar
1311

1412

1513
InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"]
@@ -122,41 +120,41 @@ def interpolate1d(
122120
else:
123121
right_pad = as_tensor_variable(right_pad)
124122

125-
x_hat = scalar("x_hat", dtype=x.dtype)
126-
idx = searchsorted(x, x_hat)
127-
128-
if x.ndim != 1 or y.ndim != 1:
129-
raise ValueError("Inputs must be 1d")
130-
131-
if method == "linear":
132-
y_hat = _linear_interp1d(
133-
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
134-
)
135-
elif method == "nearest":
136-
y_hat = _nearest_neighbor_interp1d(
137-
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
138-
)
139-
elif method == "first":
140-
y_hat = _stepwise_first_interp1d(
141-
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
142-
)
143-
elif method == "mean":
144-
y_hat = _stepwise_mean_interp1d(
145-
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
146-
)
147-
elif method == "last":
148-
y_hat = _stepwise_last_interp1d(
149-
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
150-
)
151-
else:
152-
raise NotImplementedError(
153-
f"Unknown interpolation method: {method}. "
154-
f"Did you mean {get_close_matches(method, valid_methods)}?"
155-
)
156-
157-
return Blockwise(
158-
OpFromGraph(inputs=[x_hat], outputs=[y_hat], inline=False), signature="()->()"
159-
)
123+
def _scalar_interpolate1d(x_hat):
124+
idx = searchsorted(x, x_hat)
125+
126+
if x.ndim != 1 or y.ndim != 1:
127+
raise ValueError("Inputs must be 1d")
128+
129+
if method == "linear":
130+
y_hat = _linear_interp1d(
131+
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
132+
)
133+
elif method == "nearest":
134+
y_hat = _nearest_neighbor_interp1d(
135+
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
136+
)
137+
elif method == "first":
138+
y_hat = _stepwise_first_interp1d(
139+
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
140+
)
141+
elif method == "mean":
142+
y_hat = _stepwise_mean_interp1d(
143+
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
144+
)
145+
elif method == "last":
146+
y_hat = _stepwise_last_interp1d(
147+
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
148+
)
149+
else:
150+
raise NotImplementedError(
151+
f"Unknown interpolation method: {method}. "
152+
f"Did you mean {get_close_matches(method, valid_methods)}?"
153+
)
154+
155+
return y_hat
156+
157+
return vectorize(_scalar_interpolate1d, signature="()->()")
160158

161159

162160
def interp(x, xp, fp, left=None, right=None, period=None):
@@ -191,7 +189,12 @@ def interp(x, xp, fp, left=None, right=None, period=None):
191189
The interpolated values, same shape as `x`.
192190
"""
193191

192+
xp = as_tensor_variable(xp)
193+
fp = as_tensor_variable(fp)
194+
x = as_tensor_variable(x)
195+
194196
f = interpolate1d(
195197
xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False
196198
)
199+
197200
return f(x)

0 commit comments

Comments
 (0)