Skip to content

Commit 44b6637

Browse files
authored
Resume the fine-tuning process from the previous PEFT checkpoint folder (meta-llama#531)
2 parents cf29a56 + f1d90d0 commit 44b6637

File tree

5 files changed

+22
-10
lines changed

5 files changed

+22
-10
lines changed

docs/multi_gpu.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ The args used in the command above are:
3434

3535
* `--use_peft` boolean flag to enable PEFT methods in the script
3636

37-
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
37+
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`.
3838

3939
We use `torchrun` here to spawn multiple processes for FSDP.
4040

@@ -138,8 +138,9 @@ It lets us specify the training settings for everything from `model_name` to `da
138138
mixed_precision: bool=True
139139
val_batch_size: int=1
140140
dataset = "samsum_dataset"
141-
peft_method: str = "lora" # None,llama_adapter, prefix
141+
peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
142142
use_peft: bool=False
143+
from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
143144
output_dir: str = "PATH/to/save/PEFT/model"
144145
freeze_layers: bool = False
145146
num_freeze_layers: int = 1

docs/single_gpu.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The args used in the command above are:
2727

2828
* `--use_peft` boolean flag to enable PEFT methods in the script
2929

30-
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
30+
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`.
3131

3232
* `--quantization` boolean flag to enable int8 quantization
3333

@@ -94,8 +94,9 @@ It let us specify the training settings, everything from `model_name` to `datase
9494
mixed_precision: bool=True
9595
val_batch_size: int=1
9696
dataset = "samsum_dataset"
97-
peft_method: str = "lora" # None,llama_adapter, prefix
97+
peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
9898
use_peft: bool=False
99+
from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
99100
output_dir: str = "PATH/to/save/PEFT/model"
100101
freeze_layers: bool = False
101102
num_freeze_layers: int = 1
@@ -112,6 +113,7 @@ It let us specify the training settings, everything from `model_name` to `datase
112113
flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
113114
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
114115
profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
116+
115117
```
116118

117119
* [Datasets config file](../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

recipes/finetuning/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ It lets us specify the training settings for everything from `model_name` to `da
4848
mixed_precision: bool=True
4949
val_batch_size: int=1
5050
dataset = "samsum_dataset"
51-
peft_method: str = "lora" # None,llama_adapter, prefix
51+
peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
5252
use_peft: bool=False
53+
from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
5354
output_dir: str = "PATH/to/save/PEFT/model"
5455
freeze_layers: bool = False
5556
num_freeze_layers: int = 1
@@ -66,6 +67,7 @@ It lets us specify the training settings for everything from `model_name` to `da
6667
flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
6768
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
6869
profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
70+
6971
```
7072

7173
* [Datasets config file](../../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

src/llama_recipes/configs/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class train_config:
3131
dataset = "samsum_dataset"
3232
peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
3333
use_peft: bool=False
34+
from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
3435
output_dir: str = "PATH/to/save/PEFT/model"
3536
freeze_layers: bool = False
3637
num_freeze_layers: int = 1

src/llama_recipes/finetuning.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import random
99
import torch
1010
import torch.optim as optim
11-
from peft import get_peft_model, prepare_model_for_kbit_training
11+
from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
1212
from torch.distributed.fsdp import (
1313
FullyShardedDataParallel as FSDP,
1414
ShardingStrategy
@@ -134,7 +134,7 @@ def main(**kwargs):
134134
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
135135
tokenizer.pad_token_id = tokenizer.eos_token_id
136136

137-
# If there is a mismatch between tokenizer vocab size and embedding matrix,
137+
# If there is a mismatch between tokenizer vocab size and embedding matrix,
138138
# throw a warning and then expand the embedding matrix
139139
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
140140
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
@@ -151,11 +151,17 @@ def main(**kwargs):
151151
model.to(torch.bfloat16)
152152

153153
if train_config.use_peft:
154-
peft_config = generate_peft_config(train_config, kwargs)
155-
model = get_peft_model(model, peft_config)
156-
model.print_trainable_parameters()
154+
# Load the pre-trained peft model checkpoint and setup its configuration
155+
if train_config.from_peft_checkpoint:
156+
model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
157+
peft_config = model.peft_config()
158+
# Generate the peft config and start fine-tuning from original model
159+
else:
160+
peft_config = generate_peft_config(train_config, kwargs)
161+
model = get_peft_model(model, peft_config)
157162
if wandb_run:
158163
wandb_run.config.update(peft_config)
164+
model.print_trainable_parameters()
159165

160166

161167
hsdp_device_mesh = None

0 commit comments

Comments
 (0)