Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion F2LLM/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class Args:
checkpointing_steps: int = 100
validation_steps: int = 100
# just placeholder, for logging purpose
num_processes: int=0
num_processes: int = 0
finetuning_type: str = "full"

def dict(self):
return asdict(self)
Expand Down
3 changes: 2 additions & 1 deletion F2LLM/configs/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
"warmup_steps": 500,
"train_epochs": 2,
"log_interval": 100,
"num_hard_neg": 7
"num_hard_neg": 7,
"finetuning_type": "lora"
}
32 changes: 27 additions & 5 deletions F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from model import F2LLM
from peft import LoraConfig, get_peft_model, TaskType

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down Expand Up @@ -120,14 +121,35 @@ def __iter__(self):

accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************")
model = F2LLM(args.model_path, args.max_seq_length, args=args)
if not hasattr(model.lm, 'prepare_inputs_for_generation'):
model.lm.prepare_inputs_for_generation = lambda *args, **kwargs: None

model.lm.gradient_checkpointing_enable()
# set seed again to make sure that different models share the same seed
set_seed(0)

optimizer = AdamW(model.lm.parameters(),
weight_decay=args.weight_decay,
lr=args.learning_rate,
betas=(0.9, 0.98))
if args.finetuning_type == 'lora':
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"],
)
model.lm = get_peft_model(model.lm, lora_config)
if accelerator.is_main_process:
model.lm.print_trainable_parameters()
if hasattr(model.lm, "enable_input_require_grads"):
model.lm.enable_input_require_grads()
optimizer = AdamW(filter(lambda p: p.requires_grad, model.lm.parameters()),
weight_decay=args.weight_decay,
lr=args.learning_rate,
betas=(0.9, 0.98))
else:
optimizer = AdamW(model.lm.parameters(),
weight_decay=args.weight_decay,
lr=args.learning_rate,
betas=(0.9, 0.98))

lr_scheduler = get_scheduler("cosine",
optimizer=optimizer,
Expand All @@ -150,4 +172,4 @@ def __iter__(self):


accelerate_train(args, accelerator, model, train_dataloader, valid_loaders,
optimizer, lr_scheduler, len(dataset))
optimizer, lr_scheduler, len(dataset))