Skip to content
Discussion options

You must be logged in to vote

No, lax.scan does not accept generator expressions, mainly because JAX code is lowered to XLA and XLA does not have any concept of generator expressions. You can accomplish roughly the same thing by putting the generator expression logic in the update function:

import jax
import jax.numpy as jnp 

X = jax.random.normal(jax.random.PRNGKey(0), (10, 2))
batch_size = 2

def inner_update(carry, t):
  i, X = carry
  batch = jax.lax.dynamic_slice(X, (i * batch_size, 0), (batch_size, X.shape[1]))
  return (i + 1, X), jnp.max(batch)

_, out =jax.lax.scan(inner_update, (0, X), jnp.arange(len(X)//batch_size))

Replies: 2 comments 8 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by pharringtonp19
Comment options

You must be logged in to vote
8 replies
@shoyer
Comment options

shoyer Nov 10, 2021
Collaborator

@cgarciae
Comment options

@shoyer
Comment options

shoyer Nov 10, 2021
Collaborator

@cgarciae
Comment options

@shoyer
Comment options

shoyer Nov 10, 2021
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants