Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion F2LLM/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class Args:
checkpointing_steps: int = 100
validation_steps: int = 100
# just placeholder, for logging purpose
num_processes: int=0
num_processes: int = 0
gradient_accumulation_steps: int = 1

def dict(self):
return asdict(self)
Expand Down
3 changes: 2 additions & 1 deletion F2LLM/configs/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
"warmup_steps": 500,
"train_epochs": 2,
"log_interval": 100,
"num_hard_neg": 7
"num_hard_neg": 7,
"gradient_accumulation_steps": 1
}
4 changes: 2 additions & 2 deletions F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __iter__(self):
# determine training steps
override_train_step = False
if args.train_steps < 0:
args.train_steps = sum(len(v) for v in train_loaders.values()) * args.train_epochs
args.train_steps = int(sum(len(v) for v in train_loaders.values()) * args.train_epochs / args.gradient_accumulation_steps)
override_train_step = True

accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************")
Expand Down Expand Up @@ -145,7 +145,7 @@ def __iter__(self):

# if training on multiple GPUs, length of dataloader would have changed
if override_train_step:
args.train_steps = len(train_dataloader) * args.train_epochs
args.train_steps = int(len(train_dataloader) * args.train_epochs / args.gradient_accumulation_steps)
accelerator.print(f"******************************** Training step after prepare: {args.train_steps} ********************************")


Expand Down
103 changes: 55 additions & 48 deletions F2LLM/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def accelerate_train(args,
accelerator.print(f" Num train samples = {num_train_samples}")
accelerator.print(f" Num epochs = {args.train_epochs}")
accelerator.print(f" Per device batch size = {args.train_batch_size}")
accelerator.print(f" Global batch size = {args.train_batch_size * accelerator.num_processes}")
accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
accelerator.print(f" Global batch size = {args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps}")
accelerator.print(f" Step per epoch = {len(train_dataloader)}")
accelerator.print(f" Total training steps = {args.train_steps}")
accelerator.print("************************************************************************************************")
Expand All @@ -137,6 +138,8 @@ def accelerate_train(args,
criterion = CrossEntropyLoss(reduction='none')
pbar = tqdm(range(args.train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
mini_batch_steps = 0
gradient_accumulation_steps = args.gradient_accumulation_steps
loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}
count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
Expand All @@ -163,57 +166,61 @@ def accelerate_train(args,
else:
loss = 0.0

loss_total = loss + loss_hard
loss_total = (loss + loss_hard) / gradient_accumulation_steps

# backward, optimizer, scheduler
accelerator.backward(loss_total)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if optimizer.param_groups[0]['lr'] < args.min_lr:
for i in range(len(optimizer.param_groups)):
optimizer.param_groups[i]['lr'] = args.min_lr
mini_batch_steps += 1

# log
completed_steps += 1
if completed_steps % args.log_interval == 0:
pbar.update(args.log_interval)

train_log_dict = {"lr": optimizer.param_groups[0]['lr']}
for k in loss_dict.keys():
count = accelerator.gather(count_dict[k]).sum()
if count > 0:
train_log_dict[f"{k}/training_loss_in_batch"] = accelerator.gather(loss_dict[k]).sum() / count
for k in loss_hard_dict.keys():
count = accelerator.gather(count_hard_dict[k]).sum()
if count > 0:
train_log_dict[f"{k}/training_loss_hard"] = accelerator.gather(loss_hard_dict[k]).sum() / count
train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean()
train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean()
train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean()
train_log_dict['Avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean()

accelerator.print(f"[Train] Step = {completed_steps}")
if accelerator.is_main_process:
write_tensorboard(summary_writer, train_log_dict, completed_steps)
loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}
count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}

# validation
if completed_steps % args.validation_steps == 0:
model.lm.eval()
validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer)
model.lm.train()

# step checkpoint
if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0:
output_dir = os.path.join(args.output_dir, f"step_{completed_steps}")
save_checkpoint(args, accelerator, model, output_dir, lr_scheduler)

if completed_steps >= args.train_steps:
break
if mini_batch_steps % gradient_accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

if optimizer.param_groups[0]['lr'] < args.min_lr:
for param_group in optimizer.param_groups:
param_group['lr'] = args.min_lr

# log
completed_steps += 1
if completed_steps % args.log_interval == 0:
pbar.update(args.log_interval)

train_log_dict = {"lr": optimizer.param_groups[0]['lr']}
for k in loss_dict.keys():
count = accelerator.gather(count_dict[k]).sum()
if count > 0:
train_log_dict[f"{k}/training_loss_in_batch"] = accelerator.gather(loss_dict[k]).sum() / count
for k in loss_hard_dict.keys():
count = accelerator.gather(count_hard_dict[k]).sum()
if count > 0:
train_log_dict[f"{k}/training_loss_hard"] = accelerator.gather(loss_hard_dict[k]).sum() / count
train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean()
train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean()
train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean()
train_log_dict['Avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean()

accelerator.print(f"[Train] Step = {completed_steps}")
if accelerator.is_main_process and summary_writer is not None:
write_tensorboard(summary_writer, train_log_dict, completed_steps)
loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}
count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}

if completed_steps % args.validation_steps == 0:
model.lm.eval()
validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer)
model.lm.train()

# step checkpoint
if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0:
output_dir = os.path.join(args.output_dir, f"step_{completed_steps}")
save_checkpoint(args, accelerator, model, output_dir, lr_scheduler)
if completed_steps >= args.train_steps:
break
if completed_steps >= args.train_steps:
break

# epoch checkpoint
output_dir = os.path.join(args.output_dir, f"epoch_{epoch+1}")
Expand Down