Fastest way to flip first n elements of a vector #9410
-
Given a vector I thought of two ways to do this, one which is import jax.numpy as jnp
from jax import jit
def flip1(x, n):
return jnp.concatenate([jnp.flip(x[:n], axis=0), x[n:]], axis=0)
@jit
def flip2(x, n):
xlen = x.shape[0]
inds = jnp.arange(xlen - 1, stop=-1, step=-1)
inds -= (xlen - n)
return x[inds]
Are there any better or faster ways to do this? I need them to be |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
How about using @jit
def flip(x, n):
return jnp.roll(x[::-1], n) I believe it uses Also for the first non-jit-compatible one, I'd use def flip(x, n):
return x.at[:n].set(x[n-1::-1]) I haven't benchmarked either of these, I was just going for code conciseness and readability. |
Beta Was this translation helpful? Give feedback.
-
Thanks for the help Jake! I benchmarked all the solutions and the results are pretty interesting. It seems that import jax
import jax.numpy as jnp
from jax import jit, vmap
import numpy as onp
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
def flip1(x, n):
return jnp.concatenate([jnp.flip(x[:n], axis=0), x[n:]], axis=0)
def flip2(x, n):
xlen = x.shape[0]
inds = jnp.arange(xlen-1, stop=-1, step=-1)
inds -= (xlen - n)
return x[inds]
def flip3(x, n):
return jnp.roll(x[::-1], n)
x_np = onp.arange(1_000_000).reshape([10_000, 100])
print("Transfer time")
%time x_jax = jax.device_put(x_np) # measure JAX device transfer time
jit_flip2 = jax.jit(flip2)
print("Compile time")
%time jit_flip2(x_jax, 100).block_until_ready() # measure JAX compilation time
print("Execution time")
%timeit jit_flip2(x_jax, 100).block_until_ready() # measure JAX runtime
jit_flip3 = jax.jit(flip3)
print("Compile time")
%time jit_flip3(x_jax, 100).block_until_ready() # measure JAX compilation time
print("Execution time")
%timeit jit_flip3(x_jax, 100).block_until_ready() # measure JAX runtime |
Beta Was this translation helpful? Give feedback.
Thanks for the help Jake! I benchmarked all the solutions and the results are pretty interesting. It seems that
flip2
in my initial question is the fastest by a decent margin across TPU, CPU, and GPU for the size of arrays I am interested in (roughly 10,000 x 100). Here is the script I used in case it is useful to anyone: