Replies: 1 comment 2 replies
-
It may help to run this with the dump flags: XLA_FLAGS='--xla_dump_to=/tmp/output_folder/xla_dumps --xla_dump_hlo_pass_re=.*' and then look right after the spmd propagation pass for any unusual partition specs (most of them should have a leading data partitioning). Since you mention embedding, I would note that there is a known performance issue with updating the embeddings which all-gathers the embedding updates before applying them because the scatter-add cannot be partitioned properly. The workaround normally for this is to wrap an xmap around the embedding layer with the flags: The oom message just prints out the top buffers regardless if their lifetime overlaps or not, so the same sized buffers for each layer are still created but with different lifetimes. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Following up on my last discussion (#11798) where it seemed that maybe pjit wasn't splitting over the batch axis, I increased my model size and started getting OOMs again. This time I saved some of the logs from xla (see below). I'm pretty sure this is a bug in either pjit or the new vmap spmd_axis_name, but would like to check.
As a reminder, my code looks vaguely like:
For context, I'm using a V3-256, with a pjit mesh of shape [128, 2], with the first axis being the "DATA" or "batch" axis and the second being for model partitioning. The model is partitioned along the hidden states.
I'm pretty sure these are either dropout masks or RNG states that aren't getting sharded. in the below,
This is followed by 19 additional identical allocations, presumably corresponding to the layers of the transformer blocks. (Confusingly I do have gradient checkpointing turned on, so I would have thought that you wouldn't get more than one of these anyway...)
The line being identified is:
For the offending allocation,
x
semantically here has shape[seqlen, embed]
, whereseqlen=1024, embed=1600
, and the computation has been vmapped to a batch size of 128, via:and the model has been pjit'd so that the "embed" dim should be partitioned in half.
So, the tensors being allocated above are of shape [128, 819200], which I'm guessing is a flattening of [128, 1024, 800]: which indicates that the "embed" axis is being partitioned but the "data"/"batch" axis is not...
EDIT: I realize I wasn't clear what my question was. Is my interpretation of the shape of these tensors correct? And if so, I think this means it's at least a bug in pjit/vmap?
A secondary question is why it needs to allocate all these buffers if I'm using gradient checkpointing between every layer...
Beta Was this translation helpful? Give feedback.
All reactions