diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..165a0f6 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, asdict +from dataclasses import dataclass, asdict, field import argparse, json @@ -28,7 +28,10 @@ class Args: checkpointing_steps: int = 100 validation_steps: int = 100 # just placeholder, for logging purpose - num_processes: int=0 + num_processes: int = 0 + matryoshka_dims: list = field( + default_factory=lambda: [64, 128, 256, 512, 1024] + ) def dict(self): return asdict(self) diff --git a/F2LLM/configs/config.json b/F2LLM/configs/config.json index 2ac3708..8f07380 100644 --- a/F2LLM/configs/config.json +++ b/F2LLM/configs/config.json @@ -15,5 +15,6 @@ "warmup_steps": 500, "train_epochs": 2, "log_interval": 100, - "num_hard_neg": 7 + "num_hard_neg": 7, + "matryoshka_dims": [64, 128, 256, 512, 1024] } diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..5a6ac0c 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -121,6 +121,11 @@ def __iter__(self): accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************") model = F2LLM(args.model_path, args.max_seq_length, args=args) model.lm.gradient_checkpointing_enable() +model_embedding_dim = model.lm.config.hidden_size +if any(d > model_embedding_dim for d in args.matryoshka_dims): + raise ValueError( + f"Dimensions in matryoshka_dims cannot exceed the model's embedding dimension of {model_embedding_dim}." + ) # set seed again to make sure that different models share the same seed set_seed(0) diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..88790f3 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -36,24 +36,25 @@ def inbatch_loss( context_embeddings, # [bs, d] criterion, accelerator, - temperature=0.05, + matryoshka_dims=[64, 128, 256, 512, 1024], + temperature=0.05 ): bs = query_embeddings.size(0) - a_norm = F.normalize(query_embeddings, p=2, dim=-1) - # b_norm = torch.nn.functional.normalize(context_embeddings, p=2, dim=-1) b_cross_gpus = accelerator.gather(context_embeddings) # [bs*process, d] - # print((context_embeddings - b_cross_gpus[bs * accelerator.process_index : bs * accelerator.process_index+bs]).abs().sum()) - b_norm_cross_gpus = F.normalize(b_cross_gpus, p=2, dim=-1) # () - - student_logits = torch.matmul(a_norm, b_norm_cross_gpus.t()) / temperature # [bs, bs*process] - - labels = torch.arange(bs, device=student_logits.device) + bs * accelerator.process_index - loss_bs = criterion(student_logits, labels) # (bs) - - loss = loss_bs.mean() - - return loss + losses_in_batch = [] + for dim in matryoshka_dims: + query_embeddng_slice = query_embeddings[:, :dim] + a_norm = F.normalize(query_embeddng_slice, p=2, dim=-1) + b_cross_gpus_slice = b_cross_gpus[:, :dim] + b_norm_cross_gpus = F.normalize(b_cross_gpus_slice, p=2, dim=-1) + student_logits = torch.matmul(a_norm, b_norm_cross_gpus.t()) / temperature # [bs, bs*process] + labels = torch.arange(bs, device=student_logits.device) + bs * accelerator.process_index + loss_bs = criterion(student_logits, labels) # (bs) + loss = loss_bs.mean() + losses_in_batch.append(loss) + + return losses_in_batch def hard_loss( query_embeddings, # [bs, d] @@ -61,56 +62,67 @@ def hard_loss( hard_neg_embeddings, # [bs, num, d] criterion, accelerator, - temperature=0.05, + matryoshka_dims=[64, 128, 256, 512, 1024], + temperature=0.05 ): if hard_neg_embeddings is None: return 0.0 bs = query_embeddings.size(0) - a_norm = F.normalize(query_embeddings, p=2, dim=-1) - hard_neg_embeddings = torch.concat([ context_embeddings.unsqueeze(1), hard_neg_embeddings ], dim=1) # [bs, num_hard+1, d] - - hard_norm = F.normalize(hard_neg_embeddings, p=2, dim=-1) - logits = (a_norm.unsqueeze(1) * hard_norm).sum(-1) / temperature # [bs, num_hard+1] + losses_hard = [] + for dim in matryoshka_dims: + query_embeddings_slice = query_embeddings[:, :dim] + a_norm = F.normalize(query_embeddings_slice, p=2, dim=-1) + hard_neg_embeddings_slice = hard_neg_embeddings[:, :, :dim] + hard_norm = F.normalize(hard_neg_embeddings_slice, p=2, dim=-1) + logits = (a_norm.unsqueeze(1) * hard_norm).sum(-1) / temperature # [bs, num_hard+1] - loss_hard = criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean() + loss_hard = criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean() + losses_hard.append(loss_hard) - return loss_hard + return losses_hard -def validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer): +def validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer, matryoshka_dims): eval_log_dict = {} for dataset_name, valid_dataloader in valid_loader_dict.items(): - loss_ls, loss_hard_ls = [], [] + loss_ls, loss_hard_ls = [[] for _ in range(len(matryoshka_dims))], [[] for _ in range(len(matryoshka_dims))] for batch in valid_dataloader: with torch.no_grad(): outputs = model.forward(batch) - loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) - loss_hard_ls.append(accelerator.gather(loss_hard).float()) + losses_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator, matryoshka_dims) + for i in range(len(matryoshka_dims)): + loss_hard_ls[i].append(accelerator.gather(losses_hard[i]).float()) if dataset_name in RETRIEVAL_DATASETS: - loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) - loss_ls.append(accelerator.gather(loss).float()) + losses_in_batch = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator, matryoshka_dims) + for i in range(len(matryoshka_dims)): + loss_ls[i].append(accelerator.gather(losses_in_batch[i]).float()) accelerator.wait_for_everyone() - loss_hard_ls = torch.cat(loss_hard_ls) - eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean() - if dataset_name in RETRIEVAL_DATASETS: - loss_ls = torch.cat(loss_ls) - eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = loss_ls.mean() + for i, dim in enumerate(matryoshka_dims): + loss_hard_ls[i] = torch.stack(loss_hard_ls[i]) + eval_log_dict[f'{dataset_name}/valid_loss_hard_dim_{dim}'] = loss_hard_ls[i].mean() + if dataset_name in RETRIEVAL_DATASETS: + loss_ls[i] = torch.stack(loss_ls[i]) + eval_log_dict[f"{dataset_name}/valid_loss_in_batch_dim_{dim}"] = loss_ls[i].mean() + for i, dim in enumerate(matryoshka_dims): + eval_log_dict[f'Avg/retrieval/valid_loss_in_batch_dim_{dim}'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith(f'valid_loss_in_batch_dim_{dim}')]).mean() + eval_log_dict[f'Avg/retrieval/valid_loss_hard_dim_{dim}'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith(f'valid_loss_hard_dim_{dim}')]).mean() + eval_log_dict[f'Avg/classification/valid_loss_hard_dim_{dim}'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS and k.endswith(f'valid_loss_hard_dim_{dim}')]).mean() + eval_log_dict[f'Avg/clustering/valid_loss_hard_dim_{dim}'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS and k.endswith(f'valid_loss_hard_dim_{dim}')]).mean() + eval_log_dict = {k: v for k, v in eval_log_dict.items() if not is_nan_tensor(v)} - eval_log_dict['Avg/retrieval/valid_loss_in_batch'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_in_batch')]).mean() - eval_log_dict['Avg/retrieval/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_hard')]).mean() - eval_log_dict['Avg/classification/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean() - eval_log_dict['Avg/clustering/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean() if accelerator.is_main_process: write_tensorboard(summary_writer, eval_log_dict, completed_steps) accelerator.print(f"[Validation] Step = {completed_steps}") - + +def is_nan_tensor(x): + return isinstance(x, torch.Tensor) and x.numel() == 1 and x.isnan().item() def accelerate_train(args, accelerator, @@ -137,11 +149,11 @@ def accelerate_train(args, criterion = CrossEntropyLoss(reduction='none') pbar = tqdm(range(args.train_steps), disable=not accelerator.is_local_main_process) completed_steps = 0 - loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} - loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} + loss_dict = {ds_name: [] for ds_name in RETRIEVAL_DATASETS} + loss_hard_dict = {ds_name: [] for ds_name in train_dataloader.loader_dict.keys()} count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} - + matryoshka_dims = args.matryoshka_dims model.lm.train() for epoch in range(args.train_epochs): accelerator.print(f"*************** Starting epoch {epoch+1} ***************") @@ -152,14 +164,16 @@ def accelerate_train(args, # passage features: [bs, 1, d] # hard_neg_features: [bs, num_hard_neg, d] - loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) + losses_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator, matryoshka_dims) + loss_hard = torch.sum(torch.stack(losses_hard)) dataset_name = batch['dataset_name'] count_hard_dict[dataset_name] += 1 - loss_hard_dict[dataset_name] += loss_hard.detach().float() + loss_hard_dict[dataset_name] += [l.detach().float() for l in losses_hard] if dataset_name in RETRIEVAL_DATASETS: - loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) + losses_in_batch = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator, matryoshka_dims) + loss = torch.sum(torch.stack(losses_in_batch)) count_dict[dataset_name] += 1 - loss_dict[dataset_name] += loss.detach().float() + loss_dict[dataset_name] += [l.detach().float() for l in losses_in_batch] else: loss = 0.0 @@ -183,28 +197,33 @@ def accelerate_train(args, for k in loss_dict.keys(): count = accelerator.gather(count_dict[k]).sum() if count > 0: - train_log_dict[f"{k}/training_loss_in_batch"] = accelerator.gather(loss_dict[k]).sum() / count + for i, dim in enumerate(matryoshka_dims): + train_log_dict[f"{k}/training_loss_in_batch_dim_{dim}"] = accelerator.gather(loss_dict[k][i]).sum() / count for k in loss_hard_dict.keys(): count = accelerator.gather(count_hard_dict[k]).sum() if count > 0: - train_log_dict[f"{k}/training_loss_hard"] = accelerator.gather(loss_hard_dict[k]).sum() / count - train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean() - train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean() - train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean() - train_log_dict['Avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS]).mean() - + for i, dim in enumerate(matryoshka_dims): + train_log_dict[f"{k}/training_loss_hard_dim_{dim}"] = accelerator.gather(loss_hard_dict[k][i]).sum() / count + for dim in matryoshka_dims: + train_log_dict[f'dim_{dim}/avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith(f'training_loss_in_batch_dim_{dim}')]).mean() + train_log_dict[f'dim_{dim}/avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith(f'training_loss_hard_dim_{dim}')]).mean() + train_log_dict[f'dim_{dim}/avg/classification/training_loss'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS and k.endswith(f'training_loss_hard_dim_{dim}')]).mean() + train_log_dict[f'dim_{dim}/avg/clustering/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLUSTERING_DATASETS and k.endswith(f'training_loss_hard_dim_{dim}')]).mean() + train_log_dict[f'Avg/loss_total'] = accelerator.gather(loss_total).sum() / accelerator.num_processes + + train_log_dict = {k: v for k, v in train_log_dict.items() if not is_nan_tensor(v)} accelerator.print(f"[Train] Step = {completed_steps}") if accelerator.is_main_process: write_tensorboard(summary_writer, train_log_dict, completed_steps) - loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} - loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} + loss_dict = {ds_name: [] for ds_name in RETRIEVAL_DATASETS} + loss_hard_dict = {ds_name: [] for ds_name in train_dataloader.loader_dict.keys()} count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} # validation if completed_steps % args.validation_steps == 0: model.lm.eval() - validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer) + validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer, matryoshka_dims) model.lm.train() # step checkpoint @@ -220,7 +239,7 @@ def accelerate_train(args, save_checkpoint(args, accelerator, model, output_dir, lr_scheduler) if completed_steps % args.validation_steps != 0: model.lm.eval() - validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer) + validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer, matryoshka_dims) model.lm.train() if summary_writer: