-
Minimal example:We have a pytree attribute Actual context/goal:I'm writing a parallelizable reinforcement learning environment. A common pattern is to "reset" the state after a number of (possibly variable) steps (each step is a function call carrying some state.) Ideally, this reset creates a new rng that changes the state for the next episode (set of steps until next reset). In the example described above, I'm trying to create this behavior by taking the result of a vmapped step function, checking if any of their outputted state is "done", and calling a reset function on the ones that do. In Would appreciate any pointers and happy to clarify anything :) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
This sounds like the precise use-case of |
Beta Was this translation helpful? Give feedback.
This sounds like the precise use-case of
lax.switch
.