Skip to content
Discussion options

You must be logged in to vote

Thanks for the help Jake! I benchmarked all the solutions and the results are pretty interesting. It seems that flip2 in my initial question is the fastest by a decent margin across TPU, CPU, and GPU for the size of arrays I am interested in (roughly 10,000 x 100). Here is the script I used in case it is useful to anyone:

import jax
import jax.numpy as jnp
from jax import jit, vmap
import numpy as onp
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()


def flip1(x, n):
  return jnp.concatenate([jnp.flip(x[:n], axis=0), x[n:]], axis=0)

def flip2(x, n):
  xlen = x.shape[0]
  inds = jnp.arange(xlen-1, stop=-1, step=-1)
  inds -= (xlen - n)
  return x[inds]

def flip3(x, n):
  return 

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by dieterichlawson
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants