jax.vmap out of memory, jax.lax.map too slow #13174
-
I have an n x p matrix X and a m x p matrix Z and I want to compute a length m vector using the following code: import jax.numpy as jnp
from jax import random
import jax
def foo(X, Z):
return jnp.sum(X @ Z.T >= 0, axis=0)
def foo_vmap(X, Z):
return jax.vmap(lambda t: jnp.sum(X@t>=0))(Z)
def foo_map(X, Z):
return jax.lax.map(lambda t: jnp.sum(X@t>=0), Z)
key = random.PRNGKey(0)
X = random.normal(key, shape=(100, 5))
Z = random.normal(key, shape=(10**8, 5))
ans = foo(X, Z) This function as written will give a memory error because it constructs a 100 x 10^8 intermediate array which would need 40 GB memory. In principle, it is not necessary to construct the huge 100 x 10^8 intermediate array however. I tried a couple of alternatives using vmap and map hoping to retain speed of above computation for smaller values of m, while not running out of memory for larger values of m, but the vmap implementation still gives a memory error, and the lax.map implementation runs very slowly when m = 10^6. Any insights on how to get this code to run would be much appreciated! Edit: to clarify, I am trying to run this on a GPU, so am limited to ~8-12 GB of meory. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
For those interested, here is the solution due to @pschuh def bar(X, Z):
leading, *rest = Z.shape
return jax.lax.map(lambda t: jnp.sum(X@t.T>=0, axis=0), Z.reshape(100, leading // 100, *rest)).flatten() My understanding is that this combines foo and foo_map above: it runs foo on sub-matrices that allow everything to fit into memory, then stitches those together to give the final result. |
Beta Was this translation helpful? Give feedback.
For those interested, here is the solution due to @pschuh
My understanding is that this combines foo and foo_map above: it runs foo on sub-matrices that allow everything to fit into memory, then stitches those together to give the final result.