Replies: 1 comment 1 reply
-
Hey, did you find a good solution to this issue? I'm having a similar problem with high memory usage. I'm thinking perhaps vmap isn't the best option for running the same computation with different data many times in the fastest way possible, |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to migrate some code from Numba to JAX. As a result, I've spent much time figuring out why following code example memory usage spikes up to ~1.5 GB while the Numba version of the same algorithm has only a few megabytes memory footprint:

In the following example, I've implemented a bootstrap algorithm that estimates mean distribution:
Profiling with Memory Profiler:
Also, I've tried to profile the code with JAX's Device Memory Profiler. Unfortunately, both profilers fail to spot the memory consumption problem.
A similar discussion can be found here: #7509
It would be highly helpful if you can help me to answer these questions:
vmap
and the explanation can be found in the documentation but it isn't clear for me.jax.random.split
intovmap
or there is any other better way to do it?Thanks in advance.
Beta Was this translation helpful? Give feedback.
All reactions