Selecting arrays by index in batched computations #21010
-
I have an application (sparse mixture of experts model) where I need to select a matrix from a batch of matrices to use in batched matrix multiplication. For example: I have a batch of inputs Basically my question is: is there some way to either convert the vmapped implementation so that it does not instantiate the large Example code:import jax
import jax.numpy as jnp
weights = jnp.zeros((8, 128, 512))
xs = jnp.zeros((256, 128))
idx = jnp.arange(len(xs)) % len(weights)
def broadcast(weights, xs, idx):
return jnp.squeeze(xs[:, None] @ weights[idx], -2)
def vmapped(weights, xs, idx):
return jax.vmap(lambda xi, i: xi @ weights[i])(xs, idx)
def scanned(weights, xs, idx):
return jax.lax.map(lambda a: a[0] @ weights[a[1]], (xs, idx))
print("<===== broadcasted =====>")
print(jax.make_jaxpr(broadcast)(weights, xs, idx))
print()
print("<======= vmapped =======>")
print(jax.make_jaxpr(vmapped)(weights, xs, idx))
print()
print("<======= scanned =======>")
print(jax.make_jaxpr(scanned)(weights, xs, idx))
print() Output jaxpr:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
You may be able to rely on the compiler to fuse operations. Even if the jaxpr indicates an intermediate value of a particular shape, it doesn't necessarily mean that the compiled operation will instantiate that intermediate value. For example, here's the compiled HLO produced by your print(jax.jit(broadcast).lower(weights, xs, idx).compile().as_text())
I believe this means the reduction will be fused so that the implied |
Beta Was this translation helpful? Give feedback.
You may be able to rely on the compiler to fuse operations. Even if the jaxpr indicates an intermediate value of a particular shape, it doesn't necessarily mean that the compiled operation will instantiate that intermediate value.
For example, here's the compiled HLO produced by your
broadcast
function on a T4 GPU: