From b1c3ac2f7a7996eb48695b7e86045ae0c5458a4f Mon Sep 17 00:00:00 2001 From: mulin Date: Mon, 24 Nov 2025 17:39:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E6=A2=AF=E5=BA=A6=E7=B4=AF?= =?UTF-8?q?=E7=A7=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- F2LLM/arguments.py | 1 + F2LLM/configs/accelerate_config.yaml | 2 +- F2LLM/configs/config.json | 1 + F2LLM/utils.py | 79 +++++++++++++++++++--------- 4 files changed, 58 insertions(+), 25 deletions(-) diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..74bd98e 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -29,6 +29,7 @@ class Args: validation_steps: int = 100 # just placeholder, for logging purpose num_processes: int=0 + gradient_accumulation_steps: int=4 def dict(self): return asdict(self) diff --git a/F2LLM/configs/accelerate_config.yaml b/F2LLM/configs/accelerate_config.yaml index 5133305..fcfc48d 100644 --- a/F2LLM/configs/accelerate_config.yaml +++ b/F2LLM/configs/accelerate_config.yaml @@ -1,7 +1,7 @@ compute_environment: LOCAL_MACHINE debug: false deepspeed_config: - gradient_accumulation_steps: 1 + gradient_accumulation_steps: 4 gradient_clipping: 1.0 offload_optimizer_device: none offload_param_device: none diff --git a/F2LLM/configs/config.json b/F2LLM/configs/config.json index 2ac3708..87f52dd 100644 --- a/F2LLM/configs/config.json +++ b/F2LLM/configs/config.json @@ -14,6 +14,7 @@ "weight_decay": 0.01, "warmup_steps": 500, "train_epochs": 2, + "gradient_accumulation_steps": 4, "log_interval": 100, "num_hard_neg": 7 } diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..cacd5a2 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -142,40 +142,71 @@ def accelerate_train(args, count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} + # -----验证梯度累积生效----- + # step1: + unwrapped_model = accelerator.unwrap_model(model.lm) + if accelerator.is_main_process: + try: + initial_weight = unwrapped_model.get_parameter("transformer.h.0.attn.c_attn.weight").clone().detach() + except AttributeError: + for name, param in unwrapped_model.named_parameters(): + if param.requires_grad: + initial_weight = param.clone().detach() + weight_name = name + break + model.lm.train() for epoch in range(args.train_epochs): - accelerator.print(f"*************** Starting epoch {epoch+1} ***************") + accelerator.print(f"*************** Starting epoch {epoch + 1} ***************") train_dataloader.reset_epoch(epoch) for batch in train_dataloader: - # forward and compute loss - outputs = model.forward(batch) - # passage features: [bs, 1, d] - # hard_neg_features: [bs, num_hard_neg, d] - - loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) - dataset_name = batch['dataset_name'] - count_hard_dict[dataset_name] += 1 - loss_hard_dict[dataset_name] += loss_hard.detach().float() - if dataset_name in RETRIEVAL_DATASETS: - loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) - count_dict[dataset_name] += 1 - loss_dict[dataset_name] += loss.detach().float() - else: - loss = 0.0 - - loss_total = loss + loss_hard - - # backward, optimizer, scheduler - accelerator.backward(loss_total) - optimizer.step() + with accelerator.accumulate(model.lm): + # forward and compute loss + outputs = model.forward(batch) + # passage features: [bs, 1, d] + # hard_neg_features: [bs, num_hard_neg, d] + + loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), + outputs['negative_passage_features'], criterion, accelerator) + dataset_name = batch['dataset_name'] + count_hard_dict[dataset_name] += 1 + loss_hard_dict[dataset_name] += loss_hard.detach().float() + if dataset_name in RETRIEVAL_DATASETS: + loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), + outputs['passage_passage_features'].squeeze(1), criterion, accelerator) + count_dict[dataset_name] += 1 + loss_dict[dataset_name] += loss.detach().float() + else: + loss = 0.0 + + loss_total = loss + loss_hard + + # backward, optimizer, scheduler + accelerator.backward(loss_total) + # optimizer.step() lr_scheduler.step() - optimizer.zero_grad() + # optimizer.zero_grad() if optimizer.param_groups[0]['lr'] < args.min_lr: for i in range(len(optimizer.param_groups)): optimizer.param_groups[i]['lr'] = args.min_lr - + # log completed_steps += 1 + + # step2: 每 2 步检查权重变化 + if completed_steps % 2 == 0 and accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model.lm) + try: + current_weight = unwrapped_model.get_parameter("transformer.h.0.attn.c_attn.weight") + except AttributeError: + current_weight = dict(unwrapped_model.named_parameters())[weight_name] + + weight_changed = not torch.equal(initial_weight, current_weight) + accelerator.print(f"[DEBUG] Step {completed_steps}: Weight changed = {weight_changed}") + if weight_changed: + initial_weight = current_weight.clone().detach() + if completed_steps % args.log_interval == 0: pbar.update(args.log_interval)