Implementing L2L in JAX #5871
Unanswered
davisyoshida
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm interested in implementing a decorator to swap parameters on and off of the GPU as described here: https://arxiv.org/pdf/2002.05645.pdf
To get started, I made the following decorator:
Unfortunately, functions decorated by this cannot be JIT-ed, since they take arguments living on both the CPU and GPU. Is there a way I can do something like this (or some other method entirely) and still make use of
jax.jit
?Beta Was this translation helpful? Give feedback.
All reactions