Replies: 1 comment 1 reply
-
@partial(jit, static_argnames=['num_steps'])
def dpg_steps(... |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
What I have below is just a fancy for-loop over
step()
which loops ~1e5 times and returns 2jnp.ndarrays
. I've implemented this injax.lax.fori_loop
for significant speed-up in compilation time, but this seems to lack Reverse-mode differentiation abilities, and thus I want to convert it to usejax.lax.scan
.I've tested this and it works (compiles and gives correct outputs). But when I used it to compute gradients w.r.t
alpha
andbeta
I get:So I am wondering if someone could guide me on the best way to convert this code to use scan?
Beta Was this translation helpful? Give feedback.
All reactions