Skip to content
Discussion options

You must be logged in to vote

The way to do this would be using dynamic_slice, which is JIT-compilable so long as the shape of the output array is static.

from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.jit, static_argnames=['newshape'])
def _centered(arr, newshape):
  assert len(newshape) == arr.ndim
  startind = [(s1 - s2) // 2 for s1, s2 in zip(arr.shape, newshape)]
  return jax.lax.dynamic_slice(arr, startind, newshape)

x = jnp.arange(24).reshape(2, 3, 4)
print(_centered(x, (2, 2, 2)))
# [[[ 1  2]
#   [ 5  6]]

#  [[13 14]
#   [17 18]]]

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@smartalecH
Comment options

Answer selected by smartalecH
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants