Skip to content

Commit b01e3af

Browse files
committed
Checkpoint reshaping
1 parent a82c71d commit b01e3af

File tree

1 file changed

+54
-2
lines changed

1 file changed

+54
-2
lines changed

train/tr11-176B-ml/README.md

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ Hardware:
4545

4646
Software:
4747

48-
- [Megatron-DeepSpeed](https://github.com/bigscience-workshop/Megatron-DeepSpeed) @ master / BigScience fork - currently using `layer-norm-auto-syn` PR branch
49-
- [DeepSpeed](https://github.com/microsoft/DeepSpeed) @ master (soon) at the moment 93e9307d609620943565e639f30ef15513c76f4f
48+
- [Megatron-DeepSpeed](https://github.com/bigscience-workshop/Megatron-DeepSpeed) @ `ds_ckpt_reshape-with-layer-norm-auto-sync` PR branch
49+
- [DeepSpeed](https://github.com/microsoft/DeepSpeed) @ olruwase/elastic-ckpt-refresh PR branch
5050
- [PyTorch](https://github.com/pytorch/pytorch)-1.11 w/ CUDA-11.5
5151
- [apex](https://github.com/NVIDIA/apex) @ master
5252

@@ -649,6 +649,58 @@ NHIDDEN=14336; NLAYERS=70; SEQ_LEN=2048; VOCAB_SIZE=250680; python -c "h=$NHIDDE
649649
BF16 Transformer block size: 4.59GB, the rest is: 6.75GB, total 328.34GB
650650
```
651651

652+
### Checkpoint reshaping
653+
654+
It's not trivial to switch from one 3D topology to another due to TP and DP logic of Deepspeed. So we developed a special mechanism called universal checkpoint which converts whatever topology the last checkpoint was created with into a universal checkpoint which has each weight and optimizer state as a separate file. This is done after careful merging of weights split across TP ranks (some weights are averaged, some are concatenated on the first and some on the second dimension. And then DP ZeRO sharding gets unsharded. So this universal checkpoint can now be used to start any new topology or to create a HF Transformers checkpoint. Note that all weights are in fp32 - so no data is lost.
655+
656+
657+
As this is all new currently this requires that the code runs on the following 2 branches
658+
- `microsoft/DeepSpeed|olruwase/elastic-ckpt-refresh`
659+
- `bigscience-workshop/Megatron-DeepSpeed|ds_ckpt_reshape-with-layer-norm-auto-sync`
660+
661+
The latter is really `bigscience-workshop/Megatron-DeepSpeed|ds_ckpt_reshape` but since we also have another bandaid branch that is being used it's merged with`layer-norm-auto-sync`.
662+
663+
So say you want to switch from 48 to 24 nodes.
664+
665+
1. allocate a new cpu node
666+
667+
```
668+
srun --pty --account=six@cpu --nodes=1 --ntasks=1 --partition=cpu_p1 --cpus-per-task=40 --time 6:00:00 --hint=nomultithread --tasks-per-node=1 bash
669+
```
670+
671+
2. convert the checkpoint, e.g. for `global_step90751`
672+
673+
```
674+
/usr/bin/time -v python tools/convert_checkpoint/ds_to_universal.py \
675+
--input_folder $six_ALL_CCFRSCRATCH/checkpoints/tr11-176B-ml/checkpoints/main/global_step90751 \
676+
--output_folder $six_ALL_CCFRSCRATCH/checkpoints/tr11-176B-ml/checkpoints/main/global_step90751_universal \
677+
--num_extract_workers 10 --num_merge_workers 4
678+
```
679+
680+
it takes about 50min for 176B
681+
682+
3. now edit the normal slurm script
683+
684+
a. change its topology to the desired one.
685+
686+
b. add: `--universal-checkpoint` to the script
687+
688+
c. start the slurm job normally with the edited script
689+
690+
You should be running with the new topology - it's expected that a tiny difference should be seen in lm loss, due to averaging of TP slices.
691+
692+
4. using a kill-switch or any other way save a new checkpoint which will be a normal Megatron-Deepspeed checkpoint
693+
694+
5. remove `--universal-checkpoint` from the script
695+
696+
6. resume training normally
697+
698+
the stages 5-6 are important, because currently there is a `latest-universal` tag in addition to `latest` which will not be updated by the main training, it's generated by `ds_to_universal.py` - so if you stop and start while still having `--universal-checkpoint` arg in the slurm script it'll restart from the same checkpoint as the first time and we don't want that.
699+
700+
So basically the conversion to universal is a transitional process which takes just a single step and saving a new checkpoint in the new topology - no longer universal. As you can tell converting to the universal checkpoint is a very slow and expensive process and we can't afford it on every save/load checkpoint point.
701+
702+
703+
652704
### Times
653705

654706
- 1 train iteration ~100sec

0 commit comments

Comments
 (0)