Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions F2LLM/MRL需求文档.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Matryoshka Representation Learning (MRL) 支持需求文档

## 1. 背景

在训练CodeFuse-Embeddings模型时,为了提供更大的灵活性以适应不同的下游应用和计算预算,我们实现了Matryoshka Representation Learning (MRL)支持。MRL是一种"俄罗斯套娃"式的训练方法,允许单个模型在推理时产生不同维度的高质量嵌入(例如64、128、256、512、1024等),从而为不同的应用场景提供显著的灵活性。

## 2. 需求目标

为CodeFuse-Embeddings模型增加MRL支持,使得模型能够在训练时学习多个嵌入维度的表示,并在推理时根据需要选择合适的嵌入维度,以平衡性能和计算效率。

## 3. 技术实现

### 3.1 MRL核心概念

Matryoshka Representation Learning是一种训练方法,它允许模型在不同维度上产生嵌入表示,其中较低维度的嵌入是较高维度嵌入的子集。这种方法通过以下方式实现:

1. 在训练过程中,模型同时学习多个目标维度的表示
2. 使用投影层将完整维度的嵌入映射到目标维度
3. 在损失计算时,同时考虑所有目标维度的损失

### 3.2 实现细节

#### 3.2.1 配置参数

在`F2LLM/configs/config.json`中增加了以下MRL相关配置参数:

- `mrl_enabled`: 是否启用MRL(布尔值,默认为false)
- `mrl_dims`: MRL目标维度列表(数组,默认为[128, 256, 512, 1024])
- `mrl_loss_weights`: 每个维度的损失权重(数组,默认为[1.0, 1.0, 1.0, 1.0])

#### 3.2.2 模型修改

在`F2LLM/model.py`中实现了以下MRL相关功能:

1. `F2LLM`类中增加了MRL支持:
- 添加了`mrl_enabled`标志来控制是否启用MRL
- 创建了针对每个目标维度的投影层(`mrl_projections`)
- 实现了`get_mrl_embeddings`方法用于获取特定维度的嵌入
- 实现了`get_all_mrl_embeddings`方法用于获取所有维度的嵌入

2. `forward`方法修改:
- 增加了`target_dim`参数用于指定目标嵌入维度
- 根据是否启用MRL返回不同维度的嵌入表示
- 支持返回所有维度的嵌入字典

#### 3.2.3 训练过程修改

在`F2LLM/utils.py`中实现了以下MRL相关功能:

1. 增加了MRL损失计算函数:
- `mrl_inbatch_loss`: 计算MRL的批次内负采样损失
- `mrl_hard_loss`: 计算MRL的硬负样本损失

2. 修改了训练循环:
- 在每个训练批次中随机选择一个目标维度
- 根据选择的维度计算相应的损失

3. 修改了验证过程:
- 在验证时使用完整维度进行评估

#### 3.2.4 参数定义

在`F2LLM/arguments.py`中增加了MRL相关参数定义:

- `mrl_enabled`: 是否启用MRL
- `mrl_dims`: MRL目标维度列表
- `mrl_loss_weights`: 每个维度的损失权重

## 4. 使用方法

### 4.1 启用MRL训练

在配置文件`F2LLM/configs/config.json`中设置:

```json
{
"mrl_enabled": true,
"mrl_dims": [128, 256, 512, 1024],
"mrl_loss_weights": [1.0, 1.0, 1.0, 1.0]
}
```

### 4.2 启动训练

```bash
cd F2LLM
python run.py --config configs/config.json
```

### 4.3 推理时使用不同维度

在推理时,可以通过指定`target_dim`参数来获取特定维度的嵌入:

```python
# 获取特定维度的嵌入
outputs = model.forward(batch, target_dim=512)

# 获取所有维度的嵌入
outputs = model.forward(batch)
```

## 5. 优势

1. **灵活性**: 单个模型可以生成多种维度的嵌入,适应不同应用场景
2. **效率**: 在推理时可以选择较低维度以提高速度,或选择较高维度以获得更好性能
3. **资源优化**: 根据计算预算和性能要求选择合适的嵌入维度
4. **兼容性**: 保持与现有代码的兼容性,通过配置参数控制是否启用MRL

## 6. 验证

变更后,模型能够:

1. 成功训练启用MRL的模型
2. 在推理时生成不同维度的嵌入表示
3. 保持与未启用MRL模型相当的性能
4. 在不同维度之间提供平滑的性能权衡
17 changes: 17 additions & 0 deletions F2LLM/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ class Args:

model_path: str
experiment_id: str

# save dir
output_dir: str
tb_dir: str
cache_dir: str

# training arguments
train_data_path: str
train_batch_size: int = 8
Expand All @@ -19,14 +21,22 @@ class Args:
min_lr: float = 1e-6
weight_decay: float = 1e-2
warmup_steps: int = 100

# embedding-related settings
num_hard_neg: int = 7

# MRL settings
mrl_enabled: bool = False
mrl_dims: list = None
mrl_loss_weights: list = None

# train steps take precedence over epochs, set to -1 to disable
train_steps: int = -1
train_epochs: int = 5
log_interval: int = 20
checkpointing_steps: int = 100
validation_steps: int = 100

# just placeholder, for logging purpose
num_processes: int=0

Expand All @@ -43,4 +53,11 @@ def parse_args():
args = Args(**config)
args.output_dir = f"{args.output_dir}/{args.experiment_id}"
args.tb_dir = f"{args.tb_dir}/{args.experiment_id}"

# Set default values for MRL parameters if not specified
if args.mrl_dims is None:
args.mrl_dims = [128, 256, 512, 1024]
if args.mrl_loss_weights is None:
args.mrl_loss_weights = [1.0] * len(args.mrl_dims)

return args
14 changes: 9 additions & 5 deletions F2LLM/configs/config.json
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
{
"model_path": "models/qwen3-4b",
"experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs",
"model_path": "models/qwen3-0.6b",
"experiment_id": "0.6b+lr.8e-6+bs.16x32+context.1024+2epochs",
"train_data_path": "training_data/data_tokenized_qwen",
"output_dir": "output",
"tb_dir": "output/tb",
"cache_dir": "cache",
"train_batch_size": 16,
"train_batch_size": 4,
"checkpointing_steps": 5000,
"validation_steps": 5000,
"max_seq_length": 1024,
"max_seq_length": 256,
"learning_rate": 8e-6,
"min_lr": 1e-7,
"weight_decay": 0.01,
"warmup_steps": 500,
"train_epochs": 2,
"log_interval": 100,
"num_hard_neg": 7
"num_hard_neg": 7,

"mrl_enabled": true,
"mrl_dims": [128, 256, 512, 1024],
"mrl_loss_weights": [1.0, 1.0, 1.0, 1.0]
}
100 changes: 92 additions & 8 deletions F2LLM/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn


class F2LLM:
Expand All @@ -10,17 +11,71 @@ def __init__(self,
):

self.args = args
self.dtype = torch.bfloat16
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self.device = None # set after accelerator.prepare
self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2')

# Check if CUDA is available and set the attention implementation accordingly
# Only use flash_attention_2 if CUDA is available and flash_attn is installed
attn_implementation = None
if torch.cuda.is_available():
try:
import flash_attn
attn_implementation = 'flash_attention_2'
except ImportError:
attn_implementation = 'eager' # or 'sdpa' if available

self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation=attn_implementation)
self.lm.config.use_cache = False
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.max_seq_length = max_seq_length

# MRL support
self.mrl_enabled = getattr(args, 'mrl_enabled', False)
if self.mrl_enabled:
self.mrl_dims = args.mrl_dims
# Create projection layers for each target dimension
hidden_size = self.lm.config.hidden_size
self.mrl_projections = nn.ModuleDict({
str(dim): nn.Linear(hidden_size, dim) for dim in self.mrl_dims
})
# Move projection layers to the same device as the model
self.mrl_projections.to(self.lm.device)

def set_device(self):
self.device = self.lm.device
# Move MRL projections to the correct device if they exist
if self.mrl_enabled:
self.mrl_projections.to(self.device)

def forward(self, batch):
def get_mrl_embeddings(self, full_embeddings, target_dim):
"""Get embeddings for a specific target dimension"""
if not self.mrl_enabled:
return full_embeddings

if target_dim == self.lm.config.hidden_size:
# No projection needed for full dimension
return full_embeddings
elif str(target_dim) in self.mrl_projections:
# Use projection layer
return self.mrl_projections[str(target_dim)](full_embeddings)
else:
# Fallback to truncation
return full_embeddings[:, :target_dim]

def get_all_mrl_embeddings(self, full_embeddings):
"""Get embeddings for all MRL dimensions"""
if not self.mrl_enabled:
return {str(self.lm.config.hidden_size): full_embeddings}

embeddings_dict = {}
# Full dimension
embeddings_dict[str(self.lm.config.hidden_size)] = full_embeddings
# Projected dimensions
for dim in self.mrl_dims:
embeddings_dict[str(dim)] = self.get_mrl_embeddings(full_embeddings, dim)
return embeddings_dict

def forward(self, batch, target_dim=None):
bs = batch['bs']
num_hard_neg = int((len(batch['input_ids']) - 2*bs) / bs)

Expand All @@ -29,9 +84,38 @@ def forward(self, batch):
)

passage_features_all_tokens = outputs.last_hidden_state
return {
'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]),
'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]),
'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1)
}
# Extract [CLS] token embeddings
full_embeddings = torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(len(batch['seq_lens']))])

if target_dim is not None:
# Return embeddings for specific dimension
embeddings = self.get_mrl_embeddings(full_embeddings, target_dim)
elif self.mrl_enabled:
# Return embeddings for all dimensions
embeddings_dict = self.get_all_mrl_embeddings(full_embeddings)
else:
# Return full dimension embeddings only
embeddings = full_embeddings
embeddings_dict = None

# Split embeddings back to original format
if self.mrl_enabled and target_dim is None:
# Return dict with embeddings for all dimensions
result = {}
for dim, embs in embeddings_dict.items():
result[f'query_passage_features_{dim}'] = embs[:bs]
result[f'passage_passage_features_{dim}'] = embs[bs:2*bs]
result[f'negative_passage_features_{dim}'] = None if num_hard_neg == 0 else embs[2*bs:].view(bs, num_hard_neg, -1)
return result
else:
# Return single dimension embeddings
query_embs = embeddings[:bs] if target_dim is not None or not self.mrl_enabled else embeddings_dict[str(self.lm.config.hidden_size)][:bs]
passage_embs = embeddings[bs:2*bs] if target_dim is not None or not self.mrl_enabled else embeddings_dict[str(self.lm.config.hidden_size)][bs:2*bs]
negative_embs = None if num_hard_neg == 0 else embeddings[2*bs:].view(bs, num_hard_neg, -1) if target_dim is not None or not self.mrl_enabled else embeddings_dict[str(self.lm.config.hidden_size)][2*bs:].view(bs, num_hard_neg, -1)

return {
'query_passage_features': query_embs,
'passage_passage_features': passage_embs,
'negative_passage_features': negative_embs
}

18 changes: 15 additions & 3 deletions F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def __iter__(self):

accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************")
model = F2LLM(args.model_path, args.max_seq_length, args=args)
model.lm.gradient_checkpointing_enable()

# Only enable gradient checkpointing if CUDA is available
if torch.cuda.is_available():
model.lm.gradient_checkpointing_enable()

# set seed again to make sure that different models share the same seed
set_seed(0)

Expand All @@ -134,7 +138,10 @@ def __iter__(self):
num_warmup_steps=args.warmup_steps,
num_training_steps=args.train_steps)

AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
# Check if deepspeed plugin is available before accessing its config
if hasattr(AcceleratorState(), 'deepspeed_plugin') and AcceleratorState().deepspeed_plugin is not None:
AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size

model.lm, optimizer, lr_scheduler = accelerator.prepare(
model.lm, optimizer, lr_scheduler
)
Expand All @@ -148,6 +155,11 @@ def __iter__(self):
args.train_steps = len(train_dataloader) * args.train_epochs
accelerator.print(f"******************************** Training step after prepare: {args.train_steps} ********************************")

# Fix: Use the length of the first dataset or a default value if no datasets
train_datasets_dict = dict(train_datasets)
first_dataset_name = next(iter(train_datasets_dict)) if train_datasets_dict else None
dataset = train_datasets_dict[first_dataset_name] if first_dataset_name else None
num_train_samples = len(dataset) if dataset is not None else 0

accelerate_train(args, accelerator, model, train_dataloader, valid_loaders,
optimizer, lr_scheduler, len(dataset))
optimizer, lr_scheduler, num_train_samples)
Loading