Skip to content

Commit 3ab0ad1

Browse files
DanielHesslowthomasw21stas00Muennighoff
authored
Eval harness (#212)
* Add functionality for running the evaluation harness on single gpu * Add support for pipelining * support tensor parallel * save the results * Minor cleanup * Experimental Deepspeed support * Proper deepspeed integration, now working on combined tp and pp * Update model loading and clean up code. * Add some options * Fix pipelining + fp32 evaluaiton. * Remove dummy paths in examples/run_evalharness.sh Co-authored-by: Thomas Wang <[email protected]> * Simplify offline loading with export HF_DATASETS_OFFLINE=1 * Remove accidental copy-paste. * Experimantel deepspeed evaluation-path * make it work with deepspeed; add instructions * improve * make adaptive_seq_len work with deepspeed * move to slurm * fixes * cleanup * add instructions on how to import data into the spreadsheet * not tracking ppl/em * add task version * make compatible with lm-eval@master * switch to 16gb slurm; simplify; improve instructions * Deepspeed model loading hack * Restore correct zero state. * fix conversion script * simpler config * corrections * add logiqa * dealing with custom tokenizers * fix * Update examples/run_evalharness_deepspeed.md * check that the checkpoint path is valid * skip --abort_on_unmet_fused_kernel_constraints during eval * disable sanity check on layers-2%pp==0 * sort skip_keys * make the default path unique to avoid overwrite * Add bootstrap_iters arg * Explain bootstrap_iters flag * Intermediate results flag * Add backup file * Add arg to reduce bubble for pipeline parallel * Fix adaptive_seq_len via resetting activation shape * Extract args.load prior to load_ds_checkpoint_and_setup_megatron * Parse args prior to loading function to get load_path * Add run_evalharness-tr11-176b-ml slurm script Co-authored-by: Daniel Hesslow <[email protected]> Co-authored-by: Thomas Wang <[email protected]> Co-authored-by: Stas Bekman <[email protected]> Co-authored-by: Stas Bekman <[email protected]> Co-authored-by: Muennighoff <[email protected]>
1 parent 55f8cf8 commit 3ab0ad1

File tree

11 files changed

+962
-14
lines changed

11 files changed

+962
-14
lines changed

examples/run_evalharness.sh

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
CHECKPOINT_PATH=/gpfsscratch/rech/bbv/utw68ny/checkpoints/tr3m-1B3-pile/global_step296023/
2+
3+
PP_SIZE=1
4+
TP_SIZE=1
5+
VOCAB_FILE=gpt2-vocab.json
6+
MERGE_FILE=gpt2-merges.txt
7+
8+
export HF_DATASETS_OFFLINE=1
9+
10+
#dummy arguments to make megatron happy.
11+
MEGATRON_REQUIRED_ARGS="\
12+
--num-layers -1\
13+
--hidden-size -1\
14+
--num-attention-heads -1\
15+
--seq-length -1 \
16+
--max-position-embeddings -1
17+
"
18+
19+
CMD="./tasks/eval_harness/evaluate.py \
20+
--load $CHECKPOINT_PATH\
21+
--tensor-model-parallel-size $TP_SIZE \
22+
--pipeline-model-parallel-size $PP_SIZE\
23+
--vocab-file $VOCAB_FILE\
24+
--merge-file $MERGE_FILE\
25+
--micro-batch-size 64\
26+
--adaptive_seq_len\
27+
--eval_fp32\
28+
--task_list hellaswag,mrpc,piqa\
29+
$MEGATRON_REQUIRED_ARGS\
30+
"
31+
32+
N_GPUS=1
33+
LAUNCHER="deepspeed --num_gpus $N_GPUS"
34+
$LAUNCHER $CMD
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# How to run lm-eval on Megatron-DeepSpeed checkpoint using the original setup
2+
3+
This particular setup uses the normal deepspeed checkpoint and requires no conversion to Megatron-LM.
4+
5+
This doc assumes usage on JZ, so some peculiar requirements in places. Ignore these if you're not running this on JZ.
6+
7+
## Prerequisites
8+
9+
1. Install software
10+
11+
On login console with external network
12+
13+
Get lm-eval harness (https://github.com/EleutherAI/lm-evaluation-harness) and `best-download==0.0.7` needed to download some tasks.
14+
```
15+
start-prod
16+
pip install best-download==0.0.7
17+
pip install git+https://github.com/EleutherAI/lm-evaluation-harness
18+
```
19+
20+
2. Pre-download needed datasets
21+
22+
some symlinks due to lm-harness' issues with relative position of data
23+
```
24+
mkdir data
25+
ln -s `pwd`/data tasks/eval_harness/data
26+
```
27+
Also make sure `data` is not on one of the limited paritions like WORKSF.
28+
29+
Then install datasets for the tasks:
30+
```
31+
python ./tasks/eval_harness/download.py --task_list
32+
arc_challenge,arc_easy,boolq,copa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sciq,sst,triviaqa,webqs,wic,winogrande,wnli,wsc
33+
```
34+
and make sure that `export HF_DATASETS_OFFLINE=1`
35+
36+
If there are things like custom tokenizers, pre-download those too, e.g.:
37+
38+
```
39+
python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('bigscience/oscar_13_languages_alpha_weight')"
40+
```
41+
and make sure that `export TRANSFORMERS_OFFLINE=1` is in the script.
42+
You know there is a custom tokenizer if the training script had something like:
43+
44+
```
45+
--tokenizer-type PretrainedFromHF \
46+
--tokenizer-name-or-path bigscience/oscar_13_languages_alpha_weight \
47+
```
48+
49+
3. Prepare the slurm script
50+
51+
Prepare the run script, replace `variant` with a unique identifier for the current eval so that multiple evals could run in parallel and not all log into the same `results.json` file. so, e.g., `tr9c-1B3-swiglu`
52+
53+
```
54+
cp examples/run_evalharness_deepspeed.slurm run_evalharness-variant.slurm
55+
```
56+
57+
now edit `run_evalharness-variant.slurm`
58+
59+
60+
Note that the eval code knows to pull the original training args from the checkpoint, so we don't need to pass any of those. And we just need to setup the evaluation args.
61+
62+
1. Edit:
63+
64+
```
65+
PP_SIZE=1
66+
TP_SIZE=1
67+
```
68+
to match the eval topology. If the model fits into 1 gpu, then there is nothing to change.
69+
70+
The eval script will automatically reshape the model if it was of a different topology.
71+
72+
73+
2. Adjust the following to fit the chosen GPU. As of last check for 1.3B model the settings are one of:
74+
```
75+
EVAL_MICRO_BATCH_SIZE=6 # 16GB GPU 1.3B model
76+
EVAL_MICRO_BATCH_SIZE=12 # 32GB GPU 1.3B model
77+
```
78+
79+
If you get OOM lower it further.
80+
81+
3. If not using the Deepspeed path, disable it by removing:
82+
83+
```
84+
--deepspeed \
85+
--deepspeed_config ds_config.json \
86+
```
87+
88+
If you didn't disable it and the program crashed on checkpoint loading unable to find some key, disable deepspeed as explained above.
89+
90+
4. Additional flags
91+
92+
- To reduce the amount of iterations for stderr estimation, use e.g. `--bootstrap_iters 2`. This saves 1-2 minutes per dataset.
93+
- To print intermediate results when running multiple tasks use `--intermed_results`.
94+
- To reduce the bubble when setting PP use the flag `--micro_bs_multiplier`. Reducing `--micro-batch-size` may be needed when increasing the multiplier.
95+
- Running the 176B model with PP=8, `--micro_bs_multiplier 8` & `--micro-batch-size 4` produced the fastest results for PiQA on 1 node in 2min18s.
96+
97+
## Eval
98+
99+
Currently it takes 2-3 hours to run on 32GB for 1.3B model, 6-7h for 16GB GPU, so a 20h slurm job should be enough.
100+
101+
When ready, launch:
102+
```
103+
sbatch ./run_evalharness-variant.slurm
104+
```
105+
106+
To monitor progress:
107+
```
108+
tail -f tail -f $VARIANT-eval-harness.log
109+
```
110+
where the variant is what you set `$VARIANT` to in the slurm script.
111+
112+
The template is set up for 16GB gpu since they are easier to get by. If you change to 32GB, adjust:
113+
```
114+
#SBATCH --constraint=v100-32g
115+
...
116+
EVAL_MICRO_BATCH_SIZE=12 # 32GB GPU 1.3B model
117+
```
118+
119+
120+
Note that the original ETA at the start of the run can be 10x too longer than the actual outcome. For example it may suggest 18 hours but will complete in 2 hours.
121+
122+
123+
## Short eval
124+
125+
if you just want to quickly test that everything can run to the end, edit `tasks/eval_harness/evaluate.py`, e.g. to run only 10 batches:
126+
```
127+
- results = evaluator.evaluate(adaptor, task_dict, False, 0, None)
128+
+ results = evaluator.evaluate(adaptor, task_dict, False, 0, 10)
129+
```
130+
131+
(XXX: could be a cmd line option so that code won't need to be modified)
132+
133+
134+
## Import into spreadsheet
135+
136+
https://docs.google.com/spreadsheets/d/1CI8Q9RCblLRzUOPJ6ViqBmo284-8ojluQ-CmaEuhuv0/edit?usp=sharing
137+
138+
Note that the spreadsheet format is quite different, so use this script:
139+
```
140+
./tasks/eval_harness/report-to-csv.py results.json
141+
```
142+
to reformat the json results into csv while changing its shape to match the spreadsheet format
143+
144+
Since some records might be missing or extraneous here is the best way to do it:
145+
146+
1. copy the data from first 2 columns to some place under the main spreadsheet
147+
148+
2. put the pointer to the 3rd column next to where the 2 first columns were copied.
149+
150+
3. import `results.csv` using file-> import -> file ->
151+
152+
Import location: Replace data at selected cell
153+
154+
4. Now it should be easy to align the new records with the old ones - delete irrelevant records and Insert->Cells where data is missing until the first 2 columns match
155+
156+
5. now create 2 cols in the main table on top and now it should be safe to Copy-n-Paste the 2-col data range, without the task/metrics columns into the newly created space.
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=eval-harness-deepspeed
3+
#SBATCH --constraint=v100-16g
4+
#SBATCH --nodes=1
5+
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
6+
#SBATCH --cpus-per-task=40 # number of cores per tasks
7+
#SBATCH --hint=nomultithread # we get physical cores not logical
8+
#SBATCH --gres=gpu:1 # number of gpus
9+
#SBATCH --time 20:00:00 # maximum execution time (HH:MM:SS)
10+
#SBATCH --output=%x-%j.out # output file name
11+
#SBATCH --account=six@gpu
12+
13+
14+
set -x -e
15+
16+
source $six_ALL_CCFRWORK/start-prod
17+
18+
echo "START TIME: $(date)"
19+
20+
# a unique identifier for the current eval so that multiple evals could run in parallel and not all log into the same "results.json" file.
21+
VARIANT="tr9c-1B3-swiglu"
22+
23+
CHECKPOINT_PATH=/gpfsdsstore/projects/rech/six/commun/checkpoints/tr3m-1B3-emb-norm-pile/global_step296023
24+
MEGATRON_DEEPSPEED_REPO=/gpfsssd/worksf/projects/rech/six/commun/code/eval/Megatron-DeepSpeed
25+
26+
# you want these 2 on JZ, and pre-download/cache any datasets/tokenizers/models
27+
# but comment these out if you're running on a node with Internet access
28+
export HF_DATASETS_OFFLINE=1
29+
export TRANSFORMERS_OFFLINE=1
30+
31+
cd $MEGATRON_DEEPSPEED_REPO
32+
33+
# eval topology
34+
PP_SIZE=1
35+
TP_SIZE=1
36+
37+
VOCAB_FILE=$MEGATRON_DEEPSPEED_REPO/data/gpt2-vocab.json
38+
MERGE_FILE=$MEGATRON_DEEPSPEED_REPO/data/gpt2-merges.txt
39+
SEQ_LEN=2048
40+
41+
# different from the training MICRO_BATCH_SIZE - no optim memory, so can do bigger BS
42+
# make as big as it can fit into gpu w/o OOM, but not too close to 100%
43+
44+
EVAL_MICRO_BATCH_SIZE=6 # 16GB GPU 1.3B model
45+
#EVAL_MICRO_BATCH_SIZE=12 # 32GB GPU 1.3B model
46+
47+
48+
#dummy arguments to make megatron happy.
49+
MEGATRON_REQUIRED_ARGS=" \
50+
--num-layers -1 \
51+
--hidden-size -1 \
52+
--num-attention-heads -1 \
53+
--seq-length -1 \
54+
--max-position-embeddings -1
55+
"
56+
57+
58+
ZERO_STAGE=0
59+
60+
config_json="./ds_config.json"
61+
cat <<EOT > $config_json
62+
{
63+
"train_micro_batch_size_per_gpu": 1,
64+
"train_batch_size": 1,
65+
"zero_optimization": { "stage": $ZERO_STAGE },
66+
"fp16": { "enabled": true },
67+
"steps_per_print": 2000,
68+
"wall_clock_breakdown": false
69+
}
70+
EOT
71+
72+
CMD="./tasks/eval_harness/evaluate.py \
73+
--load $CHECKPOINT_PATH \
74+
--results_path $VARIANT-results.json \
75+
--tensor-model-parallel-size $TP_SIZE \
76+
--pipeline-model-parallel-size $PP_SIZE \
77+
--vocab-file $VOCAB_FILE \
78+
--merge-file $MERGE_FILE \
79+
--micro-batch-size $EVAL_MICRO_BATCH_SIZE \
80+
--no-load-optim \
81+
--no-load-rng \
82+
--inference \
83+
--deepspeed \
84+
--deepspeed_config ds_config.json \
85+
--seq-length $SEQ_LEN \
86+
--adaptive_seq_len \
87+
--eval_fp32 \
88+
--task_list arc_challenge,arc_easy,boolq,copa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sst,webqs,wic,winogrande,wnli,wsc,triviaqa,sciq \
89+
$MEGATRON_REQUIRED_ARGS \
90+
"
91+
92+
N_GPUS=1
93+
LAUNCHER="deepspeed --num_gpus $N_GPUS"
94+
echo $LAUNCHER $CMD
95+
96+
export PYTHONPATH=$MEGATRON_DEEPSPEED_REPO
97+
98+
$LAUNCHER $CMD 2>&1 | tee $VARIANT-eval-harness.log

0 commit comments

Comments
 (0)