jax.lax.scan & jax.lax.map going OOM under gradient computation #10131
-
Hello team, I have a use-case where I need to perform multi-Gumbel-sampling on each row of a matrix while computing the gradient. I have been trying to use For example, taking a Hopefully the below code will make it easier to understand and reproduce: import jax
import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)
NUM_SAMPLES_PER_ROW = 1000
NUM_ROWS = 3000
NUM_COLS = 1500
GUMBEL_TAU = 0.7
@jax.jit
def multi_gumbel_sample(input, key):
# This function performs gumbel sampling on n-rows together, once
gumbel_sample_once = jax.jit(lambda logits, key: jax.nn.softmax(
(logits + jax.lax.stop_gradient(jax.random.gumbel(key, logits.shape))) / GUMBEL_TAU
))
# Make all the keys we would need, together
all_keys = jax.random.split(key, num = NUM_SAMPLES_PER_ROW + 1)
# Will do gumbel sampling `NUM_SAMPLES_PER_ROW` times for each row in a given matrix
def final_function(args):
x_raw, keys = args
# Loop using scan -- Memory intensive
# Memory needed is [ NUM_BATCHES x ROW_BATCH_SIZE x NUM_SAMPLES_PER_ROW x NUM_COLS ]
ret = jax.lax.scan(
lambda x, step: (x + gumbel_sample_once(x_raw, keys[step + 1]), None),
gumbel_sample_once(x_raw, keys[1]),
jnp.arange(NUM_SAMPLES_PER_ROW - 1),
length = NUM_SAMPLES_PER_ROW - 1,
)[0]
# Manual expanding & reduce -- Still memory intensive
# ret = jnp.sum(gumbel_sample_once(x_raw[jnp.newaxis, ...].repeat(NUM_SAMPLES_PER_ROW, axis = 0), keys[1]), axis = 0)
# Loop using fori_loop -- Still memory intensive
# def loop_body(i, sum_till_now): return sum_till_now + gumbel_sample_once(x_raw, keys[i])
# ret = jax.lax.fori_loop(0, NUM_SAMPLES_PER_ROW, loop_body, jnp.zeros(x_raw.shape))
return ret # [ ROW_BATCH_SIZE x NUM_COLS ]
ROW_BATCH_SIZE = 100
# Do batching on rows & use jax.lax.map over the given matrix to save memory
num_batches = NUM_ROWS // ROW_BATCH_SIZE
send_keys = jax.lax.stop_gradient(all_keys[jnp.newaxis, :].repeat(num_batches, axis = 0)) # [ NUM_BATCHES x NUM_SAMPLES_PER_ROW x 2 ]
send_input = input.reshape(num_batches, ROW_BATCH_SIZE, NUM_COLS) # [ NUM_BATCHES x ROW_BATCH_SIZE x NUM_COLS ] ~~ [ NUM_ROWS x NUM_COLS ]
final = jax.lax.map(final_function, (send_input, send_keys)).reshape(input.shape) # [ NUM_ROWS x NUM_COLS ]
return final, all_keys[-1]
if __name__ == "__main__":
key = jax.random.PRNGKey(0)
rand_input = jax.random.normal(key, shape = (NUM_ROWS, NUM_COLS)) # [ NUM_ROWS x NUM_COLS ]
# Taking the number of samples (proxy for the entropy in the input matrix) as a dummy loss-function
grad_fn = jax.grad(
lambda input, key: jnp.sum(multi_gumbel_sample(
input,
key
)), has_aux=True
)
# PASSES
output, key = multi_gumbel_sample(rand_input, key) # [ NUM_ROWS x NUM_COLS ]
print(output.shape)
# FAILS -- OOM
output, key = grad_fn(rand_input, key) # [ NUM_ROWS x NUM_COLS ]
print(output.shape) And below are some OOM Debugging stats: BufferAssignment stats:
parameter allocation: 34.33MiB
constant allocation: 84B
maybe_live_out allocation: 33.61GiB
preallocated temp allocation: 1.12GiB
preallocated temp fragmentation: 700B (0.00%)
total allocation: 34.76GiB
total fragmentation: 2.9KiB (0.00%)
Peak buffers:
Buffer 1:
Size: 33.49GiB
Operator: op_name="jit(jvp(multi_gumbel_sample))/jit(main)/broadcast_in_dim[shape=(30, 999, 100, 1500) broadcast_dimensions=()]" source_file="/home/noveens/test_gumbel.py" source_line=52
XLA Label: broadcast
Shape: f64[30,999,100,1500]
==========================
Buffer 2:
Size: 1.12GiB
XLA Label: parameter
Shape: f64[999,100,1500]
==========================
. . . Thanks in advance for the help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi! Vanilla reverse-mode autodiff (
If you want to divide the large data into some batches, then you shouldn't loop over batches inside BTW, I find that the loop in the |
Beta Was this translation helpful? Give feedback.
Hi! Vanilla reverse-mode autodiff (
jax.grad
default) need to store all intermediate values, which causes the OOM.You can leverage
jax.checkpoint
to reduce memory consumption (at the cost of extra computation).How to: (for the loop in the
final_function
, not the loop over batches.)(Actually it can work for the loop over batches, but there is a more efficient way.)
N
steps scan intosqrt(N)
chunks ofsqrt(N)
steps.jax.checkpoint
(actually can be implemented with a nested scan, and checkpoint the inner scan).