Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions F2LLM/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**/__pycache__/
7 changes: 7 additions & 0 deletions F2LLM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ In this repo we provide a streamlined and efficient script for training embeddin
- Modify model path, data path, and other arguments in `configs/config.json`.
- Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`.

#### For More Decoder-Only Models
- Setup environment following `requirements.txt`. We note that transformers>=4.51.0 is required for training decoder-only models.
- Download data and backbone models from Hugging Face.
- Run `tokenize_data_universal.py` with model path, tokenized data_dir to tokenize the downloaded data
- Modify model path, train_data_path, and other arguments in `configs/config_universal.json`.
- Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config_universal.json`.

Note: we recommend setting `num_processes` to 1 in `configs/accelerate_config.yaml` and launch the training code once to generate cache for training data before starting the actual training.

For multi-node training, run on the main node:
Expand Down
32 changes: 24 additions & 8 deletions F2LLM/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,59 @@

@dataclass
class Args:
model_path: str = None
experiment_id: str = None

model_path: str
experiment_id: str
model_config: dict = None
tokenizer_config: dict = None

# save dir
output_dir: str
tb_dir: str
cache_dir: str
output_dir: str = None
tb_dir: str = None
cache_dir: str = None

# training arguments
train_data_path: str
train_data_path: str = None
train_batch_size: int = 8
max_seq_length: int = 2048
learning_rate: float = 1e-4
min_lr: float = 1e-6
weight_decay: float = 1e-2
warmup_steps: int = 100

# embedding-related settings
num_hard_neg: int = 7

# train steps take precedence over epochs, set to -1 to disable
train_steps: int = -1
train_epochs: int = 5
log_interval: int = 20
checkpointing_steps: int = 100
validation_steps: int = 100

# just placeholder, for logging purpose
num_processes: int=0
num_processes: int = 0

def dict(self):
return asdict(self)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--config", type=str, required=True,
help="Path to configuration file")
arg = parser.parse_args()

with open(arg.config) as f:
config = json.load(f)

args = Args(**config)

# 确保model_path正确设置
if not args.model_path and args.model_config:
args.model_path = args.model_config["model_path"]

args.output_dir = f"{args.output_dir}/{args.experiment_id}"
args.tb_dir = f"{args.tb_dir}/{args.experiment_id}"

return args
32 changes: 32 additions & 0 deletions F2LLM/configs/config_universal.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_config": {
"model_type": "qwen",
"model_path": "models/qwen3-4b",
"model_params": {
"torch_dtype": "bfloat16",
"attn_implementation": "flash_attention_2",
"trust_remote_code": true
}
},
"tokenizer_config": {
"add_special_tokens": false,
"padding_side": "right",
"pad_token": null
},
"experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs",
"train_data_path": "training_data/data_tokenized",
"output_dir": "output",
"tb_dir": "output/tb",
"cache_dir": "cache",
"train_batch_size": 16,
"checkpointing_steps": 5000,
"validation_steps": 5000,
"max_seq_length": 1024,
"learning_rate": 8e-6,
"min_lr": 1e-7,
"weight_decay": 0.01,
"warmup_steps": 500,
"train_epochs": 2,
"log_interval": 100,
"num_hard_neg": 7
}
14 changes: 9 additions & 5 deletions F2LLM/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from transformers import AutoModel, AutoTokenizer
from model_factory import ModelFactory


class F2LLM:
Expand All @@ -11,10 +11,15 @@ def __init__(self,

self.args = args
self.dtype = torch.bfloat16
self.device = None # set after accelerator.prepare
self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2')
self.device = None # set after accelerator.prepare

# Use model factory to create adapter
self.adapter = ModelFactory.create_adapter(model_path, max_seq_length, args)

# Load model and tokenizer
self.lm = self.adapter.load_model()
self.lm.config.use_cache = False
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.tokenizer = self.adapter.load_tokenizer()
self.max_seq_length = max_seq_length

def set_device(self):
Expand All @@ -34,4 +39,3 @@ def forward(self, batch):
'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]),
'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1)
}

66 changes: 66 additions & 0 deletions F2LLM/model_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
from transformers import AutoModel, AutoTokenizer
import os
import json
from abc import ABC, abstractmethod


class BaseModelAdapter(ABC):
"""Base model adapter interface"""

def __init__(self, model_path, max_seq_length=512, args=None):
self.model_path = model_path
self.max_seq_length = max_seq_length
self.args = args
self.dtype = torch.bfloat16
self.device = None

@abstractmethod
def load_model(self):
"""Load model"""
pass

@abstractmethod
def load_tokenizer(self):
"""Load tokenizer"""
pass

def get_model_config(self):
"""Get model configuration"""
config_path = os.path.join(self.model_path, 'config.json')
if os.path.exists(config_path):
with open(config_path) as f:
return json.load(f)
return {}


class QwenAdapter(BaseModelAdapter):
"""Qwen series model adapter (Qwen, Qwen2, Qwen3)"""

def load_model(self):
return AutoModel.from_pretrained(
self.model_path,
trust_remote_code=True,
torch_dtype=self.dtype,
attn_implementation='flash_attention_2'
)

def load_tokenizer(self):
return AutoTokenizer.from_pretrained(self.model_path)


class LlamaAdapter(BaseModelAdapter):
"""Llama series model adapter (Llama-2, Llama-3, CodeLlama)"""

def load_model(self):
return AutoModel.from_pretrained(
self.model_path,
torch_dtype=self.dtype,
attn_implementation='flash_attention_2'
)

def load_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
110 changes: 110 additions & 0 deletions F2LLM/model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
import json
from typing import Dict, Type
from model_adapters import BaseModelAdapter, QwenAdapter, LlamaAdapter


class ModelFactory:
"""Model factory for creating adapters based on model type"""

# Mapping of model types to adapters
MODEL_ADAPTERS: Dict[str, Type[BaseModelAdapter]] = {
'qwen': QwenAdapter,
'qwen2': QwenAdapter,
'qwen3': QwenAdapter,
'llama': LlamaAdapter,
}

@classmethod
def create_adapter(cls, model_path: str, max_seq_length: int = 512, args=None) -> BaseModelAdapter:
"""Create adapter based on model path and type"""
model_type = cls.detect_model_type(model_path)
adapter_class = cls.MODEL_ADAPTERS.get(model_type)

if not adapter_class:
# Use LlamaAdapter as fallback for unknown model types
print(f"Warning: Unknown model type '{model_type}', using LlamaAdapter as fallback")
adapter_class = LlamaAdapter

return adapter_class(model_path, max_seq_length, args)

@classmethod
def detect_model_type(cls, model_path: str) -> str:
"""Detect model type"""
# Method 1: Detect via config file
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
try:
with open(config_path) as f:
config = json.load(f)
model_type = config.get('model_type', '').lower()
if model_type:
return model_type
except Exception:
pass

# Method 2: Infer from path name
path_lower = model_path.lower()
model_type_mappings = {
'qwen': ['qwen', 'qwen2', 'qwen3'],
'llama': ['llama', 'llama-2', 'llama-3', 'meta-llama', 'codellama'],
}

for model_type, keywords in model_type_mappings.items():
for keyword in keywords:
if keyword in path_lower:
return model_type

# Method 3: Detect via folder structure
if os.path.exists(os.path.join(model_path, 'tokenizer_config.json')):
try:
tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
with open(tokenizer_config_path) as f:
tokenizer_config = json.load(f)
tokenizer_class = tokenizer_config.get('tokenizer_class', '').lower()

if 'qwen' in tokenizer_class:
return 'qwen'
elif 'llama' in tokenizer_class:
return 'llama'
except Exception:
pass

return 'unknown'

@classmethod
def list_supported_models(cls) -> list:
"""Return list of supported model types"""
return list(cls.MODEL_ADAPTERS.keys())

@classmethod
def get_model_info(cls, model_path: str) -> dict:
"""Get model information"""
model_type = cls.detect_model_type(model_path)
adapter_class = cls.MODEL_ADAPTERS.get(model_type)

info = {
'model_path': model_path,
'detected_type': model_type,
'adapter_class': adapter_class.__name__ if adapter_class else 'Unknown',
'is_supported': model_type in cls.MODEL_ADAPTERS
}

# Try to get model configuration info
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
try:
with open(config_path) as f:
config = json.load(f)
info.update({
'model_name': config.get('_name_or_path', 'Unknown'),
'vocab_size': config.get('vocab_size', 0),
'hidden_size': config.get('hidden_size', 0),
'num_layers': config.get('num_hidden_layers', 0),
'num_attention_heads': config.get('num_attention_heads', 0),
'max_position_embeddings': config.get('max_position_embeddings', 0)
})
except Exception as e:
info['config_error'] = str(e)

return info
4 changes: 3 additions & 1 deletion F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from model import F2LLM
from model_factory import ModelFactory

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down Expand Up @@ -69,7 +70,8 @@ def collate_fn(batch_raw):
train_datasets.append((dataset_name, dataset['train']))
valid_datasets.append((dataset_name, dataset['test']))

tokenizer = AutoTokenizer.from_pretrained(args.model_path)
adapter = ModelFactory.create_adapter(args.model_path)
tokenizer = adapter.load_tokenizer()

train_loaders = {
name: DataLoader(ds, shuffle=True, batch_size=args.train_batch_size, collate_fn=collate_fn)
Expand Down
Loading