Replies: 1 comment 3 replies
-
Thanks for the question! The import jax.numpy as jnp
from jax import lax
xs = jnp.array([1, 2, 3])
ys = jnp.array([3, 1, 4])
def scanned_fun(_, pair):
x, y = pair
return None, x + y
_, zs = lax.scan(scanned_fun, None, (xs, ys))
print(zs) # [4, 3, 7] We're ignoring the carry here and in effect just mapping the scalar WDYT? |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
With jax.lax.scan, I take an input
scan(f, init, xs)
. What I'd like to do is have a function of the formscan(f, init, xs, ys)
, where f is a two argument function such that my output is equivalent toSo instead of scanning across a single list, I scan across the same index in two lists. It feels like this should be possible and jit-able in jax, but I'm not sure how. Any ideas?
Beta Was this translation helpful? Give feedback.
All reactions