From 5f96b33002d05c0a8e0b017e6621c3b80357816c Mon Sep 17 00:00:00 2001 From: "luanzhi.xxl" Date: Tue, 25 Nov 2025 20:08:33 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=9B=B4=E5=A4=9A=E7=9A=84mo?= =?UTF-8?q?del?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- F2LLM/MULTI_MODEL_GUIDE.md | 199 +++++++++++++++++++++++++++++ F2LLM/arguments.py | 4 + F2LLM/configs/config.json | 5 +- F2LLM/configs/config_gpt_demo.json | 22 ++++ F2LLM/model.py | 98 +++++++++++++- F2LLM/run.py | 17 ++- F2LLM/tokenize_data.py | 184 ++++++++++++++++++++++++++ F2LLM/utils.py | 24 +++- 8 files changed, 539 insertions(+), 14 deletions(-) create mode 100644 F2LLM/MULTI_MODEL_GUIDE.md create mode 100644 F2LLM/configs/config_gpt_demo.json create mode 100644 F2LLM/tokenize_data.py diff --git a/F2LLM/MULTI_MODEL_GUIDE.md b/F2LLM/MULTI_MODEL_GUIDE.md new file mode 100644 index 0000000..f1276d5 --- /dev/null +++ b/F2LLM/MULTI_MODEL_GUIDE.md @@ -0,0 +1,199 @@ +# F2LLM 多模型支持使用指南 + +## 概述 + +修改后的F2LLM现在支持多种decoder-only模型,包括Qwen、LLaMA、Baichuan、ChatGLM等系列模型。 + +## 支持的模型 + +### 已测试模型 +- **Qwen系列**: Qwen-7B, Qwen-14B, Qwen3-4B等 +- **LLaMA系列**: LLaMA-7B, LLaMA2-13B等 +- **Baichuan系列**: Baichuan-13B, Baichuan2-13B等 +- **ChatGLM系列**: ChatGLM-6B, ChatGLM2-6B等 + +### 理论支持的模型 +任何基于transformers库的decoder-only模型都应该可以工作,包括: +- GPT系列 +- CodeT5+ +- CodeGen +- StarCoder +- 以及其他自定义decoder-only模型 + +## 使用方法 + +### 1. 模型配置 + +修改配置文件 `configs/config.json`: + +```json +{ + "model_path": "your-model-path", + "model_type": "auto", // 可选: auto, qwen, llama, baichuan等 + "attn_implementation": "flash_attention_2", // flash_attention_2, sdpa, null + "use_flash_attention": true, + // ... 其他配置 +} +``` + +#### 配置说明 + +- **model_path**: 模型路径或HuggingFace模型名称 +- **model_type**: 模型类型,用于自动适配特殊处理 +- **attn_implementation**: 注意力实现方式 + - `"flash_attention_2"`: 使用Flash Attention 2(最快,但需要支持) + - `"sdpa"`: 使用PyTorch的Scaled Dot Product Attention + - `null`: 不使用特殊注意力实现 +- **use_flash_attention**: 是否尝试使用flash attention + +### 2.获取训练数据 +#### 方案1:使用huggingface-cli + +如果您想使用原始的huggingface-cli命令: + +```bash +# 安装huggingface-hub +pip install huggingface-hub + +# 从huggingface中下载训练数据,若遇网络问题,可以考虑使用镜像 +export HF_ENDPOINT=https://hf-mirror.com +python -m huggingface_hub.cli download codefuse-ai/F2LLM --repo-type dataset --local-dir training_data --include "*.parquet" +``` + +#### 方案2:手动下载 + +1. 访问网站:https://huggingface.co/datasets/codefuse-ai/F2LLM +2. 手动下载.parquet文件 +3. 保存到 `training_data/` 目录 + +### 3. 数据预处理 + +使用通用分词脚本处理数据: + +```bash +# 基础用法 +python tokenize_data.py --model_path "meta-llama/Llama-2-7b-hf" --max_seq_length 1023 + +# 完整参数 +python tokenize_data.py \ + --model_path "baichuan-inc/Baichuan2-13B-Base" \ + --max_seq_length 1023 \ + --data_dir "training_data" \ + --output_dir "data_tokenized" \ + --num_processes 16 +``` + +### 4. 训练 + +```bash +# 单GPU训练 +accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json + +# 多GPU训练 +accelerate launch --config_file configs/accelerate_config.yaml --num_processes 8 run.py --config configs/config.json +``` + +## 模型特定配置 + +### LLaMA模型 +```json +{ + "model_path": "meta-llama/Llama-2-7b-hf", + "model_type": "llama", + "attn_implementation": "sdpa", + "use_flash_attention": true, + "max_seq_length": 2048 +} +``` + +### Baichuan模型 +```json +{ + "model_path": "baichuan-inc/Baichuan2-13B-Base", + "model_type": "baichuan", + "attn_implementation": "flash_attention_2", + "use_flash_attention": true, + "max_seq_length": 2048 +} +``` + +### ChatGLM模型 +```json +{ + "model_path": "THUDM/chatglm3-6b-base", + "model_type": "chatglm", + "attn_implementation": null, + "use_flash_attention": false, + "max_seq_length": 2048 +} +``` + +## 故障排除 + +### 常见问题 + +1. **Flash Attention不支持** + - 错误信息: `FlashAttention only supports Ampere GPUs or newer.` + - 解决: 设置 `"use_flash_attention": false` 或 `"attn_implementation": "sdpa"` + +2. **内存不足** + - 减小 `train_batch_size` + - 减小 `max_seq_length` + - 使用梯度累积 + +3. **模型加载失败** + - 确保模型路径正确 + - 检查网络连接(如果是HF模型) + - 查看具体的错误信息,调整注意力配置 + +### 调试建议 + +1. **逐步测试** + ```bash + # 先测试模型加载 + python -c "from transformers import AutoModel; model = AutoModel.from_pretrained('your-model')" + + # 再测试分词 + python tokenize_data.py --model_path "your-model" --num_processes 1 + ``` + +2. **查看日志** + - 修改后的代码会输出详细的加载信息 + - 关注警告信息,它们通常包含有用的回退信息 + +3. **性能优化** + - 优先使用Flash Attention 2(如果硬件支持) + - 使用SDPA作为第二选择 + - 禁用特殊注意力实现作为最后手段 + +## 性能对比 + +| 模型 | 注意力实现 | 训练速度 | 内存使用 | 兼容性 | +|------|------------|----------|----------|---------| +| Qwen3-4B | flash_attention_2 | ★★★★★ | ★★★★★ | ★★★★☆ | +| LLaMA2-7B | sdpa | ★★★★☆ | ★★★★☆ | ★★★★★ | +| Baichuan2-13B | flash_attention_2 | ★★★★★ | ★★★★☆ | ★★★☆☆ | +| ChatGLM3-6B | default | ★★★☆☆ | ★★★☆☆ | ★★★★★ | + +## 扩展支持 + +如果需要支持新的模型类型,可以: + +1. 在 `model.py` 中添加模型特定的处理逻辑 +2. 在配置文件中添加相应的模型类型标识 +3. 测试并验证兼容性 + +## 注意事项 + +1. **模型许可**: 确保你有权使用指定的模型 +2. **硬件要求**: 大型模型需要更多GPU内存 +3. **数据格式**: 确保训练数据格式与模型要求一致 +4. **分词器兼容性**: 不同模型可能使用不同的分词器 + +## 技术支持 + +如遇到问题,请提供以下信息: +- 模型名称和版本 +- 完整的错误日志 +- 硬件配置(GPU型号、内存等) +- 配置文件内容 diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..62aa807 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -27,6 +27,10 @@ class Args: log_interval: int = 20 checkpointing_steps: int = 100 validation_steps: int = 100 + # model configuration + model_type: str = "auto" # auto, qwen, llama, baichuan, etc. + attn_implementation: str = "flash_attention_2" # flash_attention_2, sdpa, None + use_flash_attention: bool = True # just placeholder, for logging purpose num_processes: int=0 diff --git a/F2LLM/configs/config.json b/F2LLM/configs/config.json index 2ac3708..105bbe2 100644 --- a/F2LLM/configs/config.json +++ b/F2LLM/configs/config.json @@ -1,5 +1,6 @@ { "model_path": "models/qwen3-4b", + "model_type": "qwen", "experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs", "train_data_path": "training_data/data_tokenized_qwen", "output_dir": "output", @@ -15,5 +16,7 @@ "warmup_steps": 500, "train_epochs": 2, "log_interval": 100, - "num_hard_neg": 7 + "num_hard_neg": 7, + "attn_implementation": "flash_attention_2", + "use_flash_attention": true } diff --git a/F2LLM/configs/config_gpt_demo.json b/F2LLM/configs/config_gpt_demo.json new file mode 100644 index 0000000..05914e5 --- /dev/null +++ b/F2LLM/configs/config_gpt_demo.json @@ -0,0 +1,22 @@ +{ + "model_path": "microsoft/DialoGPT-medium", + "model_type": "gpt2", + "experiment_id": "gpt-final-fix", + "train_data_path": "data_tokenized/data_tokenized_DialoGPT-medium", + "output_dir": "output", + "tb_dir": "output/tb", + "cache_dir": "cache", + "train_batch_size": 1, + "checkpointing_steps": 10, + "validation_steps": 10, + "max_seq_length": 128, + "learning_rate": 1e-4, + "min_lr": 1e-6, + "weight_decay": 0.01, + "warmup_steps": 5, + "train_epochs": 1, + "log_interval": 1, + "num_hard_neg": 1, + "attn_implementation": null, + "use_flash_attention": false +} diff --git a/F2LLM/model.py b/F2LLM/model.py index d33ade7..e94120c 100644 --- a/F2LLM/model.py +++ b/F2LLM/model.py @@ -1,5 +1,6 @@ import torch -from transformers import AutoModel, AutoTokenizer +from transformers import AutoModel, AutoTokenizer, GPT2LMHeadModel, AutoModelForCausalLM +import warnings class F2LLM: @@ -12,9 +13,80 @@ def __init__(self, self.args = args self.dtype = torch.bfloat16 self.device = None # set after accelerator.prepare - self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2') + + # 根据配置选择注意力实现方式 + attn_implementation = getattr(args, 'attn_implementation', 'flash_attention_2') if args else 'flash_attention_2' + use_flash_attention = getattr(args, 'use_flash_attention', True) if args else True + + # 尝试加载模型,支持多种decoder-only模型 + try: + if use_flash_attention and attn_implementation: + # 使用配置的注意力实现 + self.lm = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=self.dtype, + attn_implementation=attn_implementation + ) + else: + # 不使用特殊注意力实现 + self.lm = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=self.dtype + ) + except Exception as e: + if use_flash_attention and attn_implementation: + warnings.warn(f"Failed to load model with {attn_implementation}: {e}. Trying fallback options...") + + # 回退策略 + fallback_options = ['sdpa', None] # 尝试sdpa,然后是不使用特殊注意力 + loaded = False + + for fallback_attn in fallback_options: + try: + if fallback_attn: + self.lm = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=self.dtype, + attn_implementation=fallback_attn + ) + else: + self.lm = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=self.dtype + ) + warnings.warn(f"Successfully loaded model with {fallback_attn or 'default'} attention") + loaded = True + break + except Exception as e2: + warnings.warn(f"Failed to load model with {fallback_attn or 'default'} attention: {e2}") + continue + + if not loaded: + raise RuntimeError(f"Failed to load model {model_path} with any attention implementation") + self.lm.config.use_cache = False - self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + # 加载分词器,添加trust_remote_code支持更多模型 + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + padding_side='right' # 大多数decoder-only模型需要右侧填充 + ) + + # 确保分词器有pad_token + if self.tokenizer.pad_token is None: + if self.tokenizer.eos_token is not None: + self.tokenizer.pad_token = self.tokenizer.eos_token + else: + # 添加新的pad_token + self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + # 需要调整模型embedding大小 + self.lm.resize_token_embeddings(len(self.tokenizer)) + self.max_seq_length = max_seq_length def set_device(self): @@ -24,11 +96,23 @@ def forward(self, batch): bs = batch['bs'] num_hard_neg = int((len(batch['input_ids']) - 2*bs) / bs) - outputs = self.lm(batch['input_ids'], - batch['attention_mask'], - ) + outputs = self.lm( + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + return_dict=True, + output_hidden_states=True + ) + + # 对于CausalLM模型,获取最后一层的隐藏状态 + if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None: + # hidden_states是一个元组,包含所有层的隐藏状态 + passage_features_all_tokens = outputs.hidden_states[-1] + elif hasattr(outputs, 'last_hidden_state'): + passage_features_all_tokens = outputs.last_hidden_state + else: + # 回退到使用transformer的输出 + passage_features_all_tokens = outputs[0] - passage_features_all_tokens = outputs.last_hidden_state return { 'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]), 'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]), diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..7fda26a 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -62,6 +62,8 @@ def collate_fn(batch_raw): train_datasets, valid_datasets = [], [] for f in sorted(os.listdir(args.train_data_path)): + if not f.endswith('.parquet'): # 只处理parquet文件 + continue dataset_name = f.split('.parquet')[0] dataset = load_dataset("parquet", data_files=os.path.join(args.train_data_path, f), cache_dir=args.cache_dir)['train'] dataset = dataset.add_column("dataset_name", [dataset_name]*len(dataset)) @@ -71,6 +73,16 @@ def collate_fn(batch_raw): tokenizer = AutoTokenizer.from_pretrained(args.model_path) +# 确保有pad_token +if tokenizer.pad_token is None: + if tokenizer.eos_token is not None: + tokenizer.pad_token = tokenizer.eos_token + else: + # 对于没有eos_token的模型,添加特殊token + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + # 需要调整模型embedding大小 + model.lm.resize_token_embeddings(len(tokenizer)) + train_loaders = { name: DataLoader(ds, shuffle=True, batch_size=args.train_batch_size, collate_fn=collate_fn) for name, ds in train_datasets @@ -134,7 +146,10 @@ def __iter__(self): num_warmup_steps=args.warmup_steps, num_training_steps=args.train_steps) -AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size +# 检查是否使用DeepSpeed +if accelerator.state.deepspeed_plugin is not None: + accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size + model.lm, optimizer, lr_scheduler = accelerator.prepare( model.lm, optimizer, lr_scheduler ) diff --git a/F2LLM/tokenize_data.py b/F2LLM/tokenize_data.py new file mode 100644 index 0000000..444f49b --- /dev/null +++ b/F2LLM/tokenize_data.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +""" +通用的数据分词脚本,支持多种decoder-only模型 +使用方法: python tokenize_data.py --model_path <模型路径> --max_seq_length <最大序列长度> +""" + +import argparse +import os +from multiprocessing import Pool +import numpy as np +import pandas as pd +from transformers import AutoTokenizer +from tqdm.auto import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser(description='Tokenize data for various decoder-only models') + parser.add_argument('--model_path', type=str, required=True, + help='Path to the model or model name on HuggingFace') + parser.add_argument('--max_seq_length', type=int, default=1023, + help='Maximum sequence length for tokenization') + parser.add_argument('--data_dir', type=str, default='training_data', + help='Directory containing training data') + parser.add_argument('--output_dir', type=str, default='training_data', + help='Directory to save tokenized data') + parser.add_argument('--num_processes', type=int, default=8, + help='Number of processes for parallel processing') + return parser.parse_args() + + +def create_tokenizer(model_path, max_seq_length): + """创建并配置分词器""" + print(f"Loading tokenizer from {model_path}...") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + padding_side='right' + ) + + # 确保分词器有pad_token + if tokenizer.pad_token is None: + if tokenizer.eos_token is not None: + tokenizer.pad_token = tokenizer.eos_token + print(f"Set pad_token to eos_token: {tokenizer.pad_token}") + else: + # 添加新的pad_token + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + print(f"Added new pad_token: {tokenizer.pad_token}") + + print(f"Tokenizer loaded: {tokenizer.__class__.__name__}") + print(f"Vocab size: {tokenizer.vocab_size}") + print(f"EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})") + print(f"PAD token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})") + + return tokenizer, max_seq_length + + +def process_sent(sentence, tokenizer=None, max_seq_length=None): + """处理单个句子,添加eos token并截断""" + if tokenizer is None: + raise ValueError("Tokenizer is required") + + # 分词,不添加特殊token,因为我们手动添加eos + tokenizer_outputs = tokenizer( + sentence, + max_length=max_seq_length, + truncation=True, + add_special_tokens=False + ) + + # 添加eos token + input_ids = tokenizer_outputs.input_ids + [tokenizer.eos_token_id] + + return np.array(input_ids) + + +def process_sent_batch(s, tokenizer=None, max_seq_length=None): + """批量处理句子""" + return s.apply(lambda x: process_sent(x, tokenizer, max_seq_length)) + + +def parallelize(data, func, num_of_processes=8, **kwargs): + """并行处理数据""" + indices = np.array_split(data.index, num_of_processes) + data_split = [data.iloc[idx] for idx in indices] + + with Pool(num_of_processes) as pool: + # 使用starmap传递多个参数 + data = pd.concat(pool.starmap(func, [(d, ) + tuple(kwargs.values()) for d in data_split])) + return data + + +def main(): + args = parse_args() + + # 创建输出目录 + os.makedirs(args.output_dir, exist_ok=True) + + # 创建分词器 + tokenizer, max_seq_length = create_tokenizer(args.model_path, args.max_seq_length) + + # 获取模型名称用于输出目录 + model_name = os.path.basename(args.model_path.rstrip('/')) + output_subdir = os.path.join(args.output_dir, f"data_tokenized_{model_name}") + os.makedirs(output_subdir, exist_ok=True) + + print(f"Processing data from {args.data_dir}...") + print(f"Output directory: {output_subdir}") + + # 处理所有数据集 + for ds_name in tqdm(sorted(os.listdir(args.data_dir)), desc="Processing datasets"): + if not ds_name.endswith('.parquet'): + continue + + print(f"\nProcessing {ds_name}...") + + # 读取数据 + df = pd.read_parquet(os.path.join(args.data_dir, ds_name)) + + # 处理查询 + print("Processing queries...") + df['query_input_ids'] = parallelize( + df['query'], + process_sent_batch, + args.num_processes, + tokenizer=tokenizer, + max_seq_length=max_seq_length + ) + + # 确定负样本数量 + num_neg = 24 if 'negative_2' in df.columns else 1 + print(f"Number of negative samples: {num_neg}") + + # 收集所有passage和负样本 + all_passages = df['passage'].to_list() + for i in range(1, num_neg + 1): + if f'negative_{i}' in df.columns: + all_passages += df[f'negative_{i}'].to_list() + + # 去重 + all_passages = list(set(all_passages)) + + # 创建临时DataFrame处理passage + df_tmp = pd.DataFrame({'text': all_passages}) + print(f"Processing {len(all_passages)} unique passages...") + + df_tmp['input_ids'] = parallelize( + df_tmp['text'], + process_sent_batch, + args.num_processes, + tokenizer=tokenizer, + max_seq_length=max_seq_length + ) + + # 设置索引以便映射 + df_tmp = df_tmp.set_index('text') + + # 映射passage的input_ids + print("Mapping passages...") + df['passage_input_ids'] = df['passage'].map(df_tmp['input_ids']) + + # 映射负样本的input_ids + for i in range(1, num_neg + 1): + neg_col = f'negative_{i}' + neg_input_col = f'negative_{i}_input_ids' + if neg_col in df.columns: + df[neg_input_col] = df[neg_col].map(df_tmp['input_ids']) + + # 保存结果 + output_path = os.path.join(output_subdir, ds_name) + df.to_parquet(output_path, index=False) + print(f"Saved to {output_path}") + + # 打印统计信息 + print(f"Dataset size: {len(df)}") + print(f"Query avg length: {df['query_input_ids'].apply(len).mean():.1f}") + print(f"Passage avg length: {df['passage_input_ids'].apply(len).mean():.1f}") + + print(f"\nAll datasets processed successfully!") + print(f"Tokenized data saved to: {output_subdir}") + + +if __name__ == "__main__": + main() diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..c7797af 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -91,17 +91,31 @@ def validate(args, accelerator, model, valid_loader_dict, criterion, completed_s with torch.no_grad(): outputs = model.forward(batch) loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) - loss_hard_ls.append(accelerator.gather(loss_hard).float()) + loss_hard = accelerator.gather(loss_hard).float() + # 确保loss_hard是至少一维的张量 + if loss_hard.dim() == 0: + loss_hard = loss_hard.unsqueeze(0) + loss_hard_ls.append(loss_hard) if dataset_name in RETRIEVAL_DATASETS: loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) - loss_ls.append(accelerator.gather(loss).float()) + loss = accelerator.gather(loss).float() + # 确保loss是至少一维的张量 + if loss.dim() == 0: + loss = loss.unsqueeze(0) + loss_ls.append(loss) accelerator.wait_for_everyone() - loss_hard_ls = torch.cat(loss_hard_ls) - eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean() - if dataset_name in RETRIEVAL_DATASETS: + if loss_hard_ls: + loss_hard_ls = torch.cat(loss_hard_ls) + eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean() + else: + eval_log_dict[f'{dataset_name}/valid_loss_hard'] = torch.tensor(0.0) + + if dataset_name in RETRIEVAL_DATASETS and loss_ls: loss_ls = torch.cat(loss_ls) eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = loss_ls.mean() + elif dataset_name in RETRIEVAL_DATASETS: + eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = torch.tensor(0.0) eval_log_dict['Avg/retrieval/valid_loss_in_batch'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_in_batch')]).mean() eval_log_dict['Avg/retrieval/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_hard')]).mean()