-
Notifications
You must be signed in to change notification settings - Fork 11
Implement vLLM FSDP LoRA hot-swapping integration #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 90 commits
904d1e1
2ace67e
a25e667
65a2dbf
ed4c84f
5a72392
e176ac8
2d869b0
dc098d6
5a1fd76
7e187bc
3eea331
906e4f3
0c41535
f24d2fa
7d27d90
d22ea85
84b953a
e1cda07
48f61d9
9876ebe
5330871
17e24bd
9982791
9a76e80
ffa7067
bd893e1
c2f346f
56cb750
97ddd8c
7c7a000
f33e89a
35bdbcd
e6b2e59
db148fa
35f6c5d
4a1251b
79fd79b
5c25397
c19de82
7e13cde
28d4ede
a863ed2
eb3721a
37f5dec
78c6faf
aea2ed8
dad6553
bbcda75
d397488
35b97b8
6af7791
97be477
b43e565
607de70
bdef48f
3294a39
2bb7bad
5d93afe
02988a5
5ad5d90
afb321c
5babf6b
c1b31c4
f0b201c
429ec5e
afbc061
7bc6f89
4936b1d
aa1fe8b
675367b
ca2cad8
649a4b8
ebb7bc9
1f1f88e
11a1ba5
b697dc0
112ea3c
07405dc
e707987
61c39ad
609c023
9585c01
31464aa
059d57f
b5c6389
f506812
3e27e84
879399f
bc0ae52
5e8944d
2005a7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,4 +9,5 @@ data/ | |
| **/*.pyc | ||
| /.cache | ||
| /.vscode | ||
| /data | ||
| /data | ||
| /env | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| model: google/gemma-2b | ||
| enable_wandb_logging: False | ||
|
|
||
| wandb_config: | ||
| project: vector-lm-verify | ||
| name: benchmark-lora | ||
| # tags: ["20240418-1a-preemption"] | ||
|
||
|
|
||
| train_parameters: | ||
| output_dir: weights | ||
| max_seq_len: 128 | ||
| epochs: 10 | ||
| seed: 11 | ||
|
|
||
| # Sharding strategy | ||
| sharding_strategy: FULL_SHARD | ||
|
|
||
| # Memory | ||
| use_mp: True | ||
| use_activation_checkpointing: True | ||
| # use_flash_attention is automatically enabled | ||
| # for CUDA capability > 8.0 | ||
| use_flash_attention: False | ||
| low_cpu_mem_usage: True | ||
|
|
||
| lora_peft_config: | ||
| task_type: CAUSAL_LM | ||
| inference_mode: False | ||
| r: 8 | ||
| lora_alpha: 32 | ||
| lora_dropout: 0.1 | ||
|
|
||
| # Gradient norm clipping | ||
| max_grad_norm: 1 | ||
| gradient_accumulation_steps: 4 | ||
|
|
||
| # Optimizer | ||
| optimizer: | ||
| lr: 1.0e-4 | ||
| weight_decay: 0.1 | ||
| betas: [0.9, 0.95] | ||
| eps: 1.0e-5 | ||
|
|
||
| # Scheduler | ||
| lr_scheduler_type: cosine | ||
| warmup_ratio: 0.05 | ||
|
|
||
| # Checkpointing | ||
| checkpointing_enabled: False | ||
| logging_steps: 10 | ||
| save_frequency: 0.10 | ||
|
|
||
| # Sampling during training | ||
| sampler: | ||
| sample_frequency: 8 | ||
| output_jsonl_path: data/output-5e-5-2b.jsonl | ||
| vllm_dtype: half | ||
| prompts: | ||
| - "Vector Institute of the" | ||
| - "Vector Institute is located in the city of" | ||
| - "The answer to life the universe and everything is" | ||
|
|
||
| dataset: | ||
| ignore_index: -100 | ||
| eval_bs: 8 | ||
| train_bs: 8 | ||
| train_ds: data/processed/vector-west/train | ||
| eval_ds: data/processed/vector-west/test | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,6 +51,22 @@ Similar to the wandb config above, these keyword parameters are fed directly int | |
| * `logging_steps`: How often evaluation is run using the evaluation dataset. | ||
| * `save_frequency`: The frequency at which checkpointing occurs. This must be between 0 and 1. | ||
|
|
||
|
|
||
| ### Sampling during Training | ||
|
|
||
| To disable sampling during training, delete the entire "sampling" section. | ||
|
||
|
|
||
| * `sample_frequency`: Number of train steps between two consecutive sampling steps. | ||
| * `output_jsonl_path`: Optional; write sampled output to the specified jsonl file. | ||
| * `prompts`: YAML list of prompt strings. | ||
|
|
||
| Each line of the output jsonl file would be a dictionary with keys: | ||
|
|
||
| * `tr_step`: number (integer), trainer step when this line was generated. | ||
| * `prompt`: string. | ||
| * `options`: list of strings, one for each possible option that the sampler provided. | ||
| * `time_taken`: float, number of seconds taken to generate **all** prompts at this step. | ||
|
|
||
| ## Dataset | ||
|
|
||
| * `ignore_index`: The integer index used to ignore a given token in the loss calculation. Cross-entropy loss by default uses `-100`. | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great writeup! |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # Efficient Sampling during training | ||
|
|
||
| Some training objectives, noteably PPO, require "sampling" from the language model many times during training. The most straightforward approach might be to invoke model.generate on the model from within the training loop. Nevertheless, there have been a number of alternative inference approaches, including vLLM and others, promising over 10x the sampling throughput in terms of tokens generated per second when using a large sampling batch size. If model.generate is taking up too much of the training time, it might be worthwhile looking into these third-party solutions for speeding up the sampling process. | ||
|
|
||
| One main challenge of running these third-party solutions, however, is that most of them assume that the weights of the language model are fixed, such that there isn't a straightforward way of updating these weights. Usually, updating the weights requires restarting the sampling engine, which sometimes take minutes. At the same time, the performance of PPO and similar techniques heavily rely on the ability to replace the weights efficiently, or else the training would no longer be on-policy and convergence would take substantially more training steps. To resolve this issue, we implemented techniques to "hot-swap" the model parameters that are used in the sampling process. | ||
|
|
||
| Additionally, it is not straightforward to ensure a consistently high GPU utilization when combining sampling with training. | ||
| This repository enables you to make the most out of all your GPUs by fitting vLLM and your training loop into the same set of devices. This way, none of the GPUs would sit idle- if a GPU is not running training, it would be busy sampling using vLLM. These slides ([link](https://docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30/edit?usp=sharing)) provide an overview of the architecture behind this approach. | ||
|
|
||
| ## Example- Supervised fine-tuning | ||
|
|
||
| We provide a basic example that samples from the language model while fine-tuning using a basic causal language modelling objective. To run the example, uncomment the "sampler" section in your configuration yaml, choose a port for `nccl` coordination, and run the following command (not using torchrun): | ||
|
|
||
| ``` | ||
| export MASTER_ADDR=127.0.0.1 | ||
| export MASTER_PORT=19132 | ||
| python3 examples/llama_example_mp.py \ | ||
| --yaml_path configs/config.yaml \ | ||
| --world_size 2 | ||
| ``` | ||
|
|
||
| ## Bring your own training loop | ||
|
|
||
| While the reference implementation is only for supervised fine-tuning, we provide abstractions that make it easier for you to implement your own training loop- be it PPO RLHF, TWIST, or something else. The goal is to abstract away all the synchronization logic, so that a training loop you've built on one GPU could scale to multiple GPUs on the same server with minimal modifications. | ||
|
|
||
| To get started, refer to examples/llama_example.py and vectorlm/trainer.py. Usually, the vLLM Engine is accessible only from the rank 0, making synchronization challenging. When invoked through llama_example_mp, the `SamplingEngine` interface in VectorLM enables your training loop to access vLLM.LLM.generate from all ranks, returning the same result across all ranks. Note that because the synchronization barriers require all ranks to reach the synchronization point, you need to invoke `generate` from all ranks. |
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| #!/bin/bash | ||
| #SBATCH --job-name=llama7b-2 | ||
| #SBATCH --nodes=1 | ||
| #SBATCH --mem=0 | ||
| #SBATCH --ntasks-per-node=1 | ||
| #SBATCH --cpus-per-gpu=6 | ||
| #SBATCH --gres=gpu:4 | ||
| #SBATCH --output=llama-2-7b.%j.out | ||
| #SBATCH --error=llama-2-7b.%j.err | ||
| #SBATCH --partition=a100 | ||
| #SBATCH --qos=your_assigned_qos # CHANGE | ||
| #SBATCH --open-mode=append | ||
| #SBATCH --wait-all-nodes=1 | ||
| #SBATCH --time=3-00 | ||
|
|
||
| export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. | ||
| export NCCL_DEBUG=WARN | ||
| export NCCL_DEBUG_SUBSYS=WARN | ||
|
|
||
| # export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication | ||
| # export TORCH_CPP_LOG_LEVEL=INFO | ||
| export LOGLEVEL=INFO | ||
| export PYTHONFAULTHANDLER=1 | ||
| # export CUDA_LAUNCH_BLOCKING=0 | ||
|
|
||
| torchrun --nnodes=1 --nproc-per-node=${SLURM_GPUS_ON_NODE} example_lora.py --yaml_path configs/config-lora.yaml |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| #!/bin/bash | ||
| #SBATCH --job-name=llama7b-2-lora | ||
| #SBATCH --nodes=1 | ||
| #SBATCH --mem=32GB | ||
| #SBATCH --ntasks-per-node=1 | ||
| #SBATCH --cpus-per-gpu=6 | ||
| #SBATCH --gres=gpu:1 | ||
| #SBATCH --output=llama-2-7b-lora.%j.out | ||
| #SBATCH --error=llama-2-7b-lora.%j.err | ||
| #SBATCH --partition=a100 | ||
| #SBATCH --qos=your_assigned_qos # CHANGE | ||
| #SBATCH --open-mode=append | ||
| #SBATCH --wait-all-nodes=1 | ||
| #SBATCH --time=3-00 | ||
|
|
||
| export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. | ||
| export NCCL_DEBUG=WARN | ||
| export NCCL_DEBUG_SUBSYS=WARN | ||
|
|
||
| # export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication | ||
| # export TORCH_CPP_LOG_LEVEL=INFO | ||
| export LOGLEVEL=INFO | ||
| export PYTHONFAULTHANDLER=1 | ||
| # export CUDA_LAUNCH_BLOCKING=0 | ||
|
|
||
| torchrun --nnodes=1 --nproc-per-node=1 example_lora.py --yaml_path configs/config-lora.yaml |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this file required to be a part of the main codebase?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That config file has been included by mistake. I will delete that from version control.