Skip to content
Discussion options

You must be logged in to vote

Great question! The general advice for something like this is that, under JIT, what you do at the high level shouldn't matter: the XLA compiler should be able to find the optimal route to computing what your code expresses.

That's the ideal, but in practice, you can sometimes improve things by choosing a different high-level approach. In cases like this micro benchmarks might be revealing, and it looks like a simple lax.concatenate, besides being shorter and easier to read, is 40-50% faster than a loop over index updates:

import jax.numpy as jnp
from jax import jit, lax

arrays = [jnp.arange(i, 2 * i) for i in range(10, 100)]

@jit
def f1(*arrays):
  return lax.concatenate(arrays, 0)

@jit

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by BugQualia
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants