8
8
import random
9
9
import torch
10
10
import torch .optim as optim
11
- from peft import get_peft_model , prepare_model_for_kbit_training
11
+ from peft import get_peft_model , prepare_model_for_kbit_training , PeftModel
12
12
from torch .distributed .fsdp import (
13
13
FullyShardedDataParallel as FSDP ,
14
14
ShardingStrategy
@@ -134,7 +134,7 @@ def main(**kwargs):
134
134
tokenizer = AutoTokenizer .from_pretrained (train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name )
135
135
tokenizer .pad_token_id = tokenizer .eos_token_id
136
136
137
- # If there is a mismatch between tokenizer vocab size and embedding matrix,
137
+ # If there is a mismatch between tokenizer vocab size and embedding matrix,
138
138
# throw a warning and then expand the embedding matrix
139
139
if len (tokenizer ) > model .get_input_embeddings ().weight .shape [0 ]:
140
140
print ("WARNING: Resizing the embedding matrix to match the tokenizer vocab size." )
@@ -151,11 +151,16 @@ def main(**kwargs):
151
151
model .to (torch .bfloat16 )
152
152
153
153
if train_config .use_peft :
154
- peft_config = generate_peft_config (train_config , kwargs )
155
- model = get_peft_model (model , peft_config )
154
+ # Load the pre-trained peft model checkpoint and setup its configuration
155
+ if train_config .from_peft_checkpoint :
156
+ model = PeftModel .from_pretrained (model , train_config .from_peft_checkpoint , is_trainable = True )
157
+ # Generate the peft config and start fine-tuning from original model
158
+ else :
159
+ peft_config = generate_peft_config (train_config , kwargs )
160
+ model = get_peft_model (model , peft_config )
161
+ if wandb_run :
162
+ wandb_run .config .update (peft_config )
156
163
model .print_trainable_parameters ()
157
- if wandb_run :
158
- wandb_run .config .update (peft_config )
159
164
160
165
161
166
hsdp_device_mesh = None
0 commit comments