Replies: 1 comment
-
To add on, I think it would be ergonomic to include theta = jnp.ones((3, 3))
x = jnp.ones((10, 3))
h = jnp.zeros((3,))
def f(carry, input):
theta, x = input
_, h = carry
return h + theta @ x, None
scanf = jax.scan(f, in_axes=(0, None)) # sequence-map (smap?) over x but not params theta
scanf(init=h, xs=(theta, x)) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Is there a reason why
jax.lax.scan
does not transform functions? Havingjit
,vmap
operate on functions is super useful and very readable. It is a shame that it doesn't work forscan
.For example, ideally one could do something like:
Or even use a decorator
I suppose you can use partial, but it's less clean
This is somewhat related to #23487
Beta Was this translation helpful? Give feedback.
All reactions