Using XLA to construct mini-batches for "ragged" batches? #7618
-
Coming from pytorch, I am used to manually feeding a batch of inputs into my model at each training step. I am wondering if via Instead of first constructing a minibatch, I would feed the entire dataset into the model together with an array of If I apply |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - thanks for the question! When you mention Dead Code Ellimination, are you talking about eliminating computations over masked sections of the array? I'm not sure whether this is possible. To strip this down and be concrete, I think essentially what you're asking is whether you can write a program like this: import jax.numpy as jnp
import jax
key1, key2 = jax.random.split(random.PRNGKey(1701))
x = jax.random.uniform(key1, (10,))
mask = jax.random.randint(key1, (10,), 0, 2).astype(bool)
def f(x, mask):
return jnp.where(mask, jnp.sin(x), 0)
f(x, mask) and when running it depend on XLA not computing I took a look at the HLO generated for this, and it doesn't look like XLA eliminates operations within an array like this print(jax.xla_computation(f)(x, mask).as_hlo_module().to_string())
It certainly looks like sine computations are being done on the full length-10 array. There may be a way to do this that I'm not aware of (hopefully someone else will chime in!) but I don't think XLA will do this automatically in the general case. |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! When you mention Dead Code Ellimination, are you talking about eliminating computations over masked sections of the array? I'm not sure whether this is possible.
To strip this down and be concrete, I think essentially what you're asking is whether you can write a program like this:
and when running it depend on XLA not computing
jnp.sin
on the entries that will eventually be zeroed out. Is that correct?I took a…