Replies: 1 comment 1 reply
-
For this specific problem, I have found 4 solution with forward-mode autodiff. def sq_d(xi, xj):
xi, xj = jnp.expand_dims(xi, -2), jnp.expand_dims(xj, -3)
return jnp.sum(jnp.square(xi - xj), -1)
def rbf_mvm(xi, xj, vj):
return jnp.exp(-0.5 * sq_d(xi, xj)) @ vj
train_x = jax.random.normal(jax.random.PRNGKey(0), (50000, 3), jnp.float32)
probes = jax.random.normal(jax.random.PRNGKey(5), (50000, 10), jnp.float32)
chunks = 50
def f(scaling): # success
x_scaled = train_x / scaling
x_chunked = jnp.split(x_scaled, chunks)
v_chunked = jnp.split(probes, chunks)
kv = 0.0
for i in range(chunks):
kv = kv + rbf_mvm(x_scaled, x_chunked[i], v_chunked[i])
return jnp.sum(probes.T @ kv)
def chunk(x: jnp.ndarray, chunks):
return jnp.reshape(x, (chunks, -1, *x.shape[1:]))
def g(scaling): # fail
x_scaled = train_x / scaling
x_chunked = chunk(x_scaled, chunks)
v_chunked = chunk(probes, chunks)
kv = jnp.sum(rbf_mvm(x_scaled, x_chunked, v_chunked), 0)
return jnp.sum(probes.T @ kv)
def h(scaling): # success
x_scaled = train_x / scaling
x_chunked = chunk(x_scaled, chunks)
v_chunked = chunk(probes, chunks)
def body_fun(i, val):
return val + rbf_mvm(x_scaled, x_chunked[i], v_chunked[i])
out = jax.eval_shape(rbf_mvm, x_scaled, x_chunked[0], v_chunked[0])
kv = jax.lax.fori_loop(0, chunks, body_fun, jnp.zeros(out.shape, out.dtype))
return jnp.sum(probes.T @ kv)
def ff(scaling): # success
x_scaled = train_x / scaling
x_chunked = chunk(x_scaled, chunks)
v_chunked = chunk(probes, chunks)
def scan_fun(carry, xv):
return carry + rbf_mvm(x_scaled, xv[0], xv[1]), None
out = jax.eval_shape(rbf_mvm, x_scaled, x_chunked[0], v_chunked[0])
kv = jax.lax.scan(scan_fun, jnp.zeros(out.shape, out.dtype), [x_chunked, v_chunked])[0]
return jnp.sum(probes.T @ kv)
def gg(scaling): # success
x_scaled = train_x / scaling
x_chunked = chunk(x_scaled, chunks)
v_chunked = chunk(probes, chunks)
def scan_fun(i, x_i):
return i + 1, jnp.sum(rbf_mvm(x_i, x_chunked, v_chunked), 0)
val: jnp.ndarray = jax.lax.scan(scan_fun, 0, x_chunked)[1]
kv = jnp.reshape(val, (-1, *val.shape[2:]))
return jnp.sum(probes.T @ kv)
print(jax.jvp(f, (3.0,), (1.0,))) # use forward-mode autodiff to save memory
print(jax.jvp(h, (3.0,), (1.0,)))
print(jax.jvp(ff, (3.0,), (1.0,)))
print(jax.jvp(gg, (3.0,), (1.0,))) I think it maybe not the forward pass itself cause OOM, since
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I'm trying to implement some large matrix vector products (where forming the matrix is too large to fit in memory) and then to compute the gradient with respect to the inputs, but am running into memory errors in the forwards pass when using both
jax.lax.scan
and explicit broadcasting:and the error i get is an out of memory requesting 30GBs:
Is there a more jax-like way to perform these type of matrix vector products?
System reference:
jax = '0.2.26'
jaxlib = '0.1.75'
cuda = '11.3'
and using linux
Beta Was this translation helpful? Give feedback.
All reactions