Help determining whether I'm using lax.scan and lax.cond correctly in re-writing existing JAX code to make it more efficient? #19969
Unanswered
Chulabhaya
asked this question in
Q&A
Replies: 1 comment
-
Hi - I think you're going to have to narrow this down for us a bit. Your link is to a file with over 1000 lines of code, and it's not really clear which part you're asking about. I would suggest doing your best to create a minimal reproducible example of what you're attempting to do, and then use that to ask a more pointed question. Thanks! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi! I have some JAX code for a reinforcement learning algorithm (SAC) that I initially worked on and is pretty fast compared to my existing PyTorch code, https://pastebin.com/tLqhSc6i. I figured that I could make this version of the code faster if I made everything end-to-end jittable and attempted to use
lax.scan
to replace the main for loop andlax.cond
to replace the if statements as follows: https://pastebin.com/AbE1VTHhMy main question is did I implement this re-write correctly in terms of making the loops and conditionals also jit-compatible, and is it beneficial to code my algorithm in this way? I'm pretty new to JAX and still trying to learn my way around the best practices so any help would be greatly appreciated. Thank you in advance!
Everything is contained to one file in those scripts for the algorithms except for the replay buffer, which is the following:
Beta Was this translation helpful? Give feedback.
All reactions