Skip to content

Commit 044a072

Browse files
authored
fixes to the blog (#1485)
1 parent 374397c commit 044a072

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

ram-efficient-pytorch-fsdp.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ authors:
77
- user: letwun
88
- user: philschmid
99
---
10+
# Fine-tuning Llama 2 70B using PyTorch FSDP
11+
12+
<!-- {blog_metadata} -->
13+
<!-- {authors} -->
14+
15+
## Introduction
1016

1117
In this blog post, we will look at how to fine-tune Llama 2 70B using PyTorch FSDP and related best practices. We will be leveraging Hugging Face Transformers, Accelerate and TRL. We will also learn how to use Accelerate with SLURM.
1218

@@ -16,7 +22,7 @@ Fully Sharded Data Parallelism (FSDP) is a paradigm in which the optimizer state
1622

1723
(Source: [link](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/))
1824

19-
# Hardware Used
25+
## Hardware Used
2026

2127
Number of nodes: 2. Minimum required is 1.
2228
Number of GPUs per node: 8
@@ -27,7 +33,7 @@ RAM per node: 1TB
2733
CPU cores per node: 96
2834
inter-node connection: Elastic Fabric Adapter
2935

30-
# Challenges with fine-tuning LLaMa 70B
36+
## Challenges with fine-tuning LLaMa 70B
3137

3238
We encountered three main challenges when trying to fine-tune LLaMa 70B with FSDP:
3339

@@ -51,13 +57,13 @@ https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/training
5157

5258
5. Dataset: [smangrul/code-chat-assistant-v1](https://huggingface.co/datasets/smangrul/code-chat-assistant-v1) (mix of LIMA+GUANACO with proper formatting in a ready-to-train format)
5359

54-
## Pre-requisites
60+
### Pre-requisites
5561

5662
First follow these steps to install Flash Attention V2: Dao-AILab/flash-attention: Fast and memory-efficient exact attention (github.com). Install the latest nightlies of PyTorch with CUDA ≥11.8. Install the remaining requirements as per DHS-LLM-Workshop/code_assistant/training/requirements.txt. Here, we will be installing 🤗 Accelerate and 🤗 Transformers from the main branch.
5763

58-
# Fine-Tuning
64+
## Fine-Tuning
5965

60-
## Addressing Challenge 1
66+
### Addressing Challenge 1
6167
PRs [huggingface/transformers#25107](https://github.com/huggingface/transformers/pull/25107) and [huggingface/accelerate#1777](https://github.com/huggingface/accelerate/pull/1777) solve the first challenge and requires no code changes from user side. It does the following:
6268

6369
1. Create the model with no weights on all ranks (using the `meta` device).
@@ -88,7 +94,7 @@ accelerator.process_index=1 CPU Peak Memory consumed during the loading (max-beg
8894
accelerator.process_index=1 CPU Total Peak Memory consumed during the loading (max): 1506
8995
```
9096

91-
## Addressing Challenge 2
97+
### Addressing Challenge 2
9298
It is addressed via choosing `SHARDED_STATE_DICT` state dict type when creating FSDP config. `SHARDED_STATE_DICT` saves shard per GPU separately which makes it quick to save or resume training from intermediate checkpoint. When `FULL_STATE_DICT` is used, first process (rank 0) gathers the whole model on CPU and then saving it in a standard format.
9399

94100
Let’s create the accelerate config via below command:
@@ -109,7 +115,7 @@ if trainer.is_fsdp_enabled:
109115
trainer.save_model(script_args.output_dir) # alternatively, trainer.push_to_hub() if the whole ckpt is below 50GB as the LFS limit per file is 50GB
110116
```
111117

112-
## Addressing Challenge 3
118+
### Addressing Challenge 3
113119
Flash Attention and enabling gradient checkpointing are required for faster training and reducing VRAM usage to enable fine-tuning and save compute costs. The codebase currently uses monkey patching and the implementation is at [chat_assistant/training/llama_flash_attn_monkey_patch.py](https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/llama_flash_attn_monkey_patch.py).
114120

115121
[FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/pdf/2205.14135.pdf) introduces a way to compute exact attention while being faster and memory-efficient by leveraging the knowledge of the memory hierarchy of the underlying hardware/GPUs - The higher the bandwidth/speed of the memory, the smaller its capacity as it becomes more expensive.
@@ -131,7 +137,7 @@ This is precisely the problem that Flash Attention addresses. The idea is to **r
131137
with IO-Awareness](https://arxiv.org/pdf/2205.14135.pdf).
132138

133139

134-
# Bringing it all-together
140+
## Bringing it all-together
135141

136142
To run the training using `Accelerate` launcher with SLURM, refer this gist [launch.slurm](https://gist.github.com/pacman100/1cb1f17b2f1b3139a63b764263e70b25). Below is an equivalent command showcasing how to use `Accelerate` launcher to run the training. Notice that we are overriding `main_process_ip` , `main_process_port` , `machine_rank` , `num_processes` and `num_machines` values of the `fsdp_config.yaml`. Here, another important point to note is that the storage is stored between all the nodes.
137143

@@ -214,5 +220,5 @@ The whole conversation is formatted as below:
214220
<|system|> system message <|endoftext|> <|prompter|> Q1 <|endoftext|> <|assistant|> A1 <|endoftext|> ...
215221
```
216222

217-
# Conclusion
223+
## Conclusion
218224
We successfully fine-tuned 70B Llama model using PyTorch FSDP in a multi-node multi-gpu setting while addressing various challenges. We saw how 🤗 Transformers and 🤗 Accelerates now supports efficient way of initializing large models when using FSDP to overcome CPU RAM getting out of memory. This was followed by recommended practices for saving/loading intermediate checkpoints and how to save the final model in a way to readily use it. To enable faster training and reducing GPU memory usage, we outlined the importance of Flash Attention and Gradient Checkpointing. Overall, we can see how a simple config using 🤗 Accelerate enables finetuning of such large models in a multi-node multi-gpu setting.

0 commit comments

Comments
 (0)