Skip to content

Commit 40bd933

Browse files
authored
add start-fast doc (#278)
1 parent 640c818 commit 40bd933

File tree

2 files changed

+254
-6
lines changed

2 files changed

+254
-6
lines changed

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ In addition various code bits and lots of docs are to be found at https://github
88

99
Please note that the rest of this page has been trimmed to only include the info relevant to the BigScience project and also updated to usage with the integrated Deepspeed. You will find the original page with all the tables and training info on Bert and T5 [here](https://github.com/NVIDIA/Megatron-LM).
1010

11+
# Get started fast
12+
13+
Here is doc with just [instructions to going from 0 to training really fast](start_fast.md).
14+
1115
# Setup
1216

1317
1. Install `bigscience-workshop/Megatron-DeepSpeed`
@@ -33,14 +37,11 @@ pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache
3337

3438
(on JZ it's done in a special way, see [here](https://github.com/bigscience-workshop/bigscience/tree/master/jz/envs#apex).)
3539

36-
3. Install `deepspeed` / the `big-science` branch
37-
38-
Then install the `big-science` branch of `deepspeed`:
40+
3. Install `deepspeed`
3941

4042
```
41-
git clone https://github.com/microsoft/deepspeed deepspeed-big-science
42-
cd deepspeed-big-science
43-
git checkout big-science
43+
git clone https://github.com/microsoft/deepspeed
44+
cd deepspeed
4445
rm -rf build
4546
TORCH_CUDA_ARCH_LIST="7.0" DS_BUILD_CPU_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 pip install -e . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check
4647
```

start_fast.md

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Fast Setup instructions
2+
3+
This quick instructions document contains 3 steps:
4+
5+
1. installing software
6+
2. preparing data
7+
3. running the script
8+
9+
This is useful if you need to ask someone to reproduce problems with `Megatron-Deepspeed`
10+
11+
## 1. Software
12+
13+
Please follow this exact order.
14+
15+
16+
0. Create a new conda env if need be or activate an existing environment.
17+
18+
1. Install `pytorch`. Choose the desired version install instructions [here](https://pytorch.org/get-started/locally/), but for conda it'd be:
19+
20+
```
21+
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
22+
```
23+
24+
2. Install system-wide `cuda` if you don't have it already. [NVIDIA instruction](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html). Of course ideally use [the premade packages for your distro](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#package-manager-installation).
25+
Use the same major version as pytorch's cuda build. To check use:
26+
27+
```
28+
python -c 'import torch; print(f"pt={torch.__version__}, cuda={torch.version.cuda}")'
29+
```
30+
31+
The minor versions don't actually have to match, but then you will need to hack `apex` installer to ignore minor version changes, see below.
32+
33+
3. Install `apex`
34+
35+
```
36+
git clone https://github.com/NVIDIA/apex
37+
cd apex
38+
pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check . 2>&1 | tee build.log
39+
cd -
40+
```
41+
42+
If the pytorch and system-wide cuda minor versions mismatch, it's not a problem, you just need to hack `apex`'s build to bypass the check by applying this patch first and then build it.
43+
```
44+
diff --git a/setup.py b/setup.py
45+
index d76e998..f224dae 100644
46+
--- a/setup.py
47+
+++ b/setup.py
48+
@@ -31,6 +31,8 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
49+
print(raw_output + "from " + cuda_dir + "/bin\n")
50+
51+
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
52+
+ # allow minor diffs
53+
+ if bare_metal_minor != torch_binary_minor: return
54+
raise RuntimeError(
55+
"Cuda extensions are being compiled with a version of Cuda that does "
56+
"not match the version used to compile Pytorch binaries. "
57+
```
58+
59+
60+
4. Checkout and prepare `Megatron-DeepSpeed` and install its requirements
61+
62+
```
63+
git clone https://github.com/bigscience-workshop/Megatron-DeepSpeed
64+
cd Megatron-DeepSpeed
65+
pip install -r requirements.txt
66+
```
67+
68+
69+
70+
71+
## 2. Data
72+
73+
Will work under the `Megatron-DeepSpeed` clone
74+
75+
```
76+
cd Megatron-DeepSpeed
77+
```
78+
79+
80+
81+
Prepare data for preprocessing
82+
```
83+
mkdir -p data
84+
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -O data/gpt2-vocab.json
85+
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -O data/gpt2-merges.txt
86+
python -c 'from datasets import load_dataset; ds = load_dataset("stas/oscar-en-10k", split="train", keep_in_memory=False); ds.to_json(f"data/oscar-en-10k.jsonl", orient="records", lines=True, force_ascii=False)'
87+
```
88+
89+
Pre-process a small dataset to be used for training
90+
91+
```
92+
python tools/preprocess_data.py \
93+
--input data/oscar-en-10k.jsonl \
94+
--output-prefix data/meg-gpt2-oscar-en-10k \
95+
--dataset-impl mmap \
96+
--tokenizer-type GPT2BPETokenizer \
97+
--merge-file data/gpt2-merges.txt \
98+
--vocab data/gpt2-vocab.json \
99+
--append-eod \
100+
--workers 4
101+
```
102+
103+
now you have data/meg-gpt2-oscar-en-10k, vocab and merges files to pass as arguments to training, the next section shows how to use them.
104+
105+
Note that Megatron wants `data/meg-gpt2-oscar-en-10k_text_document` prefix later in `--data-path`
106+
107+
## 3. Train
108+
109+
Here is a tiny model training setup configured over 2 gpus to train on the data we prepared in step 2. Put it in a script or run it directly:
110+
111+
```
112+
CHECKPOINT_PATH=checkpoints/gpt2
113+
114+
VOCAB_FILE=data/gpt2-vocab.json
115+
MERGE_FILE=data/gpt2-merges.txt
116+
DATA_PATH=data/meg-gpt2-oscar-en-10k_text_document
117+
TENSORBOARD_PATH=output_dir/tensorboard
118+
119+
N_GPUS=2
120+
MICRO_BATCH_SIZE=1
121+
GLOBAL_BATCH_SIZE=16
122+
TP_SIZE=2
123+
PP_SIZE=1
124+
125+
NLAYERS=2
126+
NHIDDEN=8
127+
NHEADS=2
128+
SEQ_LEN=512
129+
VOCAB_SIZE=50257
130+
131+
SAVE_INTERVAL=50
132+
133+
TRAIN_SAMPLES=10_000
134+
135+
GPT_ARGS=" \
136+
--num-layers $NLAYERS \
137+
--hidden-size $NHIDDEN \
138+
--num-attention-heads $NHEADS \
139+
--seq-length $SEQ_LEN \
140+
--max-position-embeddings $SEQ_LEN \
141+
--micro-batch-size $MICRO_BATCH_SIZE \
142+
--rampup-batch-size 2 2 1_000 \
143+
--global-batch-size $GLOBAL_BATCH_SIZE \
144+
--train-samples $TRAIN_SAMPLES \
145+
--optimizer adam \
146+
--adam-beta1 0.9 \
147+
--adam-beta2 0.95 \
148+
--adam-eps 1e-8 \
149+
--lr 1e-4 \
150+
--lr-warmup-samples 5 \
151+
--min-lr 1e-6 \
152+
--lr-decay-style cosine \
153+
--lr-decay-samples 12 \
154+
--clip-grad 1.0 \
155+
--weight-decay 1e-1 \
156+
--embed-layernorm \
157+
--fp16 \
158+
--partition-activations \
159+
--seed 42 \
160+
--vocab-file $VOCAB_FILE \
161+
--merge-file $MERGE_FILE \
162+
"
163+
164+
OUTPUT_ARGS=" \
165+
--exit-interval 100 \
166+
--log-interval 10 \
167+
--save-interval $SAVE_INTERVAL \
168+
--eval-interval 100 \
169+
--eval-iters 10 \
170+
--checkpoint-activations \
171+
"
172+
173+
DATA_ARGS=" \
174+
--save $CHECKPOINT_PATH \
175+
--load $CHECKPOINT_PATH \
176+
--data-path $DATA_PATH \
177+
--tensorboard-dir $TENSORBOARD_PATH \
178+
--tensorboard-queue-size 5 \
179+
--log-timers-to-tensorboard \
180+
--log-batch-size-to-tensorboard \
181+
--log-validation-ppl-to-tensorboard \
182+
--kill-switch-path /tmp/kill-switch \
183+
"
184+
185+
ZERO_STAGE=1
186+
187+
config_json="./ds_config.json"
188+
189+
# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size()
190+
cat <<EOT > $config_json
191+
{
192+
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
193+
"train_batch_size": $GLOBAL_BATCH_SIZE,
194+
"gradient_clipping": 1.0,
195+
"zero_optimization": {
196+
"stage": $ZERO_STAGE
197+
},
198+
"fp16": {
199+
"enabled": true,
200+
"loss_scale": 0,
201+
"loss_scale_window": 500,
202+
"hysteresis": 2,
203+
"min_loss_scale": 1,
204+
"initial_scale_power": 12
205+
},
206+
"steps_per_print": 2000,
207+
"wall_clock_breakdown": false
208+
}
209+
EOT
210+
211+
DEEPSPEED_ARGS=" \
212+
--deepspeed \
213+
--deepspeed_config ${config_json} \
214+
--zero-stage ${ZERO_STAGE} \
215+
--deepspeed-activation-checkpointing \
216+
"
217+
218+
ALL_ARGS="$GPT_ARGS $OUTPUT_ARGS $DATA_ARGS $DEEPSPEED_ARGS"
219+
220+
MASTER_ADDR=localhost
221+
MASTER_PORT=6777
222+
223+
export LAUNCHER="python -u -m torch.distributed.run \
224+
--nproc_per_node $N_GPUS \
225+
--nnodes 1 \
226+
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
227+
--rdzv_backend c10d \
228+
--max_restarts 0 \
229+
--tee 3 \
230+
"
231+
export CMD=" \
232+
$LAUNCHER pretrain_gpt.py \
233+
--tensor-model-parallel-size $TP_SIZE \
234+
--pipeline-model-parallel-size $PP_SIZE \
235+
--distributed-backend nccl \
236+
$ALL_ARGS \
237+
"
238+
239+
echo $CMD
240+
241+
$CMD
242+
243+
```
244+
245+
You can, of course, run this as a slurm script, but here is [a full slurm script example](https://github.com/bigscience-workshop/bigscience/blob/d57b76bb592832bb4d2054cd5cbf132796be2d83/train/tr11-176B-ml/setup-test-n2.slurm), which has some tweaks to get `MASTER_ADDR` and a few other bits right under the SLURM environment on JeanZay, which may or may not be needed if you run it elsewhere.
246+
247+
Remember to wipe out `$CHECKPOINT_PATH`, if you change the model shape and there is a checkpoint with the old shapes saved already.

0 commit comments

Comments
 (0)