Skip to content
Discussion options

You must be logged in to vote

You just need to jit it, then XLA will optimize the computation for you. (However this cannot be observed with jax.xla_computation)

n = 200
d = 5000

a = jnp.ones((n, d), jnp.float32)

def f(x):
    return jnp.sum(jax.vmap(jnp.outer)(x, x), axis=0)

print(f(a)) # fail with n = 5000, d = 1000 or n = 200, d = 5000 (on 16G V100)
print(jax.jit(f)(a)) # succeed with n = 2000000, d = 1000 or n = 200, d = 60000 (on 16G V100)

jax.lax.reduce's semantic is just reduce over operands, the large intermediate array will not be eliminated if without optimization.
jax.lax.scan has map-reduce "semantic", but optimization may break it for higher parallelism when intermediate array should be store(for examp…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@joeryjoery
Comment options

Answer selected by joeryjoery
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants