when to use vmap #18873
Replies: 1 comment 2 replies
-
Hi - thanks for the question and suggestion! To address some things in particular:
No, not necessarily, although using
This is true: whether your code is manually or automatically vectorized, it will often rely on fitting the full mapped dataset (or intermediate values) into memory: in this case, switching to
This is a very frequent request, and we're working on it: see #11319 |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I've been using JAX for the past few months and I'm incredibly impressed with its capabilities. Now, after gaining some experience, I would like to know about your opinion on a few matters, and I also have a suggestion/feature request.
Question:
Is it considered to be a good practice to use vmap and avoid manual vectorization whenever possible? Are there any downsides to using vmap compared to manual vectorization?
Suggestion:
One downside of vectorization is the significant memory usage. In cases where a function uses a lot of memory, it might be favorable to use conventional for loops instead - especially when each evaluation of the function takes so much time that python overhead can be neglected.
Now, let's say a function has inputs AND outputs of small size, but uses huge arrays in the computation process. Conceptually, it would make sense to use a transformation similar to vmap, since both batched input and batched output would easily fit into memory.
Such a function could use a for loop internally - maybe lax.map - and would be a huge quality-of-life improvement.
Has there been any discussion about implementing such a transformation?
I know about lax.map, but it is not as general as vmap. It would be really nice to have a transformation that could be used interchangeably with vmap - this would also allow to modify existing code without much effort.
Beta Was this translation helpful? Give feedback.
All reactions