Skip to content

Commit 808a3f7

Browse files
Adding support for FSDP+Qlora. (meta-llama#572)
Co-authored-by: Matthias Reso <[email protected]>
1 parent ba44797 commit 808a3f7

File tree

12 files changed

+104
-62
lines changed

12 files changed

+104
-62
lines changed

.github/scripts/spellcheck_conf/wordlist.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,7 @@ Weaviate
13511351
MediaGen
13521352
SDXL
13531353
SVD
1354+
QLORA
13541355
Agentic
13551356
AutoGen
13561357
DeepLearning
@@ -1399,6 +1400,8 @@ sqlite
13991400
customerservice
14001401
fn
14011402
ExecuTorch
1403+
nf
1404+
quant
14021405
DLAI
14031406
agentic
14041407
containts

docs/multi_gpu.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ torchrun --nnodes 1 --nproc_per_node 8 recipes/quickstart/finetuning/finetuning
5656

5757
```
5858

59+
### Fine-tuning using FSDP + QLORA
60+
61+
This has been tested on 4 H100s GPUs.
62+
63+
```bash
64+
FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 torchrun --nnodes 1 --nproc_per_node 4 finetuning.py --enable_fsdp --quantization int4 --model_name /path_of_model_folder/70B --mixed_precision False --low_cpu_fsdp --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model
65+
```
66+
5967
### Fine-tuning using FSDP on 70B Model
6068

6169
If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.

docs/single_gpu.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ To run the examples, make sure to install the llama-recipes package (See [README
1717

1818
Get access to a machine with one GPU or if using a multi-GPU machine please make sure to only make one of them visible using `export CUDA_VISIBLE_DEVICES=GPU:id` and run the following. It runs by default with `samsum_dataset` for summarization application.
1919

20+
**NOTE** To run the fine-tuning with `QLORA`, make sure to set `--peft_method lora` and `--quantization int4`.
2021

2122
```bash
2223

23-
python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization --use_fp16 --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
24+
python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization 8bit --use_fp16 --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
2425

2526
```
2627
The args used in the command above are:
@@ -51,16 +52,16 @@ to run with each of the datasets set the `dataset` flag in the command as shown
5152
```bash
5253
# grammer_dataset
5354

54-
python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization --dataset grammar_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
55+
python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization 8bit --dataset grammar_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
5556

5657
# alpaca_dataset
5758

58-
python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization --dataset alpaca_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
59+
python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization 8bit --dataset alpaca_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
5960

6061

6162
# samsum_dataset
6263

63-
python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization --dataset samsum_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
64+
python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization 8bit --dataset samsum_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
6465

6566
```
6667

recipes/quickstart/finetuning/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ It lets us specify the training settings for everything from `model_name` to `da
5454
output_dir: str = "PATH/to/save/PEFT/model"
5555
freeze_layers: bool = False
5656
num_freeze_layers: int = 1
57-
quantization: bool = False
57+
quantization: str = None
5858
one_gpu: bool = False
5959
save_model: bool = True
6060
dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
@@ -101,7 +101,7 @@ It lets us specify the training settings for everything from `model_name` to `da
101101
You can enable [W&B](https://wandb.ai/) experiment tracking by using `use_wandb` flag as below. You can change the project name, entity and other `wandb.init` arguments in `wandb_config`.
102102

103103
```bash
104-
python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model --use_wandb
104+
python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization 8bit --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model --use_wandb
105105
```
106106
You'll be able to access a dedicated project or run link on [wandb.ai](https://wandb.ai) and see your dashboard like the one below.
107107
<div style="display: flex;">

recipes/quickstart/finetuning/multigpu_finetuning.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ We will also need 2 packages:
1818
## How to run it
1919
Get access to a machine with multiple GPUs (in this case we tested with 4 A100 and A10s).
2020

21+
### With FSDP + QLORA
22+
23+
This has been tested on 4 H100s GPUs.
24+
25+
```bash
26+
FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 torchrun --nnodes 1 --nproc_per_node 4 finetuning.py --enable_fsdp --quantization int4 --model_name /path_of_model_folder/70B --mixed_precision False --low_cpu_fsdp --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model
27+
```
28+
2129
### With FSDP + PEFT
2230

2331
<details open>

recipes/quickstart/finetuning/singlegpu_finetuning.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@ To run fine-tuning on a single GPU, we will make use of two packages:
1515

1616
## How to run it?
1717

18+
**NOTE** To run the fine-tuning with `QLORA`, make sure to set `--peft_method lora` and `--quantization 4bit --quantization_config.quant_type nf4`.
19+
20+
1821
```bash
19-
python finetuning.py --use_peft --peft_method lora --quantization --use_fp16 --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
22+
FSDP_CPU_RAM_EFFICIENT_LOADING=1 python finetuning.py --use_peft --peft_method lora --quantization 8bit --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
2023
```
2124
The args used in the command above are:
2225

2326
* `--use_peft` boolean flag to enable PEFT methods in the script
2427
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
25-
* `--quantization` boolean flag to enable int8 quantization
28+
* `--quantization` string flag to enable 8bit or 4bit quantization
2629

2730
> [!NOTE]
2831
> In case you are using a multi-GPU machine please make sure to only make one of them visible using `export CUDA_VISIBLE_DEVICES=GPU:id`.
@@ -48,16 +51,16 @@ to run with each of the datasets set the `dataset` flag in the command as shown
4851
```bash
4952
# grammar_dataset
5053

51-
python -m finetuning.py --use_peft --peft_method lora --quantization --dataset grammar_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
54+
python -m finetuning.py --use_peft --peft_method lora --quantization 8bit --dataset grammar_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
5255

5356
# alpaca_dataset
5457

55-
python -m finetuning.py --use_peft --peft_method lora --quantization --dataset alpaca_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
58+
python -m finetuning.py --use_peft --peft_method lora --quantization 8bit --dataset alpaca_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
5659

5760

5861
# samsum_dataset
5962

60-
python -m finetuning.py --use_peft --peft_method lora --quantization --dataset samsum_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
63+
python -m finetuning.py --use_peft --peft_method lora --quantization 8bit --dataset samsum_dataset --model_name /path_of_model_folder/8B --output_dir Path/to/save/PEFT/model
6164

6265
```
6366

recipes/quickstart/inference/local_inference/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Padding would be required for batch inference. In this this [example](inference.
4646
The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:
4747

4848
```bash
49-
python chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chat_completion/chats.json --quantization --use_auditnlg
49+
python chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chat_completion/chats.json --quantization 8bit --use_auditnlg
5050

5151
```
5252

@@ -55,7 +55,7 @@ python chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --pro
5555
Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up inference when used for batched inputs. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
5656

5757
```bash
58-
python chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chat_completion/chats.json --quantization --use_auditnlg --use_fast_kernels
58+
python chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chat_completion/chats.json --quantization 8bit --use_auditnlg --use_fast_kernels
5959

6060
python inference.py --model_name <training_config.output_dir> --peft_model <training_config.output_dir> --prompt_file <test_prompt_file> --use_auditnlg --use_fast_kernels
6161

recipes/responsible_ai/llama_guard/README.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
<!-- markdown-link-check-disable -->
33
Meta Llama Guard is a language model that provides input and output guardrails for LLM inference. For more details and model cards, please visit the main repository for each model, [Meta Llama Guard](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard) and Meta [Llama Guard 2](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard2).
44

5-
This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path.
5+
This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path.
66

77
## Requirements
88
1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
99
2. Llama recipes package and it's dependencies [installed](https://github.com/meta-llama/llama-recipes?tab=readme-ov-file#installing)
1010

1111

1212
## Llama Guard inference script
13-
For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent.
13+
For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent.
1414

1515

1616
```
@@ -66,7 +66,4 @@ In this case, the default categories are applied by the tokenizer, using the `ap
6666

6767
Use this command for testing with a quantized Llama model, modifying the values accordingly:
6868

69-
`python examples/inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --quantization --enable_llamaguard_content_safety`
70-
71-
72-
69+
`python examples/inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --quantization 8bit --enable_llamaguard_content_safety`

src/llama_recipes/configs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from llama_recipes.configs.fsdp import fsdp_config
66
from llama_recipes.configs.training import train_config
77
from llama_recipes.configs.wandb import wandb_config
8+
from llama_recipes.configs.quantization import quantization_config
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3+
4+
from dataclasses import dataclass
5+
from typing import Optional
6+
import torch
7+
from transformers import BitsAndBytesConfig
8+
9+
@dataclass
10+
class quantization_config:
11+
quant_type: str = "fp4" # "fp4" or "nf4"
12+
compute_dtype: torch.dtype = torch.bfloat16
13+
use_double_quant: bool = False
14+
quant_storage: torch.dtype = torch.bfloat16
15+
16+
def create_bnb_config(self, quantization: str) -> BitsAndBytesConfig:
17+
if quantization not in {"4bit", "8bit"}:
18+
raise ValueError("quantization must be either '4bit' or '8bit'")
19+
20+
if quantization == "4bit":
21+
config_params = {
22+
"bnb_4bit_quant_type": self.quant_type,
23+
"bnb_4bit_compute_dtype": self.compute_dtype,
24+
"bnb_4bit_use_double_quant": self.use_double_quant,
25+
"bnb_4bit_quant_storage": self.quant_storage,
26+
}
27+
28+
return BitsAndBytesConfig(load_in_4bit=True, **config_params)
29+
else:
30+
return BitsAndBytesConfig(load_in_8bit=True)

0 commit comments

Comments
 (0)