Skip to content
Discussion options

You must be logged in to vote

You could accomplish this by nesting two scans; here's a simple example:

import jax.numpy as jnp
from jax import lax
import operator

def scan(f, init, xs=None, reverse=False, unroll=1, store_every=1):
  store_every = operator.index(store_every)
  assert store_every > 0

  kwds = dict(reverse=reverse, unroll=unroll)

  if store_every == 1:
    return lax.scan(f, init, xs=xs, **kwds)
  
  N, rem = divmod(len(xs), store_every)

  if rem:
    raise ValueError("store_every must evenly divide len(xs)")

  xs = xs.reshape(N, store_every, *xs.shape[1:])

  def f_outer(carry, xs):
    carry, ys = lax.scan(f, carry, xs=xs, **kwds)
    return carry, ys[-1]
  return lax.scan(f_outer, init, xs=xs, **k…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@gerdm
Comment options

Answer selected by gerdm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants