-
Hi Jax team, The following is a highly simplified model of what I am really doing: from jax import random, vmap, numpy as jnp
a = random.uniform(random.PRNGKey(0),(10**6,))
b = random.uniform(random.PRNGKey(1),(10**6,))
def foo(x, y):
return jnp.sin(x + y)
batched_foo = vmap(foo, in_axes=(0, None), out_axes=0)
batched_foo(a, b) The outcome is always memory leak:
The above operation is important in my own code and I believe there must be some workaround to this issue. However due to my limited knowledge on software engineering I am not able to find one. Could you please help me? Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi, I don't think there's a simple plug-in method that will reduce the memory usage of the operation you've written. If you're able to refactor your computational problem so that it doesn't need the entire matrix at once this would reduce the memory usage. So if e.g. you can deal with the rows one-by-one, you could loop through the rows with a |
Beta Was this translation helpful? Give feedback.
Hi, I don't think there's a simple plug-in method that will reduce the memory usage of the operation you've written.
The function you've written will return a 10^6 x 10^6-size matrix, i.e. one with a trillion entries. So just to store the output will require on the order of terabytes of memory.
If you're able to refactor your computational problem so that it doesn't need the entire matrix at once this would reduce the memory usage. So if e.g. you can deal with the rows one-by-one, you could loop through the rows with a
jax.lax.scan
. Inside the scaninner
function you couldvmap
over however many rows you can fit in memory at once.