How to use batch_vmap() in PR#19614 #19975
Replies: 2 comments
-
#19614 is an unmerged work-in-progress PR. I think the answer to "how to use it" is "don't yet" 😁 But if you want to weigh-in on the PR itself with the results of your testing, that may be helpful in making the new feature as useful as it can be when it eventually lands! |
Beta Was this translation helpful? Give feedback.
0 replies
-
Please accept my apologies. I have come to realize a glaring mistake on my part: it was the array itself exceeding the cache limit, and not related to vmap. I am truly embarrassed by my oversight. :( |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi! Recently, I noticed that the
batch_vmap()
function has been submitted in PR #19614. In order to address the OOM issue withvmap()
, I attempted to copy the submitted code into the JAX library and tried using it. However, while it returned the expected results in Example 1, it seems that Example 2 did not effectively demonstrate its ability to batch and avoid OOM problems. I believe I might have misunderstood how to use it. For instance, if I want to run 1e10 functions in 10 batches (let's say that the gpu can only handle 1e9 of them), how should I utilize thebatch_vmap()
function?Thank you so much for any help!
Example 1:
Example 2 (both of them):
Beta Was this translation helpful? Give feedback.
All reactions