lax.scan with generator #8503
-
I was naively hoping that we could pass a generator as the state to A minimal example which captures the error is reproduced below: import jax
import jax.numpy as jnp
X = jax.random.normal(jax.random.PRNGKey(0), (10, 2))
batch_size = 2
def sampler(X, batch_size=2):
for i in range(len(X)//batch_size):
yield X[i*batch_size: (i+1)*batch_size, :]
def inner_update(carry, t):
sampler = carry
return sampler, jnp.max(next(sampler))
jax.lax.scan(inner_update, sampler(X, batch_size), jnp.arange(len(X)//batch_size)) TypeError: Value <generator object sampler at 0x7f0d1826ccd0> with type <class 'generator'> is not a valid JAX type |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 8 replies
-
No, import jax
import jax.numpy as jnp
X = jax.random.normal(jax.random.PRNGKey(0), (10, 2))
batch_size = 2
def inner_update(carry, t):
i, X = carry
batch = jax.lax.dynamic_slice(X, (i * batch_size, 0), (batch_size, X.shape[1]))
return (i + 1, X), jnp.max(batch)
_, out =jax.lax.scan(inner_update, (0, X), jnp.arange(len(X)//batch_size)) |
Beta Was this translation helpful? Give feedback.
-
Hey @jakevdp ! I've been thinking about this problem, in case that the generator is a TF Dataset or PT Dataloader what would be the recommended solution? You could accumulate N batches into a "super batch" and use scan on that, but it might not payoff, should depend a lot on how fast you can generate these super batches. |
Beta Was this translation helpful? Give feedback.
No,
lax.scan
does not accept generator expressions, mainly because JAX code is lowered to XLA and XLA does not have any concept of generator expressions. You can accomplish roughly the same thing by putting the generator expression logic in the update function: