Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions F2LLM/arguments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, field
import argparse, json


Expand Down Expand Up @@ -28,7 +28,10 @@ class Args:
checkpointing_steps: int = 100
validation_steps: int = 100
# just placeholder, for logging purpose
num_processes: int=0
num_processes: int = 0
matryoshka_dims: list = field(
default_factory=lambda: [64, 128, 256, 512, 1024]
)

def dict(self):
return asdict(self)
Expand Down
3 changes: 2 additions & 1 deletion F2LLM/configs/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
"warmup_steps": 500,
"train_epochs": 2,
"log_interval": 100,
"num_hard_neg": 7
"num_hard_neg": 7,
"matryoshka_dims": [64, 128, 256, 512, 1024]
}
5 changes: 5 additions & 0 deletions F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def __iter__(self):
accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************")
model = F2LLM(args.model_path, args.max_seq_length, args=args)
model.lm.gradient_checkpointing_enable()
model_embedding_dim = model.lm.config.hidden_size
if any(d > model_embedding_dim for d in args.matryoshka_dims):
raise ValueError(
f"Dimensions in matryoshka_dims cannot exceed the model's embedding dimension of {model_embedding_dim}."
)
# set seed again to make sure that different models share the same seed
set_seed(0)

Expand Down
131 changes: 75 additions & 56 deletions F2LLM/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,81 +36,93 @@ def inbatch_loss(
context_embeddings, # [bs, d]
criterion,
accelerator,
temperature=0.05,
matryoshka_dims=[64, 128, 256, 512, 1024],
temperature=0.05
):

bs = query_embeddings.size(0)
a_norm = F.normalize(query_embeddings, p=2, dim=-1)
# b_norm = torch.nn.functional.normalize(context_embeddings, p=2, dim=-1)
b_cross_gpus = accelerator.gather(context_embeddings) # [bs*process, d]
# print((context_embeddings - b_cross_gpus[bs * accelerator.process_index : bs * accelerator.process_index+bs]).abs().sum())
b_norm_cross_gpus = F.normalize(b_cross_gpus, p=2, dim=-1) # ()

student_logits = torch.matmul(a_norm, b_norm_cross_gpus.t()) / temperature # [bs, bs*process]

labels = torch.arange(bs, device=student_logits.device) + bs * accelerator.process_index
loss_bs = criterion(student_logits, labels) # (bs)

loss = loss_bs.mean()

return loss
losses_in_batch = []
for dim in matryoshka_dims:
query_embeddng_slice = query_embeddings[:, :dim]
a_norm = F.normalize(query_embeddng_slice, p=2, dim=-1)
b_cross_gpus_slice = b_cross_gpus[:, :dim]
b_norm_cross_gpus = F.normalize(b_cross_gpus_slice, p=2, dim=-1)
student_logits = torch.matmul(a_norm, b_norm_cross_gpus.t()) / temperature # [bs, bs*process]
labels = torch.arange(bs, device=student_logits.device) + bs * accelerator.process_index
loss_bs = criterion(student_logits, labels) # (bs)
loss = loss_bs.mean()
losses_in_batch.append(loss)

return losses_in_batch

def hard_loss(
query_embeddings, # [bs, d]
context_embeddings, # [bs, d]
hard_neg_embeddings, # [bs, num, d]
criterion,
accelerator,
temperature=0.05,
matryoshka_dims=[64, 128, 256, 512, 1024],
temperature=0.05
):

if hard_neg_embeddings is None:
return 0.0

bs = query_embeddings.size(0)
a_norm = F.normalize(query_embeddings, p=2, dim=-1)

hard_neg_embeddings = torch.concat([
context_embeddings.unsqueeze(1),
hard_neg_embeddings
], dim=1) # [bs, num_hard+1, d]

hard_norm = F.normalize(hard_neg_embeddings, p=2, dim=-1)
logits = (a_norm.unsqueeze(1) * hard_norm).sum(-1) / temperature # [bs, num_hard+1]
losses_hard = []
for dim in matryoshka_dims:
query_embeddings_slice = query_embeddings[:, :dim]
a_norm = F.normalize(query_embeddings_slice, p=2, dim=-1)
hard_neg_embeddings_slice = hard_neg_embeddings[:, :, :dim]
hard_norm = F.normalize(hard_neg_embeddings_slice, p=2, dim=-1)
logits = (a_norm.unsqueeze(1) * hard_norm).sum(-1) / temperature # [bs, num_hard+1]

loss_hard = criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean()
loss_hard = criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean()
losses_hard.append(loss_hard)

return loss_hard
return losses_hard


def validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer):
def validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer, matryoshka_dims):
eval_log_dict = {}
for dataset_name, valid_dataloader in valid_loader_dict.items():
loss_ls, loss_hard_ls = [], []
loss_ls, loss_hard_ls = [[] for _ in range(len(matryoshka_dims))], [[] for _ in range(len(matryoshka_dims))]
for batch in valid_dataloader:
with torch.no_grad():
outputs = model.forward(batch)
loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator)
loss_hard_ls.append(accelerator.gather(loss_hard).float())
losses_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator, matryoshka_dims)
for i in range(len(matryoshka_dims)):
loss_hard_ls[i].append(accelerator.gather(losses_hard[i]).float())
if dataset_name in RETRIEVAL_DATASETS:
loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator)
loss_ls.append(accelerator.gather(loss).float())
losses_in_batch = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator, matryoshka_dims)
for i in range(len(matryoshka_dims)):
loss_ls[i].append(accelerator.gather(losses_in_batch[i]).float())

accelerator.wait_for_everyone()
loss_hard_ls = torch.cat(loss_hard_ls)
eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean()
if dataset_name in RETRIEVAL_DATASETS:
loss_ls = torch.cat(loss_ls)
eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = loss_ls.mean()
for i, dim in enumerate(matryoshka_dims):
loss_hard_ls[i] = torch.stack(loss_hard_ls[i])
eval_log_dict[f'{dataset_name}/valid_loss_hard_dim_{dim}'] = loss_hard_ls[i].mean()
if dataset_name in RETRIEVAL_DATASETS:
loss_ls[i] = torch.stack(loss_ls[i])
eval_log_dict[f"{dataset_name}/valid_loss_in_batch_dim_{dim}"] = loss_ls[i].mean()
for i, dim in enumerate(matryoshka_dims):
eval_log_dict[f'Avg/retrieval/valid_loss_in_batch_dim_{dim}'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith(f'valid_loss_in_batch_dim_{dim}')]).mean()
eval_log_dict[f'Avg/retrieval/valid_loss_hard_dim_{dim}'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith(f'valid_loss_hard_dim_{dim}')]).mean()
eval_log_dict[f'Avg/classification/valid_loss_hard_dim_{dim}'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS and k.endswith(f'valid_loss_hard_dim_{dim}')]).mean()
eval_log_dict[f'Avg/clustering/valid_loss_hard_dim_{dim}'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS and k.endswith(f'valid_loss_hard_dim_{dim}')]).mean()
eval_log_dict = {k: v for k, v in eval_log_dict.items() if not is_nan_tensor(v)}

eval_log_dict['Avg/retrieval/valid_loss_in_batch'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_in_batch')]).mean()
eval_log_dict['Avg/retrieval/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_hard')]).mean()
eval_log_dict['Avg/classification/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean()
eval_log_dict['Avg/clustering/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean()
if accelerator.is_main_process:
write_tensorboard(summary_writer, eval_log_dict, completed_steps)
accelerator.print(f"[Validation] Step = {completed_steps}")


def is_nan_tensor(x):
return isinstance(x, torch.Tensor) and x.numel() == 1 and x.isnan().item()

def accelerate_train(args,
accelerator,
Expand All @@ -137,11 +149,11 @@ def accelerate_train(args,
criterion = CrossEntropyLoss(reduction='none')
pbar = tqdm(range(args.train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}
loss_dict = {ds_name: [] for ds_name in RETRIEVAL_DATASETS}
loss_hard_dict = {ds_name: [] for ds_name in train_dataloader.loader_dict.keys()}
count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}

matryoshka_dims = args.matryoshka_dims
model.lm.train()
for epoch in range(args.train_epochs):
accelerator.print(f"*************** Starting epoch {epoch+1} ***************")
Expand All @@ -152,14 +164,16 @@ def accelerate_train(args,
# passage features: [bs, 1, d]
# hard_neg_features: [bs, num_hard_neg, d]

loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator)
losses_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator, matryoshka_dims)
loss_hard = torch.sum(torch.stack(losses_hard))
dataset_name = batch['dataset_name']
count_hard_dict[dataset_name] += 1
loss_hard_dict[dataset_name] += loss_hard.detach().float()
loss_hard_dict[dataset_name] += [l.detach().float() for l in losses_hard]
if dataset_name in RETRIEVAL_DATASETS:
loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator)
losses_in_batch = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator, matryoshka_dims)
loss = torch.sum(torch.stack(losses_in_batch))
count_dict[dataset_name] += 1
loss_dict[dataset_name] += loss.detach().float()
loss_dict[dataset_name] += [l.detach().float() for l in losses_in_batch]
else:
loss = 0.0

Expand All @@ -183,28 +197,33 @@ def accelerate_train(args,
for k in loss_dict.keys():
count = accelerator.gather(count_dict[k]).sum()
if count > 0:
train_log_dict[f"{k}/training_loss_in_batch"] = accelerator.gather(loss_dict[k]).sum() / count
for i, dim in enumerate(matryoshka_dims):
train_log_dict[f"{k}/training_loss_in_batch_dim_{dim}"] = accelerator.gather(loss_dict[k][i]).sum() / count
for k in loss_hard_dict.keys():
count = accelerator.gather(count_hard_dict[k]).sum()
if count > 0:
train_log_dict[f"{k}/training_loss_hard"] = accelerator.gather(loss_hard_dict[k]).sum() / count
train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean()
train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean()
train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean()
train_log_dict['Avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean()

for i, dim in enumerate(matryoshka_dims):
train_log_dict[f"{k}/training_loss_hard_dim_{dim}"] = accelerator.gather(loss_hard_dict[k][i]).sum() / count
for dim in matryoshka_dims:
train_log_dict[f'dim_{dim}/avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith(f'training_loss_in_batch_dim_{dim}')]).mean()
train_log_dict[f'dim_{dim}/avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith(f'training_loss_hard_dim_{dim}')]).mean()
train_log_dict[f'dim_{dim}/avg/classification/training_loss'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS and k.endswith(f'training_loss_hard_dim_{dim}')]).mean()
train_log_dict[f'dim_{dim}/avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS and k.endswith(f'training_loss_hard_dim_{dim}')]).mean()
train_log_dict[f'Avg/loss_total'] = accelerator.gather(loss_total).sum() / accelerator.num_processes

train_log_dict = {k: v for k, v in train_log_dict.items() if not is_nan_tensor(v)}
accelerator.print(f"[Train] Step = {completed_steps}")
if accelerator.is_main_process:
write_tensorboard(summary_writer, train_log_dict, completed_steps)
loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}
loss_dict = {ds_name: [] for ds_name in RETRIEVAL_DATASETS}
loss_hard_dict = {ds_name: [] for ds_name in train_dataloader.loader_dict.keys()}
count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}

# validation
if completed_steps % args.validation_steps == 0:
model.lm.eval()
validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer)
validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer, matryoshka_dims)
model.lm.train()

# step checkpoint
Expand All @@ -220,7 +239,7 @@ def accelerate_train(args,
save_checkpoint(args, accelerator, model, output_dir, lr_scheduler)
if completed_steps % args.validation_steps != 0:
model.lm.eval()
validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer)
validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer, matryoshka_dims)
model.lm.train()

if summary_writer:
Expand Down