FSDP/ZeRO grad accumulation in the presence of vmapped/scanned layers #12799
-
Hi, I have a transformer implementation that is working reasonably well except I can't quite get FSDP/ZeRO to work quite the way I want it to. So far things have been looking ok, until I really started poking at gradient accumulation, where things don't quite work out. The issue is that it seems like jax really wants to have the gradient for all layers computed on each node before it's willing to reduce them and scatter the gradients, when ideally this would proceed layer-wise to keep memory use under control. It's of course a bit hard to decode everything that's going on, but I have e.g. a parameter that has shape
I'm pretty sure this corresponds to a "logical array" size of Now, I guess the real problem above is that I'm getting a blow up of 21.3x(!) for padding (other arrays that have analogous shapes don't blow up nearly this much), but in an ideal world this array wouldn't exist. It's a temporary that gets reduced to Is there a way to make this happen? Happy to share code or more context! Specifically the current attempt is at https://github.com/stanford-crfm/levanter/blob/fsdp/src/levanter/modeling_utils.py#L66 (I have an in-progress named tensors library I'm using there but it more or less does the obvious thing). An earlier attempt is at https://github.com/stanford-crfm/levanter/blob/main/src/levanter/modeling_utils.py#L65 which doesn't have the extreme blow up problem and works fine at this scale, but it stops working at the next scale I'm targeting. (For the first linked implementation, I'm cribbing off of t5x a bit here: https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617 , but my attempts to copy it are what led to the massive array above. I'll note that in general t5x seems to avoid vmapping/scanning layers: cf https://github.com/google-research/t5x/blob/main/t5x/examples/decoder_only/network.py#L197 which makes me think that this doesn't work reliably.) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 12 replies
-
Hey David! This is quite curious! For starters, that XLA decided the memory optimised layout for that tensor was to have a dimension with size 6 at idx 0, thus causing the blowout to 128 to match TPU MXU tile size. To help us look into that, is there any chance you could send through the unoptimised HLO? You can get it by running jit(f).lower(*args).as_text(). I've quickly made a minimal reproduction I've put below. It behaves sensibly when I don't vmap per example inside the microbatch, but when I do it induces a huge all-to-all before the gradients are computed. I'm still looking into exactly what occurs, but it looks quite similar to your issue. Can you check through quickly and see if this matches with everything you're trying to do in the computation? If you don't vmap per example within the microbatch, does it work? Finally, a couple quick thoughts/questions:
import jax
jax.config.update('jax_array', True) # required for jax<0.4.0
import jax.numpy as jnp
from jax.experimental.maps import Mesh
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.experimental.pjit import PartitionSpec as P
from jax.sharding import MeshPspecSharding
from functools import partial
import numpy as np
num_layers = 48
num_heads = 24
head_size = 64
embed_size = 1536
batch = 512
t = 8
qkv_sharding = P(None, None, 'data', None)
x_sharding = P('data', None, 'model')
o_sharding = P(None, None, 'model')
qkv = jnp.ones((num_layers, num_heads, head_size, embed_size), dtype = jnp.bfloat16)
o = jnp.ones((num_layers, num_heads, head_size, embed_size), dtype=jnp.bfloat16)
x = jnp.ones((batch, t, embed_size), dtype=jnp.bfloat16)
dp, mp = 8, 1
devices = np.reshape(jax.local_devices(), (dp, mp))
mesh = Mesh(devices, ('data', 'model'))
x = jax.device_put(x, jax.sharding.MeshPspecSharding(mesh, x_sharding))
qkv = jax.device_put(qkv, jax.sharding.MeshPspecSharding(mesh, qkv_sharding))
o = jax.device_put(o, jax.sharding.MeshPspecSharding(mesh, o_sharding))
params = (qkv, o)
VMAP_MICROBATCH = False
def fwd(params, x):
@jax.checkpoint
def layer(x, params):
qkv, o = params
if VMAP_MICROBATCH:
y = jnp.einsum('te,hde->thd', x, qkv)
z = jnp.einsum('thd,hde->te', y, o)
else:
y = jnp.einsum('bte,hde->bthd', x, qkv)
z = jnp.einsum('bthd,hde->bte', y, o)
# no ffn
return z, None
x, _ = jax.lax.scan(layer, x, params)
return x
def loss_fn(params, x):
x = fwd(params, x)
l = jnp.mean(x)
return l
def grad_fn(params, x):
loss, grad = jax.value_and_grad(loss_fn)(params, x)
return loss, grad
def accumulate_gradients_sharded(params,
x,
f,
per_device_parallelism,
data_axis_size):
batch_size = jnp.shape(x)[0] # 512
microbatch_size = data_axis_size * per_device_parallelism # 8 * 4 = 32
num_micro_steps = batch_size // microbatch_size # 512 // 32 = 16
assert num_micro_steps * microbatch_size == batch_size
loss = jnp.zeros(())
grad = jax.tree_util.tree_map(jnp.zeros_like, params)
x = x.reshape((num_micro_steps, microbatch_size) + x.shape[1:])
x = with_sharding_constraint(x, PartitionSpec(None, 'data', *(None,) * (len(x.shape) - 2)))
# compute microbatches
def loop(accum, microbatch):
with jax.named_scope('microbatch'):
loss, grad = accum
if VMAP_MICROBATCH:
# vmap as code is written for single examples
this_loss, this_grad = jax.vmap(f, in_axes=(None, 0))(params, microbatch)
# reduce along microbatch dimension
this_loss = jnp.mean(this_loss)
mean_along_microbatch = partial(jnp.mean, axis = 0)
this_grad = jax.tree_map(mean_along_microbatch, this_grad)
else:
this_loss, this_grad = f(params, microbatch)
with jax.named_scope('accumulate'):
return (this_loss + loss, jax.tree_map(jnp.add, grad, this_grad)), None
# loops over microbatches, accumulates
accum = (loss, grad)
accum, _ = jax.lax.scan(loop, accum, x)
loss, grad = accum
return loss/num_micro_steps, jax.tree_map(lambda x: x / num_micro_steps, grad)
pjit_fn = partial(accumulate_gradients_sharded,
f=grad_fn,
per_device_parallelism=4,
data_axis_size=dp)
with mesh:
loss, grad = pjit(pjit_fn)(params, x)
loss.block_until_ready() |
Beta Was this translation helpful? Give feedback.
Hey David!
This is quite curious! For starters, that XLA decided the memory optimised layout for that tensor was to have a dimension with size 6 at idx 0, thus causing the blowout to 128 to match TPU MXU tile size. To help us look into that, is there any chance you could send through the unoptimised HLO? You can get it by running jit(f).lower(*args).as_text().
I've quickly made a minimal reproduction I've put below. It behaves sensibly when I don't vmap per example inside the microbatch, but when I do it induces a huge all-to-all before the gradients are computed. I'm still looking into exactly what occurs, but it looks quite similar to your issue. Can you check through quickly and see if…