Replies: 1 comment 9 replies
-
I'm not sure what you mean by "Evaluate the model using the lookup table at x", and the answer to your question will depend on what operations that implies. In short, the answer is that |
Beta Was this translation helpful? Give feedback.
9 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I am doing a batch of optimizations in parallel using vmap, and each model evaluation uses a large lookup table that is a significant fraction of available GPU memory. In code I am using the following pattern:
My main question is pretty general: what is going on under the hood with folding in the lookup table into the vmap? For example, are there competing reads of the lookup table with each batch element that could be causing slowdowns? I have not seen any particular speedups using vmap vs jax.lax.map when using this pattern, so I’d like to understand more what’s going on under the hood.
(I’ll also add: if people have general tips for a situation like this: optimizing performance, profiling, etc that would be great.)
Beta Was this translation helpful? Give feedback.
All reactions