Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions F2LLM/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Args:
validation_steps: int = 100
# just placeholder, for logging purpose
num_processes: int=0
gradient_accumulation_steps: int=4

def dict(self):
return asdict(self)
Expand Down
2 changes: 1 addition & 1 deletion F2LLM/configs/accelerate_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_accumulation_steps: 4
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
Expand Down
1 change: 1 addition & 0 deletions F2LLM/configs/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"weight_decay": 0.01,
"warmup_steps": 500,
"train_epochs": 2,
"gradient_accumulation_steps": 4,
"log_interval": 100,
"num_hard_neg": 7
}
79 changes: 55 additions & 24 deletions F2LLM/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,40 +142,71 @@ def accelerate_train(args,
count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}

# -----验证梯度累积生效-----
# step1:
unwrapped_model = accelerator.unwrap_model(model.lm)
if accelerator.is_main_process:
try:
initial_weight = unwrapped_model.get_parameter("transformer.h.0.attn.c_attn.weight").clone().detach()
except AttributeError:
for name, param in unwrapped_model.named_parameters():
if param.requires_grad:
initial_weight = param.clone().detach()
weight_name = name
break

model.lm.train()
for epoch in range(args.train_epochs):
accelerator.print(f"*************** Starting epoch {epoch+1} ***************")
accelerator.print(f"*************** Starting epoch {epoch + 1} ***************")
train_dataloader.reset_epoch(epoch)
for batch in train_dataloader:
# forward and compute loss
outputs = model.forward(batch)
# passage features: [bs, 1, d]
# hard_neg_features: [bs, num_hard_neg, d]

loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator)
dataset_name = batch['dataset_name']
count_hard_dict[dataset_name] += 1
loss_hard_dict[dataset_name] += loss_hard.detach().float()
if dataset_name in RETRIEVAL_DATASETS:
loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator)
count_dict[dataset_name] += 1
loss_dict[dataset_name] += loss.detach().float()
else:
loss = 0.0

loss_total = loss + loss_hard

# backward, optimizer, scheduler
accelerator.backward(loss_total)
optimizer.step()
with accelerator.accumulate(model.lm):
# forward and compute loss
outputs = model.forward(batch)
# passage features: [bs, 1, d]
# hard_neg_features: [bs, num_hard_neg, d]

loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1),
outputs['passage_passage_features'].squeeze(1),
outputs['negative_passage_features'], criterion, accelerator)
dataset_name = batch['dataset_name']
count_hard_dict[dataset_name] += 1
loss_hard_dict[dataset_name] += loss_hard.detach().float()
if dataset_name in RETRIEVAL_DATASETS:
loss = inbatch_loss(outputs['query_passage_features'].squeeze(1),
outputs['passage_passage_features'].squeeze(1), criterion, accelerator)
count_dict[dataset_name] += 1
loss_dict[dataset_name] += loss.detach().float()
else:
loss = 0.0

loss_total = loss + loss_hard

# backward, optimizer, scheduler
accelerator.backward(loss_total)
# optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# optimizer.zero_grad()
if optimizer.param_groups[0]['lr'] < args.min_lr:
for i in range(len(optimizer.param_groups)):
optimizer.param_groups[i]['lr'] = args.min_lr

# log
completed_steps += 1

# step2: 每 2 步检查权重变化
if completed_steps % 2 == 0 and accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model.lm)
try:
current_weight = unwrapped_model.get_parameter("transformer.h.0.attn.c_attn.weight")
except AttributeError:
current_weight = dict(unwrapped_model.named_parameters())[weight_name]

weight_changed = not torch.equal(initial_weight, current_weight)
accelerator.print(f"[DEBUG] Step {completed_steps}: Weight changed = {weight_changed}")
if weight_changed:
initial_weight = current_weight.clone().detach()

if completed_steps % args.log_interval == 0:
pbar.update(args.log_interval)

Expand Down