From 86cf9376d1f43456690e4c28d67f7e7d985c5789 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=AF=E8=BF=9B?= Date: Fri, 28 Nov 2025 17:15:38 +0800 Subject: [PATCH 1/2] lora-support --- ...00\346\261\202\346\226\207\346\241\243.md" | 165 ++++++++++++++++++ F2LLM/arguments.py | 12 ++ F2LLM/model.py | 28 ++- F2LLM/requirements.txt | 1 + F2LLM/run.py | 18 +- F2LLM/utils.py | 75 ++++++-- 6 files changed, 280 insertions(+), 19 deletions(-) create mode 100644 "F2LLM/LoRA\346\224\257\346\214\201\351\234\200\346\261\202\346\226\207\346\241\243.md" diff --git "a/F2LLM/LoRA\346\224\257\346\214\201\351\234\200\346\261\202\346\226\207\346\241\243.md" "b/F2LLM/LoRA\346\224\257\346\214\201\351\234\200\346\261\202\346\226\207\346\241\243.md" new file mode 100644 index 0000000..f90c0a7 --- /dev/null +++ "b/F2LLM/LoRA\346\224\257\346\214\201\351\234\200\346\261\202\346\226\207\346\241\243.md" @@ -0,0 +1,165 @@ +# LoRA支持需求文档 + +## 1. 需求背景 + +在当前的CodeFuse-Embeddings项目中,F2LLM模块支持将Decoder-only LLMs转换为Embedding模型,主要通过全模型微调或转换的方式实现。为了提高训练效率并降低计算成本,我们引入了LoRA (Low-Rank Adaptation) PEFT (Parameter-Efficient Fine-Tuning) 方法的支持,使用户能够通过最小参数更新来适配基础模型。 + +## 2. 需求目标 + +- 实现LoRA微调方法的支持,提高训练效率 +- 减少模型训练时的内存使用和计算成本 +- 保持模型性能的同时显著减少可训练参数数量 +- 提供灵活的LoRA配置选项,满足不同场景需求 + +## 3. 功能实现 + +### 3.1 核心实现 + +1. **LoRA集成** + - 使用Hugging Face的PEFT库实现LoRA功能 + - 支持在指定模型层应用LoRA适配器 + - 实现LoRA参数的配置化管理 + +2. **模型支持** + - 支持Qwen系列模型的LoRA微调 + - 可扩展支持其他Decoder-only架构的LLMs + +### 3.2 配置参数 + +LoRA功能通过以下参数进行配置: + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `use_lora` | bool | false | 是否启用LoRA | +| `lora_r` | int | 8 | LoRA矩阵的秩 | +| `lora_alpha` | int | 32 | LoRA的缩放因子 | +| `lora_dropout` | float | 0.1 | LoRA层的Dropout率 | +| `lora_target_modules` | list | ["q_proj", "v_proj"] | 应用LoRA的模块列表 | + +### 3.3 配置文件示例 + +```json +{ + "model_path": "models/qwen3-0.6b", + "experiment_id": "0.6b_lora_test", + "train_data_path": "training_data/data_tokenized_qwen", + "output_dir": "output", + "tb_dir": "output/tb", + "cache_dir": "cache", + "train_batch_size": 16, + "checkpointing_steps": 5000, + "validation_steps": 5000, + "max_seq_length": 1024, + "learning_rate": 8e-6, + "min_lr": 1e-7, + "weight_decay": 0.1, + "warmup_steps": 500, + "train_epochs": 2, + "log_interval": 100, + "num_hard_neg": 7, + "use_lora": true, + "lora_r": 8, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"] +} +``` + +### 3.4 代码实现 + +#### 3.4.1 模型初始化 + +在`F2LLM/model.py`中,通过以下代码集成LoRA支持: + +```python +# 检查是否启用LoRA +if args and getattr(args, 'use_lora', False): + peft_config = LoraConfig( + task_type=TaskType.FEATURE_EXTRACTION, + inference_mode=False, + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.lora_target_modules + ) + self.lm = get_peft_model(self.lm, peft_config) + print("LoRA enabled") + self.lm.print_trainable_parameters() +``` + +#### 3.4.2 模型保存 + +在`F2LLM/utils.py`中,实现了针对LoRA模型的特殊保存逻辑: + +```python +# Handle LoRA model saving +if getattr(args, 'use_lora', False): + # For LoRA models, we only save the adapter weights + unwrapped_model.save_pretrained( + output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save + ) +else: + # For full fine-tuning, save the entire model + unwrapped_model.save_pretrained( + output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model.lm), + ) +``` + +### 3.5 使用方法 + +1. 创建LoRA配置文件,设置`use_lora: true`及其他相关参数 +2. 运行训练脚本: + ```bash + cd F2LLM + python run.py --config configs/config_lora.json + ``` +3. 训练完成后,LoRA适配器权重将保存在输出目录中 + +### 3.6 模型加载 + +对于LoRA模型,可以使用PEFT库加载: + +```python +from transformers import AutoModel, AutoTokenizer +from peft import PeftModel + +base_model = AutoModel.from_pretrained('base_model_path') +model = PeftModel.from_pretrained(base_model, 'output/{experiment_id}') +tokenizer = AutoTokenizer.from_pretrained('base_model_path') +``` + +## 4. 优势与效果 + +### 4.1 训练效率提升 + +- 显著减少可训练参数数量(通常减少99%以上) +- 降低内存使用,支持在资源受限设备上训练更大模型 +- 缩短训练时间 + +### 4.2 性能保持 + +- 在保持模型性能的同时实现参数高效微调 +- 支持与全参数微调相当的模型质量 + +### 4.3 灵活性 + +- 可配置的LoRA参数满足不同场景需求 +- 支持指定不同的模型层应用LoRA + +## 5. 注意事项 + +1. LoRA适配器需要与基础模型配合使用,单独加载适配器权重无法工作 +2. 不同的`lora_r`值会影响模型性能和训练效率的平衡 +3. `lora_target_modules`需要根据具体模型架构进行调整 + +## 6. 后续优化方向 + +1. 支持更多类型的PEFT方法(如AdaLoRA、IA³等) +2. 提供自动化的LoRA参数搜索功能 +3. 增加对更多模型架构的支持 +4. 优化LoRA在推理阶段的性能 diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..4447393 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -29,6 +29,18 @@ class Args: validation_steps: int = 100 # just placeholder, for logging purpose num_processes: int=0 + + # LoRA arguments + use_lora: bool = False + lora_r: int = 8 + lora_alpha: int = 32 + lora_dropout: float = 0.1 + lora_target_modules: list = None + + def __post_init__(self): + # Set default LoRA target modules if not provided + if self.lora_target_modules is None: + self.lora_target_modules = ["q_proj", "v_proj"] def dict(self): return asdict(self) diff --git a/F2LLM/model.py b/F2LLM/model.py index d33ade7..66a6687 100644 --- a/F2LLM/model.py +++ b/F2LLM/model.py @@ -1,5 +1,6 @@ import torch from transformers import AutoModel, AutoTokenizer +from peft import get_peft_model, LoraConfig, TaskType class F2LLM: @@ -12,11 +13,36 @@ def __init__(self, self.args = args self.dtype = torch.bfloat16 self.device = None # set after accelerator.prepare - self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2') + + # Check if CUDA is available and set the attention implementation accordingly + # Only use flash_attention_2 if CUDA is available and flash_attn is installed + attn_implementation = None + if torch.cuda.is_available(): + try: + import flash_attn + attn_implementation = 'flash_attention_2' + except ImportError: + attn_implementation = 'eager' # or 'sdpa' if available + + self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation=attn_implementation) self.lm.config.use_cache = False self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.max_seq_length = max_seq_length + # Add LoRA support + if args and getattr(args, 'use_lora', False): + peft_config = LoraConfig( + task_type=TaskType.FEATURE_EXTRACTION, + inference_mode=False, + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.lora_target_modules + ) + self.lm = get_peft_model(self.lm, peft_config) + print("LoRA enabled") + self.lm.print_trainable_parameters() + def set_device(self): self.device = self.lm.device diff --git a/F2LLM/requirements.txt b/F2LLM/requirements.txt index 82fb447..f69a0b8 100644 --- a/F2LLM/requirements.txt +++ b/F2LLM/requirements.txt @@ -5,3 +5,4 @@ flash-attn torch transformers tensorboard +peft>=0.4.0 diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..46279c0 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -120,7 +120,11 @@ def __iter__(self): accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************") model = F2LLM(args.model_path, args.max_seq_length, args=args) -model.lm.gradient_checkpointing_enable() + +# Only enable gradient checkpointing if CUDA is available +if torch.cuda.is_available(): + model.lm.gradient_checkpointing_enable() + # set seed again to make sure that different models share the same seed set_seed(0) @@ -134,7 +138,10 @@ def __iter__(self): num_warmup_steps=args.warmup_steps, num_training_steps=args.train_steps) -AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size +# Check if deepspeed plugin is available before accessing its config +if hasattr(AcceleratorState(), 'deepspeed_plugin') and AcceleratorState().deepspeed_plugin is not None: + AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size + model.lm, optimizer, lr_scheduler = accelerator.prepare( model.lm, optimizer, lr_scheduler ) @@ -148,6 +155,11 @@ def __iter__(self): args.train_steps = len(train_dataloader) * args.train_epochs accelerator.print(f"******************************** Training step after prepare: {args.train_steps} ********************************") +# Fix: Use the length of the first dataset or a default value if no datasets +train_datasets_dict = dict(train_datasets) +first_dataset_name = next(iter(train_datasets_dict)) if train_datasets_dict else None +dataset = train_datasets_dict[first_dataset_name] if first_dataset_name else None +num_train_samples = len(dataset) if dataset is not None else 0 accelerate_train(args, accelerator, model, train_dataloader, valid_loaders, - optimizer, lr_scheduler, len(dataset)) \ No newline at end of file + optimizer, lr_scheduler, num_train_samples) \ No newline at end of file diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..d4690c2 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -22,12 +22,23 @@ def save_checkpoint(args, accelerator, model, output_dir, lr_scheduler): if accelerator.is_main_process: model.tokenizer.save_pretrained(output_dir) unwrapped_model = accelerator.unwrap_model(model.lm) - unwrapped_model.save_pretrained( - output_dir, - is_main_process=accelerator.is_main_process, - save_function=accelerator.save, - state_dict=accelerator.get_state_dict(model.lm), # this is required for zero 3 - ) + + # Handle LoRA model saving + if getattr(args, 'use_lora', False): + # For LoRA models, we only save the adapter weights + unwrapped_model.save_pretrained( + output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save + ) + else: + # For full fine-tuning, save the entire model + unwrapped_model.save_pretrained( + output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model.lm), # this is required for zero 3 + ) accelerator.wait_for_everyone() @@ -65,7 +76,7 @@ def hard_loss( ): if hard_neg_embeddings is None: - return 0.0 + return torch.tensor(0.0, device=query_embeddings.device) bs = query_embeddings.size(0) a_norm = F.normalize(query_embeddings, p=2, dim=-1) @@ -87,21 +98,55 @@ def validate(args, accelerator, model, valid_loader_dict, criterion, completed_s eval_log_dict = {} for dataset_name, valid_dataloader in valid_loader_dict.items(): loss_ls, loss_hard_ls = [], [] + for batch in valid_dataloader: with torch.no_grad(): outputs = model.forward(batch) - loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) - loss_hard_ls.append(accelerator.gather(loss_hard).float()) + + query_emb = outputs['query_passage_features'].squeeze(1) + passage_emb = outputs['passage_passage_features'].squeeze(1) + neg_emb = outputs['negative_passage_features'] + + loss_hard = hard_loss(query_emb, passage_emb, neg_emb, criterion, accelerator) + if accelerator.num_processes > 1: + # When using multiple processes, we need to gather the loss from all processes + gathered_loss_hard = accelerator.gather(loss_hard).float() + # Ensure gathered_loss_hard is at least 1D + if gathered_loss_hard.dim() == 0: + # Scalar tensor, convert to 1D tensor with one element + gathered_loss_hard = gathered_loss_hard.unsqueeze(0) + loss_hard_ls.append(gathered_loss_hard) + else: + # When using a single process, just append the loss directly + # Ensure loss_hard is at least 1D + if loss_hard.dim() == 0: + loss_hard = loss_hard.unsqueeze(0) + loss_hard_ls.append(loss_hard.float()) if dataset_name in RETRIEVAL_DATASETS: - loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) - loss_ls.append(accelerator.gather(loss).float()) + loss = inbatch_loss(query_emb, passage_emb, criterion, accelerator) + if accelerator.num_processes > 1: + # When using multiple processes, we need to gather the loss from all processes + gathered_loss = accelerator.gather(loss).float() + # Ensure gathered_loss is at least 1D + if gathered_loss.dim() == 0: + # Scalar tensor, convert to 1D tensor with one element + gathered_loss = gathered_loss.unsqueeze(0) + loss_ls.append(gathered_loss) + else: + # When using a single process, just append the loss directly + # Ensure loss is at least 1D + if loss.dim() == 0: + loss = loss.unsqueeze(0) + loss_ls.append(loss.float()) accelerator.wait_for_everyone() - loss_hard_ls = torch.cat(loss_hard_ls) - eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean() + if loss_hard_ls: # Check if the list is not empty + loss_hard_ls = torch.cat(loss_hard_ls) + eval_log_dict[f'{dataset_name}/valid_loss_hard'] = loss_hard_ls.mean() if dataset_name in RETRIEVAL_DATASETS: - loss_ls = torch.cat(loss_ls) - eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = loss_ls.mean() + if loss_ls: # Check if the list is not empty + loss_ls = torch.cat(loss_ls) + eval_log_dict[f"{dataset_name}/valid_loss_in_batch"] = loss_ls.mean() eval_log_dict['Avg/retrieval/valid_loss_in_batch'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_in_batch')]).mean() eval_log_dict['Avg/retrieval/valid_loss_hard'] = torch.tensor([v for k, v in eval_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('valid_loss_hard')]).mean() From eed1c44195d1ea47c0c77d4c36eaee602a198f18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=AF=E8=BF=9B?= Date: Fri, 28 Nov 2025 17:23:59 +0800 Subject: [PATCH 2/2] lora-support --- F2LLM/configs/config_lora.json | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 F2LLM/configs/config_lora.json diff --git a/F2LLM/configs/config_lora.json b/F2LLM/configs/config_lora.json new file mode 100644 index 0000000..df7045d --- /dev/null +++ b/F2LLM/configs/config_lora.json @@ -0,0 +1,24 @@ +{ + "model_path": "models/qwen3-0.6b", + "experiment_id": "0.6b_lora_test", + "train_data_path": "training_data/data_tokenized_qwen", + "output_dir": "output", + "tb_dir": "output/tb", + "cache_dir": "cache", + "train_batch_size": 16, + "checkpointing_steps": 5000, + "validation_steps": 5000, + "max_seq_length": 1024, + "learning_rate": 8e-6, + "min_lr": 1e-7, + "weight_decay": 0.01, + "warmup_steps": 500, + "train_epochs": 2, + "log_interval": 100, + "num_hard_neg": 7, + "use_lora": true, + "lora_r": 8, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"] +}