-
Hi! I'm trying to write a function that takes an array and reshapes a single axis of that array into a new shape. E.g., we want to reshape an array of shape I want to do this in a jitted function, but am encountering a import jax
import jax.numpy as jnp
@jax.jit
def reshape_axis(a: jnp.ndarray, axis: int, new_shape: tuple):
"Reshape a single axis into a new shape."
new_shape = a.shape[:axis] + new_shape + a.shape[axis + 1 :]
return a.reshape(new_shape)
a = jnp.zeros((5, 6, 7))
new_a = reshape_axis(a, 1, (2, 3))
assert new_a.shape == (5, 2, 3, 7) The problem is that I've tried to make the import jax
import jax.numpy as jnp
from functools import partial
@partial(jax.jit, static_argnums=(1, 2), inline=True)
def reshape_axis(a: jnp.ndarray, axis: int, new_shape: tuple):
"Reshape a single axis into a new shape."
new_shape = a.shape[:axis] + new_shape + a.shape[axis + 1 :]
return a.reshape(new_shape)
@jax.jit
def do_thing(a, axis, new_shape):
return reshape_axis(a, axis, new_shape)
a = jnp.zeros((5, 6, 7))
assert do_thing(a, 1, (2, 3)).shape == (5, 2, 3, 7) I could of course make the arguments static in BTW, I've been using JAX at work for over 6 months now, and it's truly a pleasure. Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - I think you've already landed on the correct solution: you need the static arguments to be marked static at the call-site. You could do this by having the same |
Beta Was this translation helpful? Give feedback.
Hi - I think you've already landed on the correct solution: you need the static arguments to be marked static at the call-site. You could do this by having the same
static_argnums
annotation fordo_thing
that you have forreshape_axis
.