Replies: 1 comment
-
Great points. We should improve these error messages. (Actually this thread might work better as an Issue rather than a Discussion... not sure if we can move it somehow.) I think the key is that import jax
import jax.numpy as jnp
init_val = (jnp.ones(2), jnp.zeros(2))
def f(i, a):
return a
jax.lax.fori_loop(1, 5, f, init_val) This snippet from the docstring is meant to convey the loop counter idea (notice how
Do you have suggestions for how we might improve the docstring? Maybe if we give an example of calling |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to use fori_loop in a simple example, but running into a bunch of errors:
This gives me
TypeError: f() takes 1 positional argument but 2 were given
, although the documentation says "init_val – initial loop carry value of type a.", implying a single input argument (here a tuple).I changed it to take in two input arguments.
I get this error:
Changing it to return a tuple doesn't work either, nor does passing in init_val as (init_val,).
Extremely confused on how to properly use this function. I'm not sure how to interpret the body_fun type structure output in the error above. When I print the tree structure for init_val, I just have the following as expected.
Any help appreciated, thanks.
Beta Was this translation helpful? Give feedback.
All reactions