Skip to content

Commit 480c4f2

Browse files
committed
resume the finetuning given the path of the previous peft checkpoint folder
1 parent 182885e commit 480c4f2

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/llama_recipes/configs/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class train_config:
3131
dataset = "samsum_dataset"
3232
peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
3333
use_peft: bool=False
34+
from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
3435
output_dir: str = "PATH/to/save/PEFT/model"
3536
freeze_layers: bool = False
3637
num_freeze_layers: int = 1

src/llama_recipes/finetuning.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import random
99
import torch
1010
import torch.optim as optim
11-
from peft import get_peft_model, prepare_model_for_kbit_training
11+
from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
1212
from torch.distributed.fsdp import (
1313
FullyShardedDataParallel as FSDP,
1414
ShardingStrategy
@@ -134,7 +134,7 @@ def main(**kwargs):
134134
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
135135
tokenizer.pad_token_id = tokenizer.eos_token_id
136136

137-
# If there is a mismatch between tokenizer vocab size and embedding matrix,
137+
# If there is a mismatch between tokenizer vocab size and embedding matrix,
138138
# throw a warning and then expand the embedding matrix
139139
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
140140
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
@@ -151,11 +151,16 @@ def main(**kwargs):
151151
model.to(torch.bfloat16)
152152

153153
if train_config.use_peft:
154-
peft_config = generate_peft_config(train_config, kwargs)
155-
model = get_peft_model(model, peft_config)
154+
# Load the pre-trained peft model checkpoint and setup its configuration
155+
if train_config.from_peft_checkpoint:
156+
model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
157+
# Generate the peft config and start fine-tuning from original model
158+
else:
159+
peft_config = generate_peft_config(train_config, kwargs)
160+
model = get_peft_model(model, peft_config)
161+
if wandb_run:
162+
wandb_run.config.update(peft_config)
156163
model.print_trainable_parameters()
157-
if wandb_run:
158-
wandb_run.config.update(peft_config)
159164

160165

161166
hsdp_device_mesh = None

0 commit comments

Comments
 (0)