Skip to content

Commit 49d21af

Browse files
stas00jeffratjruwase
authored
[sync] bf16 (#250)
* add .so/.swp to gitignore * progress * Temporarily add run scripts; Remove later Co-authored-by: Jeff Rasley <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent cf2d23a commit 49d21af

File tree

5 files changed

+369
-21
lines changed

5 files changed

+369
-21
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,6 @@ cython_debug/
142142
media
143143
staticfiles
144144
/tags
145+
146+
# tmp files
147+
*.swp

megatron/training.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,10 @@ def setup_model_and_optimizer(model_provider_func):
388388
args=args,
389389
lr_scheduler=lr_scheduler
390390
)
391+
392+
assert model.fp16_enabled() == args.fp16, "megatron fp16 config does not match deepspeed"
393+
assert model.bfloat16_enabled() == args.bf16, "megatron bf16 config does not match deepspeed"
394+
391395
if isinstance(model, deepspeed.PipelineEngine):
392396
# hack to get batch_fn from pretrain_gpt.py
393397
model.set_batch_fn(model.module._megatron_batch_fn)
@@ -622,9 +626,10 @@ def add_to_logging(name):
622626
args.consumed_train_samples)
623627
writer.add_scalar(f"lm-loss-training/{key}" + ' vs tokens', loss_dict[key],
624628
args.consumed_train_tokens)
629+
625630
writer.add_scalar(f"lm-loss-training/{key}" + ' vs gigaflos (without embeddings)', loss_dict[key],
626631
args.gigaflos_no_embeds)
627-
if args.log_loss_scale_to_tensorboard:
632+
if args.log_loss_scale_to_tensorboard and args.fp16:
628633
writer.add_scalar('loss-scale/loss-scale', loss_scale, iteration)
629634
writer.add_scalar('loss-scale/loss-scale vs samples', loss_scale,
630635
args.consumed_train_samples)
@@ -724,7 +729,8 @@ def add_to_logging(name):
724729
if avg > 0.0:
725730
log_string += ' {}: {:.6E} |'.format(key, avg)
726731
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
727-
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
732+
if args.fp16:
733+
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
728734
if grad_norm is not None:
729735
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
730736
if num_zeros_in_grad is not None:
@@ -861,10 +867,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
861867
args.gigaflos_no_embeds += (6 * new_samples * args.seq_length * get_parameters_in_billions(model, exclude_embeddings=True))
862868

863869
# Logging.
864-
if args.deepspeed:
865-
loss_scale = model[0].optimizer.cur_scale
866-
else:
867-
loss_scale = optimizer.get_loss_scale().item()
870+
loss_scale = None
871+
if args.fp16:
872+
if args.deepspeed:
873+
loss_scale = model[0].optimizer.cur_scale
874+
else:
875+
loss_scale = optimizer.get_loss_scale().item()
868876
params_norm = None
869877
if args.log_params_norm:
870878
params_norm = calc_params_l2_norm(model)

run.sh

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
DIR=`pwd`
55
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
6-
mkdir -p $DIR/logs
6+
#mkdir -p $DIR/logs
7+
#mkdir -p /tmp/logs
78

89

910
#DATASET_1="<PATH TO THE FIRST DATASET>"
@@ -19,7 +20,8 @@ MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt
1920

2021
script_path=$(realpath $0)
2122
script_dir=$(dirname $script_path)
22-
CONFIG_JSON="$script_dir/ds_config.json"
23+
#CONFIG_JSON="$script_dir/ds_config.json"
24+
CONFIG_JSON="/tmp/ds_config.json"
2325

2426
USE_DEEPSPEED=1
2527
ZERO_STAGE=0
@@ -35,16 +37,20 @@ ZERO_STAGE=0
3537
#WORKER_STR="-i worker-0"
3638

3739

38-
# 52B
39-
TP=4
40-
PP=16
41-
HIDDEN=8192
42-
LAYERS=64
40+
TP=1
41+
PP=2
42+
HIDDEN=1024
43+
LAYERS=24
4344
SEQ=1024
44-
GLOBAL_BATCH=1024
45+
GLOBAL_BATCH=2
4546
WORKER_STR=""
4647

47-
MICRO_BATCH=4
48+
MICRO_BATCH=1
49+
50+
DTYPE="bf16"
51+
52+
LOG_DIR="/tmp/tensorboard/tp${TP}_pp${PP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_${DTYPE}_fix3"
53+
mkdir -p $LOG_DIR
4854

4955
while [[ $# -gt 0 ]]
5056
do
@@ -89,15 +95,17 @@ options=" \
8995
--data-path ${DATASET} \
9096
--vocab-file ${VOCAB_PATH} \
9197
--merge-file ${MERGE_PATH} \
92-
--save-interval 1000 \
98+
--save-interval 10000 \
9399
--split 98,2,0 \
94100
--clip-grad 1.0 \
95101
--weight-decay 0.1 \
96102
--adam-beta1 0.9 \
97103
--adam-beta2 0.95 \
98104
--init-method-std 0.006 \
99-
--fp16 \
100-
--checkpoint-activations
105+
--${DTYPE} \
106+
--checkpoint-activations \
107+
--exit-interval 10000 \
108+
--tensorboard-dir $LOG_DIR
101109
"
102110

103111

@@ -122,11 +130,12 @@ cat <<EOT > $CONFIG_JSON
122130
"stage": $ZERO_STAGE
123131
},
124132
125-
"gradient_clipping": 1.0,
126-
"prescale_gradients": true,
133+
"bf16": {
134+
"enabled": true
135+
},
127136
128137
"fp16": {
129-
"enabled": true,
138+
"enabled": false,
130139
"loss_scale": 0,
131140
"loss_scale_window": 500,
132141
"hysteresis": 2,
@@ -138,6 +147,7 @@ cat <<EOT > $CONFIG_JSON
138147
}
139148
EOT
140149

150+
WORKER_STR="-i worker-0:0,1"
141151
#run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}"
142152
#run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}"
143153
run_cmd="deepspeed $WORKER_STR ${DIR}/pretrain_gpt.py $@ ${options}"

run_bf16.sh

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#!/bin/bash
2+
3+
4+
DIR=`pwd`
5+
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
6+
#mkdir -p $DIR/logs
7+
#mkdir -p /tmp/logs
8+
9+
10+
#DATASET_1="<PATH TO THE FIRST DATASET>"
11+
#DATASET_2="<PATH TO THE SECOND DATASET>"
12+
#DATASET_3="<PATH TO THE THIRD DATASET>"
13+
#DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}"
14+
15+
BASE_DATA_PATH=/data/Megatron-LM/data
16+
DATASET=${BASE_DATA_PATH}/indexed_datasets/megatron
17+
VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json
18+
MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt
19+
20+
21+
script_path=$(realpath $0)
22+
script_dir=$(dirname $script_path)
23+
#CONFIG_JSON="$script_dir/ds_config.json"
24+
CONFIG_JSON="/tmp/ds_config.json"
25+
26+
USE_DEEPSPEED=1
27+
ZERO_STAGE=0
28+
29+
30+
# Debug
31+
#TP=4
32+
#PP=4
33+
#LAYERS=8
34+
#HIDDEN=512
35+
#SEQ=1024
36+
#GLOBAL_BATCH=128
37+
#WORKER_STR="-i worker-0"
38+
39+
40+
TP=1
41+
PP=1
42+
DP=2
43+
WORLD_SIZE=$((TP*PP*DP))
44+
HIDDEN=1024
45+
LAYERS=24
46+
SEQ=1024
47+
GLOBAL_BATCH=1
48+
WORKER_STR=""
49+
50+
MICRO_BATCH=1
51+
52+
LR=6.0e-4
53+
MIN_LR=6.0e-5
54+
DTYPE="bf16"
55+
EXP_DIR=${HOME}/experiments/results/bf16
56+
LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_fix3"
57+
mkdir -p $LOG_DIR
58+
59+
while [[ $# -gt 0 ]]
60+
do
61+
key="$1"
62+
case $key in
63+
--no-deepspeed)
64+
USE_DEEPSPEED=0;
65+
shift
66+
;;
67+
-z|--zero-stage)
68+
ZERO_STAGE=$2;
69+
shift
70+
;;
71+
*)
72+
echo "Unknown argument(s)"
73+
usage
74+
exit 1
75+
shift
76+
;;
77+
esac
78+
done
79+
80+
81+
options=" \
82+
--tensor-model-parallel-size $TP \
83+
--pipeline-model-parallel-size $PP \
84+
--num-layers $LAYERS \
85+
--hidden-size $HIDDEN \
86+
--num-attention-heads 32 \
87+
--seq-length $SEQ \
88+
--loss-scale 12 \
89+
--max-position-embeddings $SEQ \
90+
--micro-batch-size $MICRO_BATCH \
91+
--global-batch-size $GLOBAL_BATCH \
92+
--train-iters 1000 \
93+
--lr $LR \
94+
--min-lr $MIN_LR \
95+
--lr-decay-style cosine \
96+
--log-interval 1 \
97+
--eval-iters 40 \
98+
--eval-interval 10 \
99+
--data-path ${DATASET} \
100+
--vocab-file ${VOCAB_PATH} \
101+
--merge-file ${MERGE_PATH} \
102+
--save-interval 10000 \
103+
--split 98,2,0 \
104+
--clip-grad 1.0 \
105+
--weight-decay 0.1 \
106+
--adam-beta1 0.9 \
107+
--adam-beta2 0.95 \
108+
--init-method-std 0.006 \
109+
--${DTYPE} \
110+
--checkpoint-activations \
111+
--exit-interval 10000 \
112+
--tensorboard-dir $LOG_DIR
113+
"
114+
115+
116+
if [[ ${USE_DEEPSPEED} -eq 1 ]]; then
117+
echo "Using DeepSpeed"
118+
options="${options} \
119+
--deepspeed \
120+
--deepspeed_config=${CONFIG_JSON} \
121+
--zero-stage=${ZERO_STAGE} \
122+
--deepspeed-activation-checkpointing \
123+
"
124+
fi
125+
126+
127+
cat <<EOT > $CONFIG_JSON
128+
{
129+
"train_batch_size" : $GLOBAL_BATCH,
130+
"train_micro_batch_size_per_gpu": $MICRO_BATCH,
131+
"steps_per_print": 1,
132+
133+
"zero_optimization": {
134+
"stage": $ZERO_STAGE
135+
},
136+
137+
"bf16": {
138+
"enabled": true
139+
},
140+
141+
"fp16": {
142+
"enabled": false,
143+
"loss_scale": 0,
144+
"loss_scale_window": 500,
145+
"hysteresis": 2,
146+
"min_loss_scale": 1,
147+
"initial_scale_power": 12
148+
},
149+
150+
"wall_clock_breakdown" : true
151+
}
152+
EOT
153+
154+
WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE"
155+
#WORKER_STR="-i worker-0:0,1,2,3"
156+
#run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}"
157+
#run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}"
158+
run_cmd="deepspeed --master_port 29700 $WORKER_STR ${DIR}/pretrain_gpt.py $@ ${options}"
159+
160+
161+
echo ${run_cmd}
162+
eval ${run_cmd}
163+
164+
set +x

0 commit comments

Comments
 (0)