vmap
a function with variable length for loop
#13045
Replies: 2 comments
-
tl;dr - this an example of why we call it "debug" print and is working as intended.
This is a really great example of how When you do def batched_body(finished_and_carry):
finished, batched_carry = finished_and_carry
batched_carry_out = vmap(body)(batched_carry)
batched_carry = jnp.where(finished, batched_carry, batched_carry_out)
finished = vmap(cond)(batched_carry)
return finished, batched_carry Note that this transformation preserves correctness if the Ordinarily, we disallow side-effects inside of batched while loops like this (for example, |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have an array
X
of dimension[N, dim1,...,dimK]
and would like to perform some calculations along the first axis, i.e., onN
arrays of dimension[dim1,...,dimK]
. Each of theseN
calculations involves a for loop whose length is known but varies for each of theN
entries.I would like to
vmap
these calculations along the first axis. Here is an example of what I am trying to do with the expected output using a for loop in place ofvmap
:When I try to
vmap
, the loop is not applied the correct number of times:Any suggestions would be appreciated
Beta Was this translation helpful? Give feedback.
All reactions