diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..8b3411d 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -28,7 +28,8 @@ class Args: checkpointing_steps: int = 100 validation_steps: int = 100 # just placeholder, for logging purpose - num_processes: int=0 + num_processes: int = 0 + finetuning_type: str = "full" def dict(self): return asdict(self) diff --git a/F2LLM/configs/config.json b/F2LLM/configs/config.json index 2ac3708..c3642df 100644 --- a/F2LLM/configs/config.json +++ b/F2LLM/configs/config.json @@ -15,5 +15,6 @@ "warmup_steps": 500, "train_epochs": 2, "log_interval": 100, - "num_hard_neg": 7 + "num_hard_neg": 7, + "finetuning_type": "lora" } diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..f5e6c52 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -14,6 +14,7 @@ from torch.nn.utils.rnn import pad_sequence from torch.optim import AdamW from model import F2LLM +from peft import LoraConfig, get_peft_model, TaskType os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -120,14 +121,35 @@ def __iter__(self): accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************") model = F2LLM(args.model_path, args.max_seq_length, args=args) +if not hasattr(model.lm, 'prepare_inputs_for_generation'): + model.lm.prepare_inputs_for_generation = lambda *args, **kwargs: None + model.lm.gradient_checkpointing_enable() # set seed again to make sure that different models share the same seed set_seed(0) -optimizer = AdamW(model.lm.parameters(), - weight_decay=args.weight_decay, - lr=args.learning_rate, - betas=(0.9, 0.98)) +if args.finetuning_type == 'lora': + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=16, + lora_alpha=32, + lora_dropout=0.1, + target_modules=["q_proj", "v_proj"], + ) + model.lm = get_peft_model(model.lm, lora_config) + if accelerator.is_main_process: + model.lm.print_trainable_parameters() + if hasattr(model.lm, "enable_input_require_grads"): + model.lm.enable_input_require_grads() + optimizer = AdamW(filter(lambda p: p.requires_grad, model.lm.parameters()), + weight_decay=args.weight_decay, + lr=args.learning_rate, + betas=(0.9, 0.98)) +else: + optimizer = AdamW(model.lm.parameters(), + weight_decay=args.weight_decay, + lr=args.learning_rate, + betas=(0.9, 0.98)) lr_scheduler = get_scheduler("cosine", optimizer=optimizer, @@ -150,4 +172,4 @@ def __iter__(self): accelerate_train(args, accelerator, model, train_dataloader, valid_loaders, - optimizer, lr_scheduler, len(dataset)) \ No newline at end of file + optimizer, lr_scheduler, len(dataset))