From 1196b2ae82fd9eaf844b8f17e6a4fab0432506a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=AF=E8=BF=9B?= Date: Fri, 28 Nov 2025 16:39:45 +0800 Subject: [PATCH] MRL-support --- ...00\346\261\202\346\226\207\346\241\243.md" | 116 ++++++++++++ F2LLM/arguments.py | 17 ++ F2LLM/configs/config.json | 14 +- F2LLM/model.py | 100 +++++++++- F2LLM/run.py | 18 +- F2LLM/utils.py | 174 ++++++++++++++++-- 6 files changed, 404 insertions(+), 35 deletions(-) create mode 100644 "F2LLM/MRL\351\234\200\346\261\202\346\226\207\346\241\243.md" diff --git "a/F2LLM/MRL\351\234\200\346\261\202\346\226\207\346\241\243.md" "b/F2LLM/MRL\351\234\200\346\261\202\346\226\207\346\241\243.md" new file mode 100644 index 0000000..0da53cd --- /dev/null +++ "b/F2LLM/MRL\351\234\200\346\261\202\346\226\207\346\241\243.md" @@ -0,0 +1,116 @@ +# Matryoshka Representation Learning (MRL) 支持需求文档 + +## 1. 背景 + +在训练CodeFuse-Embeddings模型时,为了提供更大的灵活性以适应不同的下游应用和计算预算,我们实现了Matryoshka Representation Learning (MRL)支持。MRL是一种"俄罗斯套娃"式的训练方法,允许单个模型在推理时产生不同维度的高质量嵌入(例如64、128、256、512、1024等),从而为不同的应用场景提供显著的灵活性。 + +## 2. 需求目标 + +为CodeFuse-Embeddings模型增加MRL支持,使得模型能够在训练时学习多个嵌入维度的表示,并在推理时根据需要选择合适的嵌入维度,以平衡性能和计算效率。 + +## 3. 技术实现 + +### 3.1 MRL核心概念 + +Matryoshka Representation Learning是一种训练方法,它允许模型在不同维度上产生嵌入表示,其中较低维度的嵌入是较高维度嵌入的子集。这种方法通过以下方式实现: + +1. 在训练过程中,模型同时学习多个目标维度的表示 +2. 使用投影层将完整维度的嵌入映射到目标维度 +3. 在损失计算时,同时考虑所有目标维度的损失 + +### 3.2 实现细节 + +#### 3.2.1 配置参数 + +在`F2LLM/configs/config.json`中增加了以下MRL相关配置参数: + +- `mrl_enabled`: 是否启用MRL(布尔值,默认为false) +- `mrl_dims`: MRL目标维度列表(数组,默认为[128, 256, 512, 1024]) +- `mrl_loss_weights`: 每个维度的损失权重(数组,默认为[1.0, 1.0, 1.0, 1.0]) + +#### 3.2.2 模型修改 + +在`F2LLM/model.py`中实现了以下MRL相关功能: + +1. `F2LLM`类中增加了MRL支持: + - 添加了`mrl_enabled`标志来控制是否启用MRL + - 创建了针对每个目标维度的投影层(`mrl_projections`) + - 实现了`get_mrl_embeddings`方法用于获取特定维度的嵌入 + - 实现了`get_all_mrl_embeddings`方法用于获取所有维度的嵌入 + +2. `forward`方法修改: + - 增加了`target_dim`参数用于指定目标嵌入维度 + - 根据是否启用MRL返回不同维度的嵌入表示 + - 支持返回所有维度的嵌入字典 + +#### 3.2.3 训练过程修改 + +在`F2LLM/utils.py`中实现了以下MRL相关功能: + +1. 增加了MRL损失计算函数: + - `mrl_inbatch_loss`: 计算MRL的批次内负采样损失 + - `mrl_hard_loss`: 计算MRL的硬负样本损失 + +2. 修改了训练循环: + - 在每个训练批次中随机选择一个目标维度 + - 根据选择的维度计算相应的损失 + +3. 修改了验证过程: + - 在验证时使用完整维度进行评估 + +#### 3.2.4 参数定义 + +在`F2LLM/arguments.py`中增加了MRL相关参数定义: + +- `mrl_enabled`: 是否启用MRL +- `mrl_dims`: MRL目标维度列表 +- `mrl_loss_weights`: 每个维度的损失权重 + +## 4. 使用方法 + +### 4.1 启用MRL训练 + +在配置文件`F2LLM/configs/config.json`中设置: + +```json +{ + "mrl_enabled": true, + "mrl_dims": [128, 256, 512, 1024], + "mrl_loss_weights": [1.0, 1.0, 1.0, 1.0] +} +``` + +### 4.2 启动训练 + +```bash +cd F2LLM +python run.py --config configs/config.json +``` + +### 4.3 推理时使用不同维度 + +在推理时,可以通过指定`target_dim`参数来获取特定维度的嵌入: + +```python +# 获取特定维度的嵌入 +outputs = model.forward(batch, target_dim=512) + +# 获取所有维度的嵌入 +outputs = model.forward(batch) +``` + +## 5. 优势 + +1. **灵活性**: 单个模型可以生成多种维度的嵌入,适应不同应用场景 +2. **效率**: 在推理时可以选择较低维度以提高速度,或选择较高维度以获得更好性能 +3. **资源优化**: 根据计算预算和性能要求选择合适的嵌入维度 +4. **兼容性**: 保持与现有代码的兼容性,通过配置参数控制是否启用MRL + +## 6. 验证 + +变更后,模型能够: + +1. 成功训练启用MRL的模型 +2. 在推理时生成不同维度的嵌入表示 +3. 保持与未启用MRL模型相当的性能 +4. 在不同维度之间提供平滑的性能权衡 diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..870a650 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -7,10 +7,12 @@ class Args: model_path: str experiment_id: str + # save dir output_dir: str tb_dir: str cache_dir: str + # training arguments train_data_path: str train_batch_size: int = 8 @@ -19,14 +21,22 @@ class Args: min_lr: float = 1e-6 weight_decay: float = 1e-2 warmup_steps: int = 100 + # embedding-related settings num_hard_neg: int = 7 + + # MRL settings + mrl_enabled: bool = False + mrl_dims: list = None + mrl_loss_weights: list = None + # train steps take precedence over epochs, set to -1 to disable train_steps: int = -1 train_epochs: int = 5 log_interval: int = 20 checkpointing_steps: int = 100 validation_steps: int = 100 + # just placeholder, for logging purpose num_processes: int=0 @@ -43,4 +53,11 @@ def parse_args(): args = Args(**config) args.output_dir = f"{args.output_dir}/{args.experiment_id}" args.tb_dir = f"{args.tb_dir}/{args.experiment_id}" + + # Set default values for MRL parameters if not specified + if args.mrl_dims is None: + args.mrl_dims = [128, 256, 512, 1024] + if args.mrl_loss_weights is None: + args.mrl_loss_weights = [1.0] * len(args.mrl_dims) + return args \ No newline at end of file diff --git a/F2LLM/configs/config.json b/F2LLM/configs/config.json index 2ac3708..a74677a 100644 --- a/F2LLM/configs/config.json +++ b/F2LLM/configs/config.json @@ -1,19 +1,23 @@ { - "model_path": "models/qwen3-4b", - "experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs", + "model_path": "models/qwen3-0.6b", + "experiment_id": "0.6b+lr.8e-6+bs.16x32+context.1024+2epochs", "train_data_path": "training_data/data_tokenized_qwen", "output_dir": "output", "tb_dir": "output/tb", "cache_dir": "cache", - "train_batch_size": 16, + "train_batch_size": 4, "checkpointing_steps": 5000, "validation_steps": 5000, - "max_seq_length": 1024, + "max_seq_length": 256, "learning_rate": 8e-6, "min_lr": 1e-7, "weight_decay": 0.01, "warmup_steps": 500, "train_epochs": 2, "log_interval": 100, - "num_hard_neg": 7 + "num_hard_neg": 7, + + "mrl_enabled": true, + "mrl_dims": [128, 256, 512, 1024], + "mrl_loss_weights": [1.0, 1.0, 1.0, 1.0] } diff --git a/F2LLM/model.py b/F2LLM/model.py index d33ade7..2a49578 100644 --- a/F2LLM/model.py +++ b/F2LLM/model.py @@ -1,5 +1,6 @@ import torch from transformers import AutoModel, AutoTokenizer +import torch.nn as nn class F2LLM: @@ -10,17 +11,71 @@ def __init__(self, ): self.args = args - self.dtype = torch.bfloat16 + self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 self.device = None # set after accelerator.prepare - self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2') + + # Check if CUDA is available and set the attention implementation accordingly + # Only use flash_attention_2 if CUDA is available and flash_attn is installed + attn_implementation = None + if torch.cuda.is_available(): + try: + import flash_attn + attn_implementation = 'flash_attention_2' + except ImportError: + attn_implementation = 'eager' # or 'sdpa' if available + + self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation=attn_implementation) self.lm.config.use_cache = False self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.max_seq_length = max_seq_length + + # MRL support + self.mrl_enabled = getattr(args, 'mrl_enabled', False) + if self.mrl_enabled: + self.mrl_dims = args.mrl_dims + # Create projection layers for each target dimension + hidden_size = self.lm.config.hidden_size + self.mrl_projections = nn.ModuleDict({ + str(dim): nn.Linear(hidden_size, dim) for dim in self.mrl_dims + }) + # Move projection layers to the same device as the model + self.mrl_projections.to(self.lm.device) def set_device(self): self.device = self.lm.device + # Move MRL projections to the correct device if they exist + if self.mrl_enabled: + self.mrl_projections.to(self.device) - def forward(self, batch): + def get_mrl_embeddings(self, full_embeddings, target_dim): + """Get embeddings for a specific target dimension""" + if not self.mrl_enabled: + return full_embeddings + + if target_dim == self.lm.config.hidden_size: + # No projection needed for full dimension + return full_embeddings + elif str(target_dim) in self.mrl_projections: + # Use projection layer + return self.mrl_projections[str(target_dim)](full_embeddings) + else: + # Fallback to truncation + return full_embeddings[:, :target_dim] + + def get_all_mrl_embeddings(self, full_embeddings): + """Get embeddings for all MRL dimensions""" + if not self.mrl_enabled: + return {str(self.lm.config.hidden_size): full_embeddings} + + embeddings_dict = {} + # Full dimension + embeddings_dict[str(self.lm.config.hidden_size)] = full_embeddings + # Projected dimensions + for dim in self.mrl_dims: + embeddings_dict[str(dim)] = self.get_mrl_embeddings(full_embeddings, dim) + return embeddings_dict + + def forward(self, batch, target_dim=None): bs = batch['bs'] num_hard_neg = int((len(batch['input_ids']) - 2*bs) / bs) @@ -29,9 +84,38 @@ def forward(self, batch): ) passage_features_all_tokens = outputs.last_hidden_state - return { - 'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]), - 'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]), - 'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1) - } + # Extract [CLS] token embeddings + full_embeddings = torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(len(batch['seq_lens']))]) + + if target_dim is not None: + # Return embeddings for specific dimension + embeddings = self.get_mrl_embeddings(full_embeddings, target_dim) + elif self.mrl_enabled: + # Return embeddings for all dimensions + embeddings_dict = self.get_all_mrl_embeddings(full_embeddings) + else: + # Return full dimension embeddings only + embeddings = full_embeddings + embeddings_dict = None + + # Split embeddings back to original format + if self.mrl_enabled and target_dim is None: + # Return dict with embeddings for all dimensions + result = {} + for dim, embs in embeddings_dict.items(): + result[f'query_passage_features_{dim}'] = embs[:bs] + result[f'passage_passage_features_{dim}'] = embs[bs:2*bs] + result[f'negative_passage_features_{dim}'] = None if num_hard_neg == 0 else embs[2*bs:].view(bs, num_hard_neg, -1) + return result + else: + # Return single dimension embeddings + query_embs = embeddings[:bs] if target_dim is not None or not self.mrl_enabled else embeddings_dict[str(self.lm.config.hidden_size)][:bs] + passage_embs = embeddings[bs:2*bs] if target_dim is not None or not self.mrl_enabled else embeddings_dict[str(self.lm.config.hidden_size)][bs:2*bs] + negative_embs = None if num_hard_neg == 0 else embeddings[2*bs:].view(bs, num_hard_neg, -1) if target_dim is not None or not self.mrl_enabled else embeddings_dict[str(self.lm.config.hidden_size)][2*bs:].view(bs, num_hard_neg, -1) + + return { + 'query_passage_features': query_embs, + 'passage_passage_features': passage_embs, + 'negative_passage_features': negative_embs + } diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..4d84ddb 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -120,7 +120,11 @@ def __iter__(self): accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************") model = F2LLM(args.model_path, args.max_seq_length, args=args) -model.lm.gradient_checkpointing_enable() + +# Only enable gradient checkpointing if CUDA is available +if torch.cuda.is_available(): + model.lm.gradient_checkpointing_enable() + # set seed again to make sure that different models share the same seed set_seed(0) @@ -134,7 +138,10 @@ def __iter__(self): num_warmup_steps=args.warmup_steps, num_training_steps=args.train_steps) -AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size +# Check if deepspeed plugin is available before accessing its config +if hasattr(AcceleratorState(), 'deepspeed_plugin') and AcceleratorState().deepspeed_plugin is not None: + AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size + model.lm, optimizer, lr_scheduler = accelerator.prepare( model.lm, optimizer, lr_scheduler ) @@ -148,6 +155,11 @@ def __iter__(self): args.train_steps = len(train_dataloader) * args.train_epochs accelerator.print(f"******************************** Training step after prepare: {args.train_steps} ********************************") +# Fix: Use the length of the first dataset or a default value if no datasets +train_datasets_dict = dict(train_datasets) +first_dataset_name = next(iter(train_datasets_dict)) if train_datasets_dict else None +dataset = train_datasets_dict[first_dataset_name] if first_dataset_name else None +num_train_samples = len(dataset) if dataset is not None else 0 accelerate_train(args, accelerator, model, train_dataloader, valid_loaders, - optimizer, lr_scheduler, len(dataset)) \ No newline at end of file + optimizer, lr_scheduler, num_train_samples) \ No newline at end of file diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..ec7980b 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from torch.nn import CrossEntropyLoss import os +import random CLASSIFICATION_DATASETS = ['amazon_counterfactual', 'amazon_polarity', 'imdb', 'toxic_conversations', 'cola'] CLUSTERING_DATASETS = ['amazon_reviews', 'banking77', 'emotion', 'mtop_intent', 'mtop_domain', 'massive_scenario', 'massive_intent', 'tweet_sentiment_extraction', 'arxiv_clustering_p2p', 'arxiv_clustering_s2s', 'biorxiv_clustering_p2p', 'biorxiv_clustering_s2s', 'medrxiv_clustering_p2p', 'medrxiv_clustering_s2s', 'reddit_clustering_p2p', 'reddit_clustering_s2s', 'stackexchange_clustering_p2p', 'stackexchange_clustering_s2s', 'twentynewsgroups'] @@ -65,7 +66,7 @@ def hard_loss( ): if hard_neg_embeddings is None: - return 0.0 + return torch.tensor(0.0, device=query_embeddings.device) bs = query_embeddings.size(0) a_norm = F.normalize(query_embeddings, p=2, dim=-1) @@ -83,25 +84,122 @@ def hard_loss( return loss_hard +def mrl_inbatch_loss( + query_embeddings_dict, # dict of [bs, d] + context_embeddings_dict, # dict of [bs, d] + criterion, + accelerator, + mrl_dims, + mrl_loss_weights, + temperature=0.05, + ): + """Compute MRL loss for in-batch negative sampling across multiple dimensions""" + total_loss = 0.0 + + for i, (dim, weight) in enumerate(zip(mrl_dims, mrl_loss_weights)): + query_emb = query_embeddings_dict[str(dim)] + context_emb = context_embeddings_dict[str(dim)] + + loss = inbatch_loss(query_emb, context_emb, criterion, accelerator, temperature) + total_loss += weight * loss + + return total_loss + + +def mrl_hard_loss( + query_embeddings_dict, # dict of [bs, d] + context_embeddings_dict, # dict of [bs, d] + hard_neg_embeddings_dict, # dict of [bs, num, d] + criterion, + accelerator, + mrl_dims, + mrl_loss_weights, + temperature=0.05, + ): + """Compute MRL loss for hard negative sampling across multiple dimensions""" + total_loss = 0.0 + + for i, (dim, weight) in enumerate(zip(mrl_dims, mrl_loss_weights)): + query_emb = query_embeddings_dict[str(dim)] + context_emb = context_embeddings_dict[str(dim)] + hard_neg_emb = hard_neg_embeddings_dict.get(str(dim), None) + + loss = hard_loss(query_emb, context_emb, hard_neg_emb, criterion, accelerator, temperature) + total_loss += weight * loss + + return total_loss + + def validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer): eval_log_dict = {} for dataset_name, valid_dataloader in valid_loader_dict.items(): loss_ls, loss_hard_ls = [], [] + + # For MRL, we'll validate on the full dimension + target_dim = model.lm.config.hidden_size if model.mrl_enabled else None + for batch in valid_dataloader: with torch.no_grad(): - outputs = model.forward(batch) - loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) - loss_hard_ls.append(accelerator.gather(loss_hard).float()) + outputs = model.forward(batch, target_dim=target_dim) + + if model.mrl_enabled: + # Extract embeddings for full dimension + dim_key = str(model.lm.config.hidden_size) + # Check if the dimension key exists in outputs, if not, use the default key + if f'query_passage_features_{dim_key}' in outputs: + query_emb = outputs[f'query_passage_features_{dim_key}'].squeeze(1) + passage_emb = outputs[f'passage_passage_features_{dim_key}'].squeeze(1) + neg_emb = outputs.get(f'negative_passage_features_{dim_key}', None) + else: + # Fallback to default keys + query_emb = outputs['query_passage_features'].squeeze(1) + passage_emb = outputs['passage_passage_features'].squeeze(1) + neg_emb = outputs.get('negative_passage_features', None) + else: + query_emb = outputs['query_passage_features'].squeeze(1) + passage_emb = outputs['passage_passage_features'].squeeze(1) + neg_emb = outputs['negative_passage_features'] + + loss_hard = hard_loss(query_emb, passage_emb, neg_emb, criterion, accelerator) + if accelerator.num_processes > 1: + # When using multiple processes, we need to gather the loss from all processes + gathered_loss_hard = accelerator.gather(loss_hard).float() + # Ensure gathered_loss_hard is at least 1D + if gathered_loss_hard.dim() == 0: + # Scalar tensor, convert to 1D tensor with one element + gathered_loss_hard = gathered_loss_hard.unsqueeze(0) + loss_hard_ls.append(gathered_loss_hard) + else: + # When using a single process, just append the loss directly + # Ensure loss_hard is at least 1D + if loss_hard.dim() == 0: + loss_hard = loss_hard.unsqueeze(0) + loss_hard_ls.append(loss_hard.float()) if dataset_name in RETRIEVAL_DATASETS: - loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) - loss_ls.append(accelerator.gather(loss).float()) + loss = inbatch_loss(query_emb, passage_emb, criterion, accelerator) + if accelerator.num_processes > 1: + # When using multiple processes, we need to gather the loss from all processes + gathered_loss = accelerator.gather(loss).float() + # Ensure gathered_loss is at least 1D + if gathered_loss.dim() == 0: + # Scalar tensor, convert to 1D tensor with one element + gathered_loss = gathered_loss.unsqueeze(0) + loss_ls.append(gathered_loss) + else: + # When using a single process, just append the loss directly + # Ensure loss is at least 1D + if loss.dim() == 0: + loss = loss.unsqueeze(0) + loss_ls.append(loss.float()) accelerator.wait_for_everyone() - loss_hard_ls = torch.cat(loss_hard_ls) - eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean() + if loss_hard_ls: # Check if the list is not empty + loss_hard_ls = torch.cat(loss_hard_ls) + eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean() if dataset_name in RETRIEVAL_DATASETS: - loss_ls = torch.cat(loss_ls) - eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = loss_ls.mean() + if loss_ls: # Check if the list is not empty + loss_ls = torch.cat(loss_ls) + eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = loss_ls.mean() eval_log_dict['Avg/retrieval/valid_loss_in_batch'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_in_batch')]).mean() eval_log_dict['Avg/retrieval/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_hard')]).mean() @@ -147,21 +245,59 @@ def accelerate_train(args, accelerator.print(f"*************** Starting epoch {epoch+1} ***************") train_dataloader.reset_epoch(epoch) for batch in train_dataloader: - # forward and compute loss - outputs = model.forward(batch) - # passage features: [bs, 1, d] - # hard_neg_features: [bs, num_hard_neg, d] - - loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) + # For MRL, randomly select a dimension for this batch + if model.mrl_enabled and args.mrl_dims: + # Randomly select a dimension for this batch + target_dim = random.choice([model.lm.config.hidden_size] + args.mrl_dims) + outputs = model.forward(batch, target_dim=target_dim) + + # Extract embeddings for the selected dimension + if target_dim == model.lm.config.hidden_size: + dim_key = str(target_dim) + # Check if the dimension key exists in outputs, if not, use the default key + if f'query_passage_features_{dim_key}' in outputs: + query_emb = outputs[f'query_passage_features_{dim_key}'].squeeze(1) + passage_emb = outputs[f'passage_passage_features_{dim_key}'].squeeze(1) + neg_emb = outputs.get(f'negative_passage_features_{dim_key}', None) + else: + # Fallback to default keys + query_emb = outputs['query_passage_features'].squeeze(1) + passage_emb = outputs['passage_passage_features'].squeeze(1) + neg_emb = outputs.get('negative_passage_features', None) + else: + dim_key = str(target_dim) + # Check if the dimension key exists in outputs, if not, use the default key + if f'query_passage_features_{dim_key}' in outputs: + query_emb = outputs[f'query_passage_features_{dim_key}'].squeeze(1) + passage_emb = outputs[f'passage_passage_features_{dim_key}'].squeeze(1) + neg_emb = outputs.get(f'negative_passage_features_{dim_key}', None) + else: + # Fallback to default keys + query_emb = outputs['query_passage_features'].squeeze(1) + passage_emb = outputs['passage_passage_features'].squeeze(1) + neg_emb = outputs.get('negative_passage_features', None) + + # Compute loss for the selected dimension + loss_hard = hard_loss(query_emb, passage_emb, neg_emb, criterion, accelerator) + if batch['dataset_name'] in RETRIEVAL_DATASETS: + loss = inbatch_loss(query_emb, passage_emb, criterion, accelerator) + else: + loss = torch.tensor(0.0, device=query_emb.device) + else: + # Standard training without MRL + outputs = model.forward(batch) + loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) + if batch['dataset_name'] in RETRIEVAL_DATASETS: + loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) + else: + loss = torch.tensor(0.0, device=outputs['query_passage_features'].squeeze(1).device) + dataset_name = batch['dataset_name'] count_hard_dict[dataset_name] += 1 loss_hard_dict[dataset_name] += loss_hard.detach().float() if dataset_name in RETRIEVAL_DATASETS: - loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) count_dict[dataset_name] += 1 loss_dict[dataset_name] += loss.detach().float() - else: - loss = 0.0 loss_total = loss + loss_hard