|
2 | 2 | from difflib import get_close_matches |
3 | 3 | from typing import Literal, get_args |
4 | 4 |
|
5 | | -from pytensor.compile.builders import OpFromGraph |
6 | 5 | from pytensor.tensor import TensorLike |
7 | 6 | from pytensor.tensor.basic import as_tensor_variable, switch |
8 | | -from pytensor.tensor.blockwise import Blockwise |
9 | 7 | from pytensor.tensor.extra_ops import searchsorted |
| 8 | +from pytensor.tensor.functional import vectorize |
10 | 9 | from pytensor.tensor.math import clip, eq, le |
11 | 10 | from pytensor.tensor.sort import argsort |
12 | | -from pytensor.tensor.type import scalar |
13 | 11 |
|
14 | 12 |
|
15 | 13 | InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"] |
@@ -122,41 +120,41 @@ def interpolate1d( |
122 | 120 | else: |
123 | 121 | right_pad = as_tensor_variable(right_pad) |
124 | 122 |
|
125 | | - x_hat = scalar("x_hat", dtype=x.dtype) |
126 | | - idx = searchsorted(x, x_hat) |
127 | | - |
128 | | - if x.ndim != 1 or y.ndim != 1: |
129 | | - raise ValueError("Inputs must be 1d") |
130 | | - |
131 | | - if method == "linear": |
132 | | - y_hat = _linear_interp1d( |
133 | | - x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate |
134 | | - ) |
135 | | - elif method == "nearest": |
136 | | - y_hat = _nearest_neighbor_interp1d( |
137 | | - x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate |
138 | | - ) |
139 | | - elif method == "first": |
140 | | - y_hat = _stepwise_first_interp1d( |
141 | | - x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate |
142 | | - ) |
143 | | - elif method == "mean": |
144 | | - y_hat = _stepwise_mean_interp1d( |
145 | | - x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate |
146 | | - ) |
147 | | - elif method == "last": |
148 | | - y_hat = _stepwise_last_interp1d( |
149 | | - x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate |
150 | | - ) |
151 | | - else: |
152 | | - raise NotImplementedError( |
153 | | - f"Unknown interpolation method: {method}. " |
154 | | - f"Did you mean {get_close_matches(method, valid_methods)}?" |
155 | | - ) |
156 | | - |
157 | | - return Blockwise( |
158 | | - OpFromGraph(inputs=[x_hat], outputs=[y_hat], inline=False), signature="()->()" |
159 | | - ) |
| 123 | + def _scalar_interpolate1d(x_hat): |
| 124 | + idx = searchsorted(x, x_hat) |
| 125 | + |
| 126 | + if x.ndim != 1 or y.ndim != 1: |
| 127 | + raise ValueError("Inputs must be 1d") |
| 128 | + |
| 129 | + if method == "linear": |
| 130 | + y_hat = _linear_interp1d( |
| 131 | + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate |
| 132 | + ) |
| 133 | + elif method == "nearest": |
| 134 | + y_hat = _nearest_neighbor_interp1d( |
| 135 | + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate |
| 136 | + ) |
| 137 | + elif method == "first": |
| 138 | + y_hat = _stepwise_first_interp1d( |
| 139 | + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate |
| 140 | + ) |
| 141 | + elif method == "mean": |
| 142 | + y_hat = _stepwise_mean_interp1d( |
| 143 | + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate |
| 144 | + ) |
| 145 | + elif method == "last": |
| 146 | + y_hat = _stepwise_last_interp1d( |
| 147 | + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate |
| 148 | + ) |
| 149 | + else: |
| 150 | + raise NotImplementedError( |
| 151 | + f"Unknown interpolation method: {method}. " |
| 152 | + f"Did you mean {get_close_matches(method, valid_methods)}?" |
| 153 | + ) |
| 154 | + |
| 155 | + return y_hat |
| 156 | + |
| 157 | + return vectorize(_scalar_interpolate1d, signature="()->()") |
160 | 158 |
|
161 | 159 |
|
162 | 160 | def interp(x, xp, fp, left=None, right=None, period=None): |
@@ -191,7 +189,12 @@ def interp(x, xp, fp, left=None, right=None, period=None): |
191 | 189 | The interpolated values, same shape as `x`. |
192 | 190 | """ |
193 | 191 |
|
| 192 | + xp = as_tensor_variable(xp) |
| 193 | + fp = as_tensor_variable(fp) |
| 194 | + x = as_tensor_variable(x) |
| 195 | + |
194 | 196 | f = interpolate1d( |
195 | 197 | xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False |
196 | 198 | ) |
| 199 | + |
197 | 200 | return f(x) |
0 commit comments