You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: ram-efficient-pytorch-fsdp.md
+15-9Lines changed: 15 additions & 9 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -7,6 +7,12 @@ authors:
7
7
- user: letwun
8
8
- user: philschmid
9
9
---
10
+
# Fine-tuning Llama 2 70B using PyTorch FSDP
11
+
12
+
<!-- {blog_metadata} -->
13
+
<!-- {authors} -->
14
+
15
+
## Introduction
10
16
11
17
In this blog post, we will look at how to fine-tune Llama 2 70B using PyTorch FSDP and related best practices. We will be leveraging Hugging Face Transformers, Accelerate and TRL. We will also learn how to use Accelerate with SLURM.
12
18
@@ -16,7 +22,7 @@ Fully Sharded Data Parallelism (FSDP) is a paradigm in which the optimizer state
5. Dataset: [smangrul/code-chat-assistant-v1](https://huggingface.co/datasets/smangrul/code-chat-assistant-v1) (mix of LIMA+GUANACO with proper formatting in a ready-to-train format)
53
59
54
-
## Pre-requisites
60
+
###Pre-requisites
55
61
56
62
First follow these steps to install Flash Attention V2: Dao-AILab/flash-attention: Fast and memory-efficient exact attention (github.com). Install the latest nightlies of PyTorch with CUDA ≥11.8. Install the remaining requirements as per DHS-LLM-Workshop/code_assistant/training/requirements.txt. Here, we will be installing 🤗 Accelerate and 🤗 Transformers from the main branch.
57
63
58
-
# Fine-Tuning
64
+
##Fine-Tuning
59
65
60
-
## Addressing Challenge 1
66
+
###Addressing Challenge 1
61
67
PRs [huggingface/transformers#25107](https://github.com/huggingface/transformers/pull/25107) and [huggingface/accelerate#1777](https://github.com/huggingface/accelerate/pull/1777) solve the first challenge and requires no code changes from user side. It does the following:
62
68
63
69
1. Create the model with no weights on all ranks (using the `meta` device).
@@ -88,7 +94,7 @@ accelerator.process_index=1 CPU Peak Memory consumed during the loading (max-beg
88
94
accelerator.process_index=1 CPU Total Peak Memory consumed during the loading (max): 1506
89
95
```
90
96
91
-
## Addressing Challenge 2
97
+
###Addressing Challenge 2
92
98
It is addressed via choosing `SHARDED_STATE_DICT` state dict type when creating FSDP config. `SHARDED_STATE_DICT` saves shard per GPU separately which makes it quick to save or resume training from intermediate checkpoint. When `FULL_STATE_DICT` is used, first process (rank 0) gathers the whole model on CPU and then saving it in a standard format.
93
99
94
100
Let’s create the accelerate config via below command:
@@ -109,7 +115,7 @@ if trainer.is_fsdp_enabled:
109
115
trainer.save_model(script_args.output_dir) # alternatively, trainer.push_to_hub() if the whole ckpt is below 50GB as the LFS limit per file is 50GB
110
116
```
111
117
112
-
## Addressing Challenge 3
118
+
###Addressing Challenge 3
113
119
Flash Attention and enabling gradient checkpointing are required for faster training and reducing VRAM usage to enable fine-tuning and save compute costs. The codebase currently uses monkey patching and the implementation is at [chat_assistant/training/llama_flash_attn_monkey_patch.py](https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/llama_flash_attn_monkey_patch.py).
114
120
115
121
[FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/pdf/2205.14135.pdf) introduces a way to compute exact attention while being faster and memory-efficient by leveraging the knowledge of the memory hierarchy of the underlying hardware/GPUs - The higher the bandwidth/speed of the memory, the smaller its capacity as it becomes more expensive.
@@ -131,7 +137,7 @@ This is precisely the problem that Flash Attention addresses. The idea is to **r
131
137
with IO-Awareness](https://arxiv.org/pdf/2205.14135.pdf).
132
138
133
139
134
-
# Bringing it all-together
140
+
##Bringing it all-together
135
141
136
142
To run the training using `Accelerate` launcher with SLURM, refer this gist [launch.slurm](https://gist.github.com/pacman100/1cb1f17b2f1b3139a63b764263e70b25). Below is an equivalent command showcasing how to use `Accelerate` launcher to run the training. Notice that we are overriding `main_process_ip` , `main_process_port` , `machine_rank` , `num_processes` and `num_machines` values of the `fsdp_config.yaml`. Here, another important point to note is that the storage is stored between all the nodes.
137
143
@@ -214,5 +220,5 @@ The whole conversation is formatted as below:
We successfully fine-tuned 70B Llama model using PyTorch FSDP in a multi-node multi-gpu setting while addressing various challenges. We saw how 🤗 Transformers and 🤗 Accelerates now supports efficient way of initializing large models when using FSDP to overcome CPU RAM getting out of memory. This was followed by recommended practices for saving/loading intermediate checkpoints and how to save the final model in a way to readily use it. To enable faster training and reducing GPU memory usage, we outlined the importance of Flash Attention and Gradient Checkpointing. Overall, we can see how a simple config using 🤗 Accelerate enables finetuning of such large models in a multi-node multi-gpu setting.
0 commit comments