Replies: 2 comments
-
You can not do |
Beta Was this translation helpful? Give feedback.
0 replies
-
As Jake says, a One option might be to see if you can express your variable-size computation using a def f(i, carry):
...
lax.while_loop(...)
... the body function has another loop inside of it. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello! I'm working on a function that maps parameters to an orthogonal matrix. This involves a sequence of Householder transformations, where only a sub-block of the matrix iterate is operated on at each iteration. Below is a working version with a regular Python for-loop.
Each iteration of the for-loop involves an operation on a conformable bottom-right sub-block of the carry value
Q
. I want to replace the for-loop with ajax.lax
-style loop. Below is my attempt withjax.lax.fori_loop
.However, the test code below
gives me the error:
It seems that the
jax.lax
-style loop does not like how the size of the vectorv
is dynamic with the way I specified it. However, I know beforehand how largev
will be for each iteration, so I wonder if there is some way I can implement this without dynamic slicing. Ultimately, I'd like to get this working withjax.lax.scan
in particular. Any idea how I can fix this?Beta Was this translation helpful? Give feedback.
All reactions