-
Is there a way to save the output of a function in _, hist = jax.lax.scan(step_fn, init, xs, store_every=k) The only way that I can think of doing this is to create an empty collection and append values to it by using Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You could accomplish this by nesting two scans; here's a simple example: import jax.numpy as jnp
from jax import lax
import operator
def scan(f, init, xs=None, reverse=False, unroll=1, store_every=1):
store_every = operator.index(store_every)
assert store_every > 0
kwds = dict(reverse=reverse, unroll=unroll)
if store_every == 1:
return lax.scan(f, init, xs=xs, **kwds)
N, rem = divmod(len(xs), store_every)
if rem:
raise ValueError("store_every must evenly divide len(xs)")
xs = xs.reshape(N, store_every, *xs.shape[1:])
def f_outer(carry, xs):
carry, ys = lax.scan(f, carry, xs=xs, **kwds)
return carry, ys[-1]
return lax.scan(f_outer, init, xs=xs, **kwds)
# Test it with a simple accumulation:
def f(carry, x):
return x + carry, x + carry
# With store_every=1, it operates like lax.scan
xs = jnp.arange(50)
_, ys = scan(f, init=0, xs=xs, store_every=1)
print(ys)
# [ 0 1 3 6 10 15 21 28 36 45 55 66 78 91
# 105 120 136 153 171 190 210 231 253 276 300 325 351 378
# 406 435 465 496 528 561 595 630 666 703 741 780 820 861
# 903 946 990 1035 1081 1128 1176 1225]
# With store_every=5, it only returns every fifth element
_, ys = scan(f, init=0, xs=xs, store_every=5)
print(ys)
# [ 10 45 105 190 300 435 595 780 990 1225] I'm not sure about the efficiency of this approach in comparison to, say, returning the whole list and slicing it. |
Beta Was this translation helpful? Give feedback.
You could accomplish this by nesting two scans; here's a simple example: