From d4694d184b0cf926dfe873c705ac119f68f8abc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=97=E7=A6=8F?= Date: Mon, 24 Nov 2025 22:28:44 +0800 Subject: [PATCH] support gradient accumulation --- F2LLM/arguments.py | 3 +- F2LLM/configs/config.json | 3 +- F2LLM/run.py | 4 +- F2LLM/utils.py | 103 ++++++++++++++++++++------------------ 4 files changed, 61 insertions(+), 52 deletions(-) diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..67ccf15 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -28,7 +28,8 @@ class Args: checkpointing_steps: int = 100 validation_steps: int = 100 # just placeholder, for logging purpose - num_processes: int=0 + num_processes: int = 0 + gradient_accumulation_steps: int = 1 def dict(self): return asdict(self) diff --git a/F2LLM/configs/config.json b/F2LLM/configs/config.json index 2ac3708..7b8505b 100644 --- a/F2LLM/configs/config.json +++ b/F2LLM/configs/config.json @@ -15,5 +15,6 @@ "warmup_steps": 500, "train_epochs": 2, "log_interval": 100, - "num_hard_neg": 7 + "num_hard_neg": 7, + "gradient_accumulation_steps": 1 } diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..a31f3f3 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -115,7 +115,7 @@ def __iter__(self): # determine training steps override_train_step = False if args.train_steps < 0: - args.train_steps = sum(len(v) for v in train_loaders.values()) * args.train_epochs + args.train_steps = int(sum(len(v) for v in train_loaders.values()) * args.train_epochs / args.gradient_accumulation_steps) override_train_step = True accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************") @@ -145,7 +145,7 @@ def __iter__(self): # if training on multiple GPUs, length of dataloader would have changed if override_train_step: - args.train_steps = len(train_dataloader) * args.train_epochs + args.train_steps = int(len(train_dataloader) * args.train_epochs / args.gradient_accumulation_steps) accelerator.print(f"******************************** Training step after prepare: {args.train_steps} ********************************") diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..2ee747f 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -124,7 +124,8 @@ def accelerate_train(args, accelerator.print(f" Num train samples = {num_train_samples}") accelerator.print(f" Num epochs = {args.train_epochs}") accelerator.print(f" Per device batch size = {args.train_batch_size}") - accelerator.print(f" Global batch size = {args.train_batch_size * accelerator.num_processes}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Global batch size = {args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps}") accelerator.print(f" Step per epoch = {len(train_dataloader)}") accelerator.print(f" Total training steps = {args.train_steps}") accelerator.print("************************************************************************************************") @@ -137,6 +138,8 @@ def accelerate_train(args, criterion = CrossEntropyLoss(reduction='none') pbar = tqdm(range(args.train_steps), disable=not accelerator.is_local_main_process) completed_steps = 0 + mini_batch_steps = 0 + gradient_accumulation_steps = args.gradient_accumulation_steps loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} @@ -163,57 +166,61 @@ def accelerate_train(args, else: loss = 0.0 - loss_total = loss + loss_hard + loss_total = (loss + loss_hard) / gradient_accumulation_steps # backward, optimizer, scheduler accelerator.backward(loss_total) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - if optimizer.param_groups[0]['lr'] < args.min_lr: - for i in range(len(optimizer.param_groups)): - optimizer.param_groups[i]['lr'] = args.min_lr + mini_batch_steps += 1 - # log - completed_steps += 1 - if completed_steps % args.log_interval == 0: - pbar.update(args.log_interval) - - train_log_dict = {"lr": optimizer.param_groups[0]['lr']} - for k in loss_dict.keys(): - count = accelerator.gather(count_dict[k]).sum() - if count > 0: - train_log_dict[f"{k}/training_loss_in_batch"] = accelerator.gather(loss_dict[k]).sum() / count - for k in loss_hard_dict.keys(): - count = accelerator.gather(count_hard_dict[k]).sum() - if count > 0: - train_log_dict[f"{k}/training_loss_hard"] = accelerator.gather(loss_hard_dict[k]).sum() / count - train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean() - train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean() - train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean() - train_log_dict['Avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean() - - accelerator.print(f"[Train] Step = {completed_steps}") - if accelerator.is_main_process: - write_tensorboard(summary_writer, train_log_dict, completed_steps) - loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} - loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} - count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} - count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} - - # validation - if completed_steps % args.validation_steps == 0: - model.lm.eval() - validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer) - model.lm.train() - - # step checkpoint - if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0: - output_dir = os.path.join(args.output_dir, f"step_{completed_steps}") - save_checkpoint(args, accelerator, model, output_dir, lr_scheduler) - - if completed_steps >= args.train_steps: - break + if mini_batch_steps % gradient_accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if optimizer.param_groups[0]['lr'] < args.min_lr: + for param_group in optimizer.param_groups: + param_group['lr'] = args.min_lr + + # log + completed_steps += 1 + if completed_steps % args.log_interval == 0: + pbar.update(args.log_interval) + + train_log_dict = {"lr": optimizer.param_groups[0]['lr']} + for k in loss_dict.keys(): + count = accelerator.gather(count_dict[k]).sum() + if count > 0: + train_log_dict[f"{k}/training_loss_in_batch"] = accelerator.gather(loss_dict[k]).sum() / count + for k in loss_hard_dict.keys(): + count = accelerator.gather(count_hard_dict[k]).sum() + if count > 0: + train_log_dict[f"{k}/training_loss_hard"] = accelerator.gather(loss_hard_dict[k]).sum() / count + train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean() + train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean() + train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean() + train_log_dict['Avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean() + + accelerator.print(f"[Train] Step = {completed_steps}") + if accelerator.is_main_process and summary_writer is not None: + write_tensorboard(summary_writer, train_log_dict, completed_steps) + loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} + loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} + count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} + count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} + + if completed_steps % args.validation_steps == 0: + model.lm.eval() + validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer) + model.lm.train() + + # step checkpoint + if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0: + output_dir = os.path.join(args.output_dir, f"step_{completed_steps}") + save_checkpoint(args, accelerator, model, output_dir, lr_scheduler) + if completed_steps >= args.train_steps: + break + if completed_steps >= args.train_steps: + break # epoch checkpoint output_dir = os.path.join(args.output_dir, f"epoch_{epoch+1}")