float()
argument TypeError
in a scan.
#19057
Replies: 1 comment 1 reply
-
Sounds like an internal JAX bug! On that presumption, let me convert this to a bug issue... Let's discuss on #19059 (rather than here). |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a really weird jax bug with scan, I'm wondering if anyone else has seen similar issues?
Here's a minimal reproducible example:
So far, so good. I've defined a function g, that basically maps from \C^8 to \R (casted to be complex for typing reasons), and used vjp to differentiate it (grad behaves a bit weird with these functions). Now, I'd just like to do a lot of these iterations with a scan.
The following code works completely fine
out has shape (9,8), and dtype complex64, as desired.
Note that this code are basically the two lines of the _step, except they aren't even returned.
However, the following code returns a TypeError
Gives error:
Notice that
_step
doesn't even return anything relating to its intermediate computations! Moreover, I can see that when I comment out the linethe error goes away.
I'm really at a loss as to why this line is completely fine outside of a scan, but crashes the scan, even though it is unrelated to either the input or output of
_step
. I've really been struggling to debug inside of the scan.If anyone has seen this type of error before, or even just has tips about how to debug inside of a scan, I'd be super grateful to hear them.
Beta Was this translation helpful? Give feedback.
All reactions