You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey, I have a training loop which leaks memory somewhere. While running the loop the ram usage increases constantly. Using the jax.profiler.save_device_memory_profile I can only profile GPU memory which looks as expected. If I set device to CPU in the profiler while profiling, I just get an empty graph (probably since it can't track it when the main device is set to GPU)?
I tried using memory_profiler, but all that I get is the following output:
The 6000 MB show that there is a leak (I am only working with CIFAR, this shouldn't be this much ram), but its very hard to understand where the leak is appearing. update_step is a JIT-compiled function (training loop of SGMs), and I don't see how it should leak memory:
def loss_fn(params, model, rng, batch):
rng, step_rng = random.split(rng)
N_batch = batch.shape[0]
t = random.randint(step_rng, (N_batch,1), 1, R)/(R-1)
mean_coeff = mean_factor(t)
#is it right to have the square root here for the loss?
vs = var(t)
stds = jnp.sqrt(vs)
rng, step_rng = random.split(rng)
noise = random.normal(step_rng, batch.shape)
stds = stds[:, :, None, None]
mean_coeff = mean_coeff[:, :, None, None]
xt = batch * mean_coeff + noise * stds
output = score_model.apply(params, xt, t.flatten())
loss = jnp.mean((noise + output*stds)**2)
return loss
@partial(jit, static_argnums=[4])
def update_step(params, rng, batch, opt_state, model):
val, grads = jax.value_and_grad(loss_fn)(params, model, rng, batch)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return val, params, opt_state
Any pointers on where I could be leaking memory or how I go on profiling this are greatly appreciated!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hey, I have a training loop which leaks memory somewhere. While running the loop the ram usage increases constantly. Using the
jax.profiler.save_device_memory_profile
I can only profile GPU memory which looks as expected. If I set device to CPU in the profiler while profiling, I just get an empty graph (probably since it can't track it when the main device is set to GPU)?I tried using
memory_profiler
, but all that I get is the following output:The 6000 MB show that there is a leak (I am only working with CIFAR, this shouldn't be this much ram), but its very hard to understand where the leak is appearing.
update_step
is a JIT-compiled function (training loop of SGMs), and I don't see how it should leak memory:Any pointers on where I could be leaking memory or how I go on profiling this are greatly appreciated!
Beta Was this translation helpful? Give feedback.
All reactions