JIT dynamic indexing for "reverse padding" #15917
-
Years ago I discussed the ability to do large-scale filtering using FFTs (#5227). I'm finally getting back around to that... but now I'm running into a separate issue when trying to jit my filter function. In order to avoid artifacts from the circular convolution, I need to pad my input arrays (easy!). Then I perform my convolution, and then I need to "unpad" my arrays to preserve my input shape (just like scipy does). This is the problem I'm running into... indexing down into a smaller array requires a Am I overthinking this? Is there already a jax-cononical way to "unpad" an array? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The way to do this would be using from functools import partial
import jax
import jax.numpy as jnp
@partial(jax.jit, static_argnames=['newshape'])
def _centered(arr, newshape):
assert len(newshape) == arr.ndim
startind = [(s1 - s2) // 2 for s1, s2 in zip(arr.shape, newshape)]
return jax.lax.dynamic_slice(arr, startind, newshape)
x = jnp.arange(24).reshape(2, 3, 4)
print(_centered(x, (2, 2, 2)))
# [[[ 1 2]
# [ 5 6]]
# [[13 14]
# [17 18]]] |
Beta Was this translation helpful? Give feedback.
The way to do this would be using
dynamic_slice
, which is JIT-compilable so long as the shape of the output array is static.