diff --git a/.gitignore b/.gitignore index 5f29895be914..47510447a842 100644 --- a/.gitignore +++ b/.gitignore @@ -140,4 +140,5 @@ autogen/ #fp8 ops/csrc/fp8/deep_gemm/include/cutlass ops/csrc/fp8/deep_gemm/include/cute -.ccls-cache \ No newline at end of file +.ccls-cache +llm/log diff --git a/docs/zh/llm/benchmark/rl/README.md b/docs/zh/llm/benchmark/rl/README.md new file mode 120000 index 000000000000..c8ff8b971399 --- /dev/null +++ b/docs/zh/llm/benchmark/rl/README.md @@ -0,0 +1 @@ +../../../../../llm/benchmark/rl/README.md \ No newline at end of file diff --git a/llm/config/llama/dislora_argument.json b/llm/config/llama/dislora_argument.json new file mode 100644 index 000000000000..35dc6ed7bcd8 --- /dev/null +++ b/llm/config/llama/dislora_argument.json @@ -0,0 +1,37 @@ +{ + "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", + "dataset_name_or_path": "/home/bjh/bjh/Dislora/cs_5_lite", + "output_dir": "./checkpoints/dislora_ckpts_3", + "dislora": true, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 5, + "num_train_epochs": 1, + "learning_rate": 2e-05, + "lr_scheduler_type": "linear", + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "no", + "save_strategy": "steps", + "save_steps": 500, + "src_length": 256, + "max_length": 512, + "bf16": true, + "do_train": true, + "do_eval": false, + "disable_tqdm": false, + "load_best_model_at_end": false, + "eval_with_do_generation": false, + "recompute": false, + "save_total_limit": 5, + "fp16_opt_level": "O2", + "sharding": "stage3", + "zero_padding": false, + "use_flash_attention": false, + "unified_checkpoint": false, + "dislora_rank": 8, + "dislora_dropout": 0.05, + "target_modules": [".*q_proj.*", ".*v_proj.*", ".*k_proj.*", ".*o_proj.*"], + "s_tsd": 8, + "ortho_lambda": 1.0, + "prefer_small_sigma": true +} \ No newline at end of file diff --git a/llm/config/qwen/dislora_argument.json b/llm/config/qwen/dislora_argument.json new file mode 100644 index 000000000000..f1383adaa163 --- /dev/null +++ b/llm/config/qwen/dislora_argument.json @@ -0,0 +1,36 @@ +{ + "model_name_or_path": "Qwen/Qwen2.5-7B-Instruct", + "dataset_name_or_path": "/home/bjh/bjh/Dislora/cs_5_lite", + "output_dir": "./checkpoints/dislora_ckpts", + "dislora": true, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "num_train_epochs": 1, + "learning_rate": 2e-05, + "lr_scheduler_type": "linear", + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "no", + "save_strategy": "steps", + "save_steps": 500, + "src_length": 256, + "max_length": 512, + "bf16": true, + "do_train": true, + "do_eval": false, + "disable_tqdm": false, + "load_best_model_at_end": false, + "eval_with_do_generation": false, + "recompute": false, + "save_total_limit": 5, + "fp16_opt_level": "O2", + "sharding": "stage3", + "zero_padding": false, + "use_flash_attention": false, + "unified_checkpoint": false, + "dislora_rank": 8, + "dislora_dropout": 0.05, + "s_tsd": 8, + "ortho_lambda": 1.0, + "prefer_small_sigma": true +} \ No newline at end of file diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 31427a516f2d..37afa9e4528f 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -31,6 +31,8 @@ ) from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL from paddlenlp.peft import ( + DisLoRAConfig, + DisLoRAModel, LoKrConfig, LoKrModel, LoRAConfig, @@ -311,6 +313,15 @@ def neft_post_hook(module, input, output): tokenizer.pad_token_id = tokenizer.eos_token_id train_ds, dev_ds, test_ds = create_dataset(data_args, training_args) + + train_dataset_size = None + if train_ds is not None and model_args.dislora: + train_dataset_size = get_dataset_size(train_ds) + if train_dataset_size is not None: + logger.info(f"Original training dataset size: {train_dataset_size}") + else: + logger.warning("Unable to determine training dataset size for dynamic dash_flag calculation") + # TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later. if training_args.resume_from_checkpoint is not None and data_args.lazy: logger.info( @@ -377,7 +388,9 @@ def neft_post_hook(module, input, output): if eval_zero_padding and test_ds is not None: test_ds = intoken_dataset(test_ds, tokenizer=tokenizer, max_length=data_args.max_length) - model = create_peft_model(model_args, reft_args, training_args, dtype, model_config, model, reft_layers) + model = create_peft_model( + model_args, reft_args, training_args, dtype, model_config, model, reft_layers, train_dataset_size + ) def compute_metrics_do_generation(eval_preds): rouge1 = Rouge1() @@ -441,6 +454,10 @@ def compute_metrics_do_generation(eval_preds): return_attention_mask=not model_args.flash_mask, pad_to_multiple_of=data_args.pad_to_multiple_of, ) + + if model_args.dislora and hasattr(model_args, "ortho_lambda"): + training_args.dislora_ortho_lambda = model_args.ortho_lambda + trainer = SFTTrainer( model=model, args=training_args, @@ -531,7 +548,9 @@ def save_to_aistudio(model_args, training_args, trainer): ) -def create_peft_model(model_args, reft_args, training_args, dtype, model_config, model, reft_layers): +def create_peft_model( + model_args, reft_args, training_args, dtype, model_config, model, reft_layers, train_dataset_size +): if model_args.prefix_tuning: if training_args.pipeline_parallel_degree > 1: raise NotImplementedError("Prefix tuning is not implemented for pipeline parallelism.") @@ -606,6 +625,53 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config, else: model = LoKrModel.from_pretrained(model=model, lokr_path=model_args.lokr_path) + if model_args.dislora: + # Calculate dynamic dash_flag based on training configuration + if train_dataset_size is not None and training_args.do_train: + # Calculate warmup steps: len(train_data) * num_epochs // (batch_size * gradient_accumulation_steps * 3) + effective_batch_size = ( + training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * training_args.dataset_world_size # Consider data parallel + ) + calculated_dash_flag = (train_dataset_size * training_args.num_train_epochs) // (effective_batch_size * 3) + + # Use calculated value if it's reasonable, otherwise fall back to model_args + if calculated_dash_flag > 0: + dash_flag = calculated_dash_flag + logger.info( + f"Calculated dynamic dash_flag: {dash_flag} based on dataset size: {train_dataset_size}, " + f"epochs: {training_args.num_train_epochs}, effective batch size: {effective_batch_size}" + ) + else: + dash_flag = model_args.dash_flag + logger.warning( + f"Calculated dash_flag was {calculated_dash_flag}, using model_args.dash_flag: {dash_flag}" + ) + else: + dash_flag = getattr(model_args, "dash_flag", 50) + if train_dataset_size is None: + logger.info( + f"Unable to calculate dynamic dash_flag (dataset size unknown), using configured dash_flag: {dash_flag}" + ) + else: + logger.info(f"Not in training mode, using configured dash_flag: {dash_flag}") + if model_args.dislora_path is None: + dislora_config = DisLoRAConfig( + target_modules=model_args.target_modules + if model_args.target_modules + else get_lora_target_modules(model), + r=model_args.dislora_rank, + dislora_alpha=1.5 * model_args.dislora_rank, + dislora_dropout=model_args.dislora_dropout, + dtype=dtype, + base_model_name_or_path=model_args.model_name_or_path, + s_tsd=model_args.s_tsd, + dash_flag=dash_flag, # Use calculated dash_flag + ortho_lambda=model_args.ortho_lambda, + ) + model = DisLoRAModel(model, dislora_config) + if model_args.reft: intervention_dtype = dtype intervention_params = { @@ -745,5 +811,24 @@ def create_dataset(data_args, training_args): return train_ds, dev_ds, test_ds +def get_dataset_size(dataset): + """Get the size of a dataset, handling both lazy and regular datasets""" + if dataset is None: + return None + + try: + if hasattr(dataset, "__len__"): + return len(dataset) + elif hasattr(dataset, "_length"): + return dataset._length + else: + # For lazy datasets, we might need to iterate once to count + logger.warning("Unable to determine dataset size directly for lazy loading dataset") + return None + except Exception as e: + logger.warning(f"Error getting dataset size: {e}") + return None + + if __name__ == "__main__": main() diff --git a/llm/tools/merge_dislora_params.py b/llm/tools/merge_dislora_params.py new file mode 100644 index 000000000000..f393b3d4971a --- /dev/null +++ b/llm/tools/merge_dislora_params.py @@ -0,0 +1,290 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import paddle + +from paddlenlp.peft import DisLoRAConfig, DisLoRAModel +from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from paddlenlp.utils.env import CONFIG_NAME + + +def parse_arguments(): + """解析命令行参数""" + parser = argparse.ArgumentParser() + parser.add_argument("--model_name_or_path", default=None, help="The directory of pretrained model.") + parser.add_argument("--dislora_path", default="", help="The directory of dislora parameters. Default to None") + parser.add_argument( + "--merge_dislora_model_path", + default="", + help="The directory of merged parameters. Default to None", + ) + parser.add_argument("--device", type=str, default="gpu", help="Device") + parser.add_argument( + "--low_gpu_mem", type=bool, default=True, help="Whether to use low gpu memory. Default to False" + ) + return parser.parse_args() + + +def weight_process(name, dislora_config, state_dict): + """ + Based on the DisLoRA algorithm for processing weight merging: + The final weight = W_prin + W_res + W_TSD + However, here we do not directly add the adapter to the base model; instead, we reconstruct the entire weight matrix. + Args: + name: Layer name (e.g. "model.layers.0.self_attn.q_proj") + dislora_config: DisLoRA configuration + state_dict: Model state dictionary + # Define the weight_process function to handle the DisLoRA weight merging. The parameters include the layer name, DisLoRA configuration, and the model state dictionary. + """ + + weight_key = name + ".weight" + + if weight_key not in state_dict: + print(f"Warning: {weight_key} not found in state_dict") + return + + w_prin = state_dict[weight_key] + print(f"Processing layer: {name}") + print(f" W_prin shape: {w_prin.shape}") + + scaling = dislora_config.dislora_alpha / dislora_config.r + + final_weight = w_prin.clone() + + ur_key = name + ".Direc_Ur.weight" + sr_key = name + ".Direc_Sr" + vhr_key = name + ".Direc_Vhr.weight" + + w_res_added = False + + if all(key in state_dict for key in [ur_key, sr_key, vhr_key]): + + direc_ur = state_dict[ur_key] # [r, out_features] + direc_sr = state_dict[sr_key] # [r] + direc_vhr = state_dict[vhr_key] # [in_features, r] + + s_diag = paddle.diag(direc_sr) # [r, r] + + w_res = direc_vhr @ s_diag @ direc_ur * scaling # [in_features, out_features] + + if w_res.shape != w_prin.shape: + print(f" Error: W_res shape {w_res.shape} doesn't match W_prin shape {w_prin.shape}") + return + + final_weight += w_res + w_res_added = True + print(f" ✓ Added W_res with scaling factor: {scaling}") + else: + print(f" ⚠ W_res components not found for {name}") + + utsd_key = name + ".Direc_Utsd.weight" + stsd_key = name + ".Direc_Stsd" + vhtsd_key = name + ".Direc_Vhtsd.weight" + + w_tsd_added = False + if all(key in state_dict for key in [utsd_key, stsd_key, vhtsd_key]): + + direc_utsd = state_dict[utsd_key] # [s_tsd, out_features] + direc_stsd = state_dict[stsd_key] # [s_tsd] + direc_vhtsd = state_dict[vhtsd_key] # [in_features, s_tsd] + + if not paddle.all(direc_stsd == 0.0): + + s_diag_tsd = paddle.diag(direc_stsd) # [s_tsd, s_tsd] + + w_tsd = direc_vhtsd @ s_diag_tsd @ direc_utsd * scaling # [in_features, out_features] + + if w_tsd.shape != w_prin.shape: + print(f" Error: W_TSD shape {w_tsd.shape} doesn't match W_prin shape {w_prin.shape}") + return + + final_weight += w_tsd + w_tsd_added = True + print(f" ✓ Added W_TSD with scaling factor: {scaling}") + else: + print(f" ⚠ W_TSD parameters are uninitialized (all zeros) for {name}") + else: + print(f" ⚠ W_TSD components not found for {name}") + + state_dict[weight_key] = final_weight + + keys_to_remove = [] + for key in state_dict.keys(): + if key.startswith(name + ".Direc_") or key == name + ".step": + keys_to_remove.append(key) + + for key in keys_to_remove: + removed_param = state_dict.pop(key) + print(f" ✓ Removed DisLoRA parameter: {key} (shape: {removed_param.shape})") + + components = [] + if w_res_added: + components.append("W_res") + if w_tsd_added: + components.append("W_TSD") + + if components: + print(f" ✓ Successfully merged: W_prin + {' + '.join(components)}") + else: + print(" ✓ Kept original W_prin (no adaptations found)") + print() + + +def merge(): + + args = parse_arguments() + paddle.set_device(args.device) + + print("Loading DisLoRA configuration...") + dislora_config = DisLoRAConfig.from_pretrained(args.dislora_path) + if dislora_config.base_model_name_or_path is None: + if args.model_name_or_path is None: + raise ValueError("We can not find a valid model_name_or_path.") + else: + dislora_config.base_model_name_or_path = args.model_name_or_path + + print("Loading model configuration...") + if os.path.isfile(os.path.join(args.dislora_path, CONFIG_NAME)): + config = AutoConfig.from_pretrained(args.dislora_path) + elif args.model_name_or_path is not None: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + raise ValueError( + f"We can not find config.json in dislora_path: {args.dislora_path} or find a valid model_name_or_path." + ) + + config.dtype = dislora_config.dtype + + if ( + dislora_config.dtype == "bfloat16" + or ( + hasattr(config, "quantization_config") + and hasattr(config.quantization_config, "weight_quantize_algo") + and config.quantization_config.weight_quantize_algo in ["nf4", "fp4"] + ) + ) and args.device == "cpu": + raise ValueError("We can not apply bfloat16 or nf4/fp4 dislora merge on cpu.") + + print("Loading base model...") + model = AutoModelForCausalLM.from_pretrained( + dislora_config.base_model_name_or_path, + config=config, + low_cpu_mem_usage=args.low_gpu_mem, + ) + + print("Loading DisLoRA model...") + model = DisLoRAModel.from_pretrained(model=model, dislora_path=args.dislora_path, dislora_config=dislora_config) + + model.eval() + model_state_dict = model.model.state_dict() + + print(f"Total parameters in state_dict: {len(model_state_dict)}") + + step_keys = [key for key in model_state_dict.keys() if key.endswith(".step")] + if step_keys: + print(f"Found {len(step_keys)} step parameters in loaded model:") + for key in step_keys[:5]: + print(f" {key}") + if len(step_keys) > 5: + print(f" ... and {len(step_keys) - 5} more") + else: + print("No step parameters found in loaded model") + print() + + print("Identifying DisLoRA layers...") + dislora_name_set = set() + for key in model_state_dict.keys(): + if any( + dislora_param in key + for dislora_param in ["Direc_Ur", "Direc_Sr", "Direc_Vhr", "Direc_Utsd", "Direc_Stsd", "Direc_Vhtsd"] + ): + + for param_type in ["Direc_Ur", "Direc_Sr", "Direc_Vhr", "Direc_Utsd", "Direc_Stsd", "Direc_Vhtsd"]: + if f".{param_type}" in key: + layer_name = key.split(f".{param_type}")[0] + dislora_name_set.add(layer_name) + break + + dislora_name_list = sorted(list(dislora_name_set)) + + print(f"Found {len(dislora_name_list)} DisLoRA layers:") + for i, name in enumerate(dislora_name_list, 1): + print(f" {i:2d}. {name}") + print() + + print("Merging DisLoRA parameters...") + + for i, name in enumerate(dislora_name_list, 1): + print(f"[{i}/{len(dislora_name_list)}] Processing: {name}") + weight_process(name, dislora_config, model_state_dict) + + print("Cleaning up remaining step parameters...") + step_keys_to_remove = [key for key in model_state_dict.keys() if key.endswith(".step")] + for key in step_keys_to_remove: + removed_param = model_state_dict.pop(key) + print(f" ✓ Removed step parameter: {key} (shape: {removed_param.shape})") + + if step_keys_to_remove: + print(f"✓ Removed {len(step_keys_to_remove)} step parameters") + else: + print("✓ No step parameters found") + print() + + print("Verifying parameter cleanup...") + remaining_dislora_params = [] + remaining_step_params = [] + for key in model_state_dict.keys(): + if any( + dislora_param in key + for dislora_param in ["Direc_Ur", "Direc_Sr", "Direc_Vhr", "Direc_Utsd", "Direc_Stsd", "Direc_Vhtsd"] + ): + remaining_dislora_params.append(key) + if key.endswith(".step"): + remaining_step_params.append(key) + + if remaining_dislora_params: + print(f"Warning: {len(remaining_dislora_params)} DisLoRA parameters still remain:") + for param in remaining_dislora_params: + print(f" - {param}") + else: + print("✓ All DisLoRA parameters successfully removed") + + if remaining_step_params: + print(f"Warning: {len(remaining_step_params)} step parameters still remain:") + for param in remaining_step_params: + print(f" - {param}") + else: + print("✓ All step parameters successfully removed") + print() + + print("Saving merged model...") + os.makedirs(args.merge_dislora_model_path, exist_ok=True) + model.model.save_pretrained(args.merge_dislora_model_path, state_dict=model_state_dict) + + print("Saving tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(dislora_config.base_model_name_or_path) + tokenizer.save_pretrained(args.merge_dislora_model_path) + + print("=" * 80) + print("✓ DisLoRA merge completed successfully!") + print(f"✓ Merged model saved to: {args.merge_dislora_model_path}") + print(f"✓ Processed {len(dislora_name_list)} DisLoRA layers") + print("=" * 80) + + +if __name__ == "__main__": + merge() diff --git a/paddlenlp/peft/__init__.py b/paddlenlp/peft/__init__.py index 85c61ffc793b..331488e56dc0 100644 --- a/paddlenlp/peft/__init__.py +++ b/paddlenlp/peft/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .dislora import DisLoRAConfig, DisLoRALinear, DisLoRAModel from .lokr import LoKrConfig, LoKrModel from .lora import LoRAAutoConfig, LoRAAutoModel, LoRAConfig, LoRAModel from .prefix import PrefixConfig, PrefixModelForCausalLM diff --git a/paddlenlp/peft/dislora/__init__.py b/paddlenlp/peft/dislora/__init__.py new file mode 100644 index 000000000000..c1bdb6cd810e --- /dev/null +++ b/paddlenlp/peft/dislora/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dislora_config import DisLoRAConfig +from .dislora_layer import DisLoRALinear +from .dislora_model import DisLoRAModel + +__all__ = ["DisLoRAConfig", "DisLoRAModel", "DisLoRALinear"] diff --git a/paddlenlp/peft/dislora/dislora_config.py b/paddlenlp/peft/dislora/dislora_config.py new file mode 100644 index 000000000000..b9ff8b1b36a7 --- /dev/null +++ b/paddlenlp/peft/dislora/dislora_config.py @@ -0,0 +1,160 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from dataclasses import asdict, dataclass, field +from typing import List, Optional, Union + +from ...utils.env import DISLORA_CONFIG_NAME + + +@dataclass +class DisLoRAConfig: + """ + This is the configuration class to store the configuration of a [`DisLoRAModel`]. + Args: + target_modules (`Union[List[str],str]`): The names of the modules to apply DisLoRA to. + trainable_modules (`List[str]`): The names of the modules to train when applying DisLoRA. + dislora_alpha (`float`): The alpha parameter for DisLoRA scaling. + merge_weights (`bool`): + Whether to merge the weights of the DisLoRA layers with the base transfoisrmer model in `eval` mode. + """ + + base_model_name_or_path: Optional[str] = field( + default=None, metadata={"help": "The name of the base model to use."} + ) + r: int = field(default=8, metadata={"help": "DisLoRA attention dimension"}) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with DisLoRA." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + }, + ) + trainable_modules: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to train when applying with DisLoRA." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + }, + ) + dislora_alpha: int = field(default=12, metadata={"help": "DisLoRA alpha"}) + dislora_dropout: float = field(default=0.0, metadata={"help": "DisLoRA dropout"}) + merge_weights: bool = field( + default=False, metadata={"help": "Merge weights of the original model and the DisLoRA model"} + ) + trainable_bias: Optional[str] = field( + default=None, metadata={"help": "Define trainable bias parameters for the DisLoRA model."} + ) + + tensor_parallel_degree: int = field(default=-1, metadata={"help": "1 for not use tensor parallel"}) + dtype: Optional[str] = field(default=None, metadata={"help": "The data type of tensor"}) + + dash_flag: int = field( # characteristic + default=50, + metadata={"help": "The number of preheating steps before introducing additional low-rank updates"}, + ) + + s_tsd: int = field( # characteristic + default=8, + metadata={"help": "The number of top-k singular vectors dynamically selected after preheating"}, + ) + + ortho_lambda: float = field( # characteristic + default=1, + metadata={"help": "The weight of orthogonal regularization loss"}, + ) + prefer_small_sigma: bool = field( + default=True, + metadata={"help": "Whether to prioritize the smallest singular value in the top-k selection process"}, + ) + + def __post_init__(self): + + if self.target_modules is None: + raise ValueError("The target_modules must be specified as a string or a list of strings.") + if self.r <= 0: + raise ValueError("The rank r of LoRA must be greater than 0.") + if self.dislora_alpha <= 0: + raise ValueError("dislora_alpha must be greater than 0") + if self.r < self.s_tsd: + raise ValueError("The rank r of LoRA must be larger than the number of top-k singular values.") + + @property + def scaling(self): + return self.dislora_alpha / self.r + + @property + def __dict__(self): + return asdict(self) + + def to_dict(self): + return self.__dict__ + + def save_pretrained(self, save_directory): + r""" + This method saves the configuration of your adapter model in a directory. + Args: + save_directory (`str`): + The directory where the configuration will be saved. + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + output_dict = self.__dict__ + output_dict["scaling"] = self.scaling + output_path = os.path.join(save_directory, DISLORA_CONFIG_NAME) + + # save it + with open(output_path, "w") as writer: + writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + This method loads the configuration of your adapter model from a directory. + Args: + pretrained_model_name_or_path (`str`): + The directory or the hub-id where the configuration is saved. + **kwargs: + Additional keyword arguments passed along to the child class initialization. + """ + if os.path.isfile(os.path.join(pretrained_model_name_or_path, DISLORA_CONFIG_NAME)): + config_file = os.path.join(pretrained_model_name_or_path, DISLORA_CONFIG_NAME) + else: + raise ValueError(f"Can't find dislora_config.json at '{pretrained_model_name_or_path}'") + + loaded_attributes = cls.from_json_file(config_file) + loaded_attributes.pop("scaling", None) + + merged_kwargs = {**loaded_attributes, **kwargs} + config = cls(**merged_kwargs) + + return config + + @classmethod + def from_json_file(cls, path_json_file): + r""" + Loads a configuration file from a json file. + Args: + path_json_file (`str`): + The path to the json file. + """ + with open(path_json_file, "r") as file: + json_object = json.load(file) + + return json_object diff --git a/paddlenlp/peft/dislora/dislora_layer.py b/paddlenlp/peft/dislora/dislora_layer.py new file mode 100644 index 000000000000..990d7629816f --- /dev/null +++ b/paddlenlp/peft/dislora/dislora_layer.py @@ -0,0 +1,312 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + + +import warnings +from typing import Union + +import paddle +import paddle.nn as nn + + +class DisLoRALinear(nn.Linear): + """ + Paddle implementation of Direct Low-Rank Adaptation (DisLoRA) layer. + DisLoRA decomposes W into backbone (W_prin) and task-specific (W_res) subspaces via SVD, + further identifying task-specific directions (W_TSD) for fine tuning. + """ + + def __init__( + self, + in_features: int, + out_features: int, + r: int = 8, + dislora_alpha: int = 8, + dislora_dropout: float = 0.0, + dash_flag: int = 50, + s_tsd: int = 8, + prefer_small_sigma: bool = True, + merge_weights: bool = False, + init_lora_weights: Union[bool, str] = True, + **kwargs + ): + + if r <= 0: + raise ValueError(f"`r` must be a positive integer, got {r}") + if s_tsd <= 0: + raise ValueError(f"`s_tsd` must be a positive integer, got {s_tsd}") + + nn.Linear.__init__(self, in_features, out_features, **kwargs) + + original_weight = self.weight.clone() + original_bias = self.bias.clone() if self.bias is not None else None + + self.base_dtype = original_weight.dtype + + delattr(self, "weight") + if hasattr(self, "bias") and self.bias is not None: + delattr(self, "bias") + + self.weight = self.create_parameter( + shape=[in_features, out_features], + default_initializer=nn.initializer.Assign(original_weight), + dtype=self.base_dtype, + attr=paddle.ParamAttr(trainable=False), + ) + + if original_bias is not None: + self.bias = self.create_parameter( + shape=[out_features], + default_initializer=nn.initializer.Assign(original_bias), + dtype=self.base_dtype, + attr=paddle.ParamAttr(trainable=True), + ) + else: + self.bias = None + + self.r = r + self.dislora_alpha = dislora_alpha + self.scaling = dislora_alpha / r + self.dislora_dropout = nn.Dropout(p=dislora_dropout) if dislora_dropout > 0.0 else nn.Identity() + self.dash_flag = dash_flag + self.s_tsd = s_tsd + self.prefer_small_sigma = prefer_small_sigma + self.merge_weights = merge_weights + self.init_lora_weights = init_lora_weights + + self._disable_adapters = False + self.merged = False + + self.register_buffer("step", paddle.to_tensor(0, dtype="int64")) + + self.U = None + self.S = None + self.Vh = None + + self.Direc_Ur = nn.Linear(r, out_features, bias_attr=False) + self.Direc_Sr = self.create_parameter( + shape=[r], default_initializer=nn.initializer.Constant(0.0), dtype=self.base_dtype + ) + self.Direc_Vhr = nn.Linear(in_features, r, bias_attr=False) + self.Direc_Ur.weight.stop_gradient = False + self.Direc_Sr.stop_gradient = False + self.Direc_Vhr.weight.stop_gradient = False + + self.Direc_Utsd = nn.Linear(s_tsd, out_features, bias_attr=False) + self.Direc_Stsd = self.create_parameter( + shape=[s_tsd], default_initializer=nn.initializer.Constant(0.0), dtype=self.base_dtype + ) + self.Direc_Vhtsd = nn.Linear(in_features, s_tsd, bias_attr=False) + + self.Direc_Utsd.weight.stop_gradient = True + self.Direc_Vhtsd.weight.stop_gradient = True + + self._align_dtypes() + + if init_lora_weights: + self._init_lora_weights() + + def _align_dtypes(self): + """Ensure that the data types of all parameters are consistent with those of the base layer.""" + target_dtype = self.base_dtype + + if self.Direc_Ur.weight.dtype != target_dtype: + self.Direc_Ur.weight.set_value(self.Direc_Ur.weight.astype(target_dtype)) + if self.Direc_Vhr.weight.dtype != target_dtype: + self.Direc_Vhr.weight.set_value(self.Direc_Vhr.weight.astype(target_dtype)) + if self.Direc_Utsd.weight.dtype != target_dtype: + self.Direc_Utsd.weight.set_value(self.Direc_Utsd.weight.astype(target_dtype)) + if self.Direc_Vhtsd.weight.dtype != target_dtype: + self.Direc_Vhtsd.weight.set_value(self.Direc_Vhtsd.weight.astype(target_dtype)) + if self.Direc_Sr.dtype != target_dtype: + self.Direc_Sr.set_value(self.Direc_Sr.astype(target_dtype)) + if self.Direc_Stsd.dtype != target_dtype: + self.Direc_Stsd.set_value(self.Direc_Stsd.astype(target_dtype)) + + def _init_lora_weights(self): + """ + Initialize LoRA weights using SVD + Decompose the original weight W into W_prin (frozen backbone) + W_res (trainable residual) + Note: The shape of the Linear weight in PaddlePaddle is [in_features, out_features] + """ + weight_float32 = self.weight.astype("float32") + + weight_transposed = weight_float32.T + + U, S, Vh = paddle.linalg.svd(weight_transposed, full_matrices=False) + + self.U = U.astype(self.base_dtype) + self.S = S.astype(self.base_dtype) + self.Vh = Vh.astype(self.base_dtype) + + if self.prefer_small_sigma: + _, indices = paddle.topk(S, self.r, largest=False) + else: + _, indices = paddle.topk(S, self.r, largest=True) + + self.Direc_Ur.weight.set_value(U[:, indices].T.astype(self.base_dtype)) + self.Direc_Sr.set_value(S[indices].astype(self.base_dtype)) + + self.Direc_Vhr.weight.set_value(Vh[indices, :].T.astype(self.base_dtype)) + self.Direc_Ur.weight.stop_gradient = False + self.Direc_Sr.stop_gradient = False + self.Direc_Vhr.weight.stop_gradient = False + self.Direc_Stsd.stop_gradient = False + + S_diag = paddle.diag(self.Direc_Sr) # [r, r] + W_res_T = self.Direc_Ur.weight.T @ S_diag @ self.Direc_Vhr.weight.T # [out_features, in_features] + W_res = W_res_T.T * self.scaling # [in_features, out_features] + + if W_res.shape != self.weight.shape: + raise ValueError(f"Expected W_res shape {self.weight.shape}, but got {W_res.shape}.") + + self.weight.set_value(self.weight - W_res.astype(self.base_dtype)) + self.weight.stop_gradient = True + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + """ + Forward propagation: W_prin @ x + W_res @ x + W_TSD @ x + - W_prin is calculated through the base_layer + - W_res is calculated through the trainable LoRA structure + - W_TSD is calculated through the frozen dynamic vector (after warmup) + """ + if self._disable_adapters: + if self.merged: + self.unmerge() + return super().forward(x) + + if self.merged: + return super().forward(x) + + result = super().forward(x) + + temp = self.dislora_dropout(x) + temp = self.Direc_Vhr(temp) + temp = temp * self.Direc_Sr + temp = self.Direc_Ur(temp) + result += temp * self.scaling + + if self.step < self.dash_flag: + pass + elif self.step == self.dash_flag: + self._initialize_dynamic_vectors() + else: + temp = self.dislora_dropout(x) + temp = self.Direc_Vhtsd(temp) + temp = temp * self.Direc_Stsd + temp = self.Direc_Utsd(temp) + result += temp * self.scaling + + if self.training: + with paddle.no_grad(): + self.step += 1 + + return result + + def _initialize_dynamic_vectors(self): + """ + After the warm-up steps, initialize the dynamic singular vector W_TSD. + Based on the current change of W_res, select the most important s_tsd directions. + """ + with paddle.no_grad(): + + S_diag = paddle.diag(self.Direc_Sr) # [r, r] + deltaW_T = self.Direc_Ur.weight.T @ S_diag @ self.Direc_Vhr.weight.T # [out_features, in_features] + + delta_sigma = paddle.diag(self.U.T @ deltaW_T @ self.Vh.T) + + top_indices = self.calculate_change_rate( + self.S, delta_sigma, self.s_tsd, largest=not self.prefer_small_sigma + ) + + self.Direc_Utsd.weight.set_value(self.U[:, top_indices].T.astype(self.base_dtype)) + self.Direc_Stsd.set_value(self.S[top_indices].astype(self.base_dtype)) + self.Direc_Vhtsd.weight.set_value(self.Vh[top_indices, :].T.astype(self.base_dtype)) + + self.Direc_Utsd.weight.stop_gradient = True + self.Direc_Vhtsd.weight.stop_gradient = True + + def calculate_change_rate(self, a: paddle.Tensor, b: paddle.Tensor, s: int, largest: bool = True) -> paddle.Tensor: + """ + Calculate the rate of change of singular values and + select the top-s index change_rate = |b| / (|a| + eps) + """ + with paddle.no_grad(): + + change_rate = paddle.abs(b) / (paddle.abs(a) + 1e-8) + + _, top_s_indices = paddle.topk(change_rate, s, largest=largest) + return top_s_indices + + def merge(self): + """ + Merge the trainable W_res into the base weights. + After merging: base_layer.weight = W_prin + W_res + Note: W_TSD remains frozen and does not participate in the merge. + """ + if self.merged: + warnings.warn("Already merged. Nothing to do.") + return + + if self.r > 0: + + delta_weight = self.get_delta_weight() + orig_weights = self.weight.clone() + orig_weights += delta_weight + self.weight.set_value(orig_weights) + + self.merged = True + + def unmerge(self): + """ + Remove the merging of W_res from the base weights. + After the merging is removed: base_layer.weight = W_prin + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + if self.r > 0: + delta_weight = self.get_delta_weight() + self.weight.set_value(self.weight - delta_weight) + + self.merged = False + + def get_delta_weight(self) -> paddle.Tensor: + """ + Calculate the trainable LoRA incremental weights + It consists of two parts: + 1. W_res = Ur @ diag(Sr) @ Vhr * scaling (transposed) + 2. W_tsd = Utsd @ diag(Stsd) @ Vhtsd * scaling (transposed) + Return the incremental weights with the shape of [in_features, out_features] + """ + + S_diag_r = paddle.diag(self.Direc_Sr) # [r, r] + delta_weight_T = self.Direc_Ur.weight.T @ S_diag_r @ self.Direc_Vhr.weight.T # [out_features, in_features] + delta_weight = delta_weight_T.T * self.scaling # [in_features, out_features] + + if not paddle.all(self.Direc_Stsd == 0.0): + S_diag_tsd = paddle.diag(self.Direc_Stsd) # [s_tsd, s_tsd] + delta_weight_tsd_T = ( + self.Direc_Utsd.weight.T @ S_diag_tsd @ self.Direc_Vhtsd.weight.T + ) # [out_features, in_features] + delta_weight += delta_weight_tsd_T.T * self.scaling # [in_features, out_features] + + return delta_weight.astype(self.base_dtype) + + def enable_adapters(self): + """Enable the adapter""" + self._disable_adapters = False + + def disable_adapters(self): + """Disable adapter""" + self._disable_adapters = True + + def __repr__(self) -> str: + rep = super().__repr__() + return rep diff --git a/paddlenlp/peft/dislora/dislora_model.py b/paddlenlp/peft/dislora/dislora_model.py new file mode 100644 index 000000000000..9be94c0e90bc --- /dev/null +++ b/paddlenlp/peft/dislora/dislora_model.py @@ -0,0 +1,446 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import re +from collections import OrderedDict +from typing import Dict, Union + +import numpy as np +import paddle +import paddle.nn as nn +from paddle.distributed.fleet.meta_parallel import PipelineLayer + +from paddlenlp.transformers import AutoConfig, PretrainedModel +from paddlenlp.transformers.model_utils import _add_variant, dtype_guard +from paddlenlp.utils.log import logger + +from ...utils.env import DISLORA_WEIGHTS_NAME +from .dislora_config import DisLoRAConfig + + +def get_dislora_layers(): + from .dislora_layer import DisLoRALinear + + return { + "DisLoRALinear": DisLoRALinear, + } + + +dislora_layers = get_dislora_layers() +DisLoRALinear = dislora_layers["DisLoRALinear"] +AVAILABLE_LAYERS = [ + DisLoRALinear, +] + + +class DisLoRAModel(nn.Layer): + restore_layer_map: Dict[nn.Layer, nn.Layer] = { + DisLoRALinear: nn.Linear, + } + + def __init__(self, model, dislora_config: DisLoRAConfig) -> None: + super().__init__() + self.model_config = AutoConfig.from_pretrained(dislora_config.base_model_name_or_path) + self.quantized = False + self.dislora_config = dislora_config + self.dislora_split_mapping = {} + if self.dislora_config.dtype is None: + self.dislora_config.dtype = paddle.get_default_dtype() + with dtype_guard(self.dislora_config.dtype): + self.model = self.get_dislora_model(model, dislora_config) + self.is_pipelinemodel = False + if issubclass(type(self.model), PipelineLayer): + raise NotImplementedError("dislora don't support pipeline parallel now") + if dislora_config.tensor_parallel_degree > 1: + self.dislora_config.tensor_parallel_degree = -1 + self.model.config.tensor_parallel_degree = -1 + raise NotImplementedError("dislora don't support tensor parallel now") + # currently tensor_parallel_degree should all be set to -1. + self.forward = self.model.forward + + logger.info("Mark only dislora and trainable_module as trainable.") + self.mark_only_dislora_as_trainable() + + @classmethod + def from_pretrained(cls, model, dislora_path, **kwargs): + dislora_config = kwargs.pop("dislora_config", None) + # init dislora config & dislora model + if not isinstance(dislora_config, DisLoRAConfig): + dislora_config = DisLoRAConfig.from_pretrained(dislora_path) + # define a new variable to conserve original lora_config.tensor_parallel_degree value which will update while initializing lora model + dislora_config_tensor_parallel_degree = dislora_config.tensor_parallel_degree + dislora_model = cls(model, dislora_config) + + # define dislora weight name + dislora_weight_name = DISLORA_WEIGHTS_NAME + + # load and set dislora weight parameter + dislora_weight_path = os.path.join(dislora_path, dislora_weight_name) + if os.path.exists(dislora_weight_path): + # load dislora weight parameter + dislora_state_dict = paddle.load(dislora_weight_path, return_numpy=True) + logger.info(f"Loading the DisLoRA weights from {dislora_weight_path}") + + if ( + dislora_config_tensor_parallel_degree > 1 + and dislora_config_tensor_parallel_degree != model.config.tensor_parallel_degree + ): + raise NotImplementedError( + f"{dislora_config_tensor_parallel_degree} is not equal to {model.config.tensor_parallel_degree}. Please merge DisLoRA weights first." + ) + # set dislora state dict + dislora_model.set_state_dict(dislora_state_dict) + else: + logger.error(f"DisLoRA weights not found under {dislora_path}, creating DisLoRA weights from scratch") + + return dislora_model + + def set_state_dict(self, state_dict): + import warnings + + warnings.filterwarnings( + action="ignore", message=".*Skip loading for.*", category=Warning, lineno=0, append=False + ) + self.model.set_state_dict(state_dict) + logger.info("Load dislora weight successfully") + + def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False, **kwargs): + logger.info("save dislora pretrained") + save_model_config = kwargs.get("save_model_config", True) + + variant = kwargs.get("variant", None) + is_main_process = kwargs.get("is_main_process", paddle.distributed.get_rank() == 0) + + assert not os.path.isfile( + save_directory + ), f"Saving directory ({save_directory}) should be a directory, not a file" + os.makedirs(save_directory, exist_ok=True) + + dislora_config_to_save = DisLoRAConfig(**self.dislora_config.to_dict()) + trainable_state_dict = self.get_trainable_state_dict() + + # save dislora weight + dislora_weight_name = _add_variant(DISLORA_WEIGHTS_NAME, variant) + weight_filename = os.path.join(save_directory, dislora_weight_name) + paddle.save(trainable_state_dict, weight_filename) + + # save dislora config + if is_main_process: + dislora_config_to_save.save_pretrained(save_directory) + if save_model_config: + model_config_to_save = copy.deepcopy(self.model.config) + if merge_tensor_parallel: + model_config_to_save.tensor_parallel_degree = -1 + model_config_to_save.save_pretrained(save_directory) + + def _find_and_replace_module(self, model, module_name, dislora_config): + + if any(dislora_keyword in module_name.lower() for dislora_keyword in ["dislora", "direc_"]): + logger.debug(f"Skipping {module_name} - appears to be a DisLoRA submodule") + return + + try: + parent_module = model + attribute_chain = module_name.split(".") + for name in attribute_chain[:-1]: + parent_module = getattr(parent_module, name) + module = getattr(parent_module, attribute_chain[-1]) + except AttributeError as e: + logger.error(f"Cannot access module {module_name}: {e}") + raise ValueError(f"Cannot access target module {module_name}: {e}") + + if isinstance(module, nn.Linear): + logger.debug(f"Converting {module_name} from nn.Linear to DisLoRALinear") + + try: + dislora_module = DisLoRALinear( + in_features=module.weight.shape[0], + out_features=module.weight.shape[1], + r=dislora_config.r, + dislora_alpha=dislora_config.dislora_alpha, + dislora_dropout=dislora_config.dislora_dropout, + dash_flag=dislora_config.dash_flag, + s_tsd=dislora_config.s_tsd, + prefer_small_sigma=dislora_config.prefer_small_sigma, + merge_weights=dislora_config.merge_weights, + bias_attr=False if module.bias is None else None, + init_lora_weights=False, + ) + + dislora_module.weight.set_value(module.weight) + if module.bias is not None: + dislora_module.bias.set_value(module.bias) + + dislora_module._init_lora_weights() + + setattr(parent_module, attribute_chain[-1], dislora_module) + logger.debug(f"Successfully replaced {module_name}") + + except Exception as e: + logger.error(f"Failed to create DisLoRALinear for {module_name}: {e}") + raise ValueError(f"Failed to create DisLoRALinear for {module_name}: {e}") + + elif isinstance(module, DisLoRALinear): + logger.debug(f"Module {module_name} is already a DisLoRALinear, skipping") + + else: + + module_type = type(module).__name__ + if any(keyword in module_name.lower() for keyword in ["dislora_dropout", "direc_"]): + logger.debug(f"Skipping DisLoRA submodule {module_name} ({module_type})") + return + else: + + error_msg = f"Target module {module_name} is {module_type}, not nn.Linear. DisLoRA can only replace nn.Linear modules." + logger.error(f"Cannot replace {module_name}: expected nn.Linear, got {module_type}") + raise ValueError(error_msg) + + def _find_and_restore_module(self, module_name): + parent_module = self.model + attribute_chain = module_name.split(".") + for name in attribute_chain[:-1]: + parent_module = getattr(parent_module, name) + module = getattr(parent_module, attribute_chain[-1]) + original_model_class = self.restore_layer_map[module.__class__] + original_module = original_model_class(in_features=module.weight.shape[0], out_features=module.weight.shape[1]) + original_module.weight = module.weight + + if isinstance(module, DisLoRALinear): + if not module.merged: + complete_weight = module.weight + module.get_delta_weight() + original_module.weight.set_value(complete_weight) + else: + original_module.weight.set_value(module.weight) + else: + original_module.weight.set_value(module.weight) + + if module.bias is not None: + original_module.bias.set_value(module.bias) + + setattr(parent_module, attribute_chain[-1], original_module) + + def get_trainable_state_dict(self): + """ + Obtain the required state dictionary to be saved, including: + 1. Trainable parameters (stop_gradient = False) + 2. Main weight W_prin (although frozen, must be saved) + 3. TSD direction parameters (although frozen, must be saved) + 4. QAT-related parameters + """ + trainable_state_dict = OrderedDict() + for name, weight in self.model.state_dict().items(): + # Save trainable parameters and QAT parameters + if not weight.stop_gradient or "activation_quanter" in name or "weight_quanter" in name: + trainable_state_dict[name] = weight + # Save the main branch weight W_prin (for critical fixes) + elif "weight" in name and any(layer_name in name for layer_name in [".weight"]) and "Direc_" not in name: + trainable_state_dict[name] = weight + logger.debug(f"Saving backbone weight: {name}") + # Save all TSD parameters (excluding Direc_Stsd) + elif any(tsd_param in name for tsd_param in ["Direc_Utsd", "Direc_Vhtsd"]): + trainable_state_dict[name] = weight + logger.debug(f"Saving TSD parameter: {name}") + # Save the bias parameters (if any) + elif "bias" in name and "Direc_" not in name: + trainable_state_dict[name] = weight + logger.debug(f"Saving bias parameter: {name}") + + return trainable_state_dict + + def print_trainable_parameters(self) -> None: + freeze_numel = 0 + trainable_numel = 0 + for _, weight in self.model.state_dict().items(): + if weight.stop_gradient: + freeze_numel += np.prod(weight.shape) + else: + trainable_numel += np.prod(weight.shape) + logger.debug( + f"Frozen parameters: {freeze_numel:.2e} || Trainable parameters:{trainable_numel:.2e} || Total parameters:{freeze_numel+trainable_numel:.2e}|| Trainable:{trainable_numel / (freeze_numel+trainable_numel):.2%}" + ) + + def mark_only_dislora_as_trainable(self) -> None: + """ + Mark only the parameters related to DisLoRA as trainable, while ensuring that the TSD parameters remain in a frozen state. + """ + + for full_param_name, weight in self.model.state_dict().items(): + + is_dislora_layer = any( + re.fullmatch(target_module, full_param_name.rsplit(".", 1)[0]) + for target_module in self.dislora_config.target_modules + ) + + if is_dislora_layer: + param_name = full_param_name.split(".")[-1] + + if param_name == "weight" and "Direc_" not in full_param_name: + weight.stop_gradient = True + logger.debug(f"Freezing backbone weight: {full_param_name}") + + elif param_name == "bias" and "Direc_" not in full_param_name: + if self.dislora_config.trainable_bias in ["dislora", "all"]: + weight.stop_gradient = False + logger.debug(f"Setting bias as trainable: {full_param_name}") + else: + weight.stop_gradient = True + logger.debug(f"Freezing bias: {full_param_name}") + + elif any(tsd_param in full_param_name for tsd_param in ["Direc_Utsd", "Direc_Vhtsd"]): + weight.stop_gradient = True + logger.debug(f"Keeping TSD parameter frozen: {full_param_name}") + + elif any( + trainable_param in full_param_name + for trainable_param in ["Direc_Ur", "Direc_Sr", "Direc_Vhr", "Direc_Stsd"] + ): + weight.stop_gradient = False + logger.debug(f"Setting DisLoRA parameter as trainable: {full_param_name}") + + else: + weight.stop_gradient = True + logger.debug(f"Freezing other parameter: {full_param_name}") + + else: + param_name = full_param_name.split(".")[-1] + if self.dislora_config.trainable_bias == "all" and param_name == "bias": + weight.stop_gradient = False + logger.debug(f"Setting bias as trainable in non-DisLoRA layer: {full_param_name}") + else: + weight.stop_gradient = True + logger.debug(f"Freezing parameter in non-DisLoRA layer: {full_param_name}") + + if self.dislora_config.trainable_modules is not None: + for full_param_name, weight in self.model.state_dict().items(): + if any( + re.fullmatch(trainable_module, full_param_name) + for trainable_module in self.dislora_config.trainable_modules + ): + + if not any(tsd_param in full_param_name for tsd_param in ["Direc_Utsd", "Direc_Vhtsd"]): + weight.stop_gradient = False + logger.debug(f"Setting additional trainable module parameter: {full_param_name}") + else: + logger.warning( + f"TSD parameter {full_param_name} matched trainable_modules pattern but kept frozen" + ) + + def get_dislora_model(self, model: Union[PretrainedModel, nn.Layer], dislora_config: DisLoRAConfig): + """ + Iterate all base model layers, change target modules to DisLoRALayer. + """ + if dislora_config.target_modules is None: + return model + else: + target_modules = dislora_config.target_modules + + target_module_names = [] + + existing_dislora_paths = set() + for module_name, module in model.named_sublayers(): + if isinstance(module, DisLoRALinear): + existing_dislora_paths.add(module_name) + + for target_module in target_modules: + for module_name, module in model.named_sublayers(): + + if re.fullmatch(target_module, module_name): + + if not isinstance(module, DisLoRALinear): + + is_submodule = any( + module_name.startswith(dislora_path + ".") for dislora_path in existing_dislora_paths + ) + + if not is_submodule: + target_module_names.append(module_name) + else: + logger.debug(f"Skipping {module_name} - it's a submodule of existing DisLoRA module") + else: + logger.debug(f"Skipping {module_name} - already a DisLoRA module") + + for module_name in target_module_names: + try: + self._find_and_replace_module(model, module_name, dislora_config) + logger.debug(f"Replaced {module_name} with DisLoRALinear") + except ValueError as e: + raise e + except Exception as e: + + logger.warning(f"Failed to replace {module_name}: {e}") + + return model + + def restore_original_model(self): + # make sure W and dislora weights are not merged before we restore the original model + for layer_name, layer in self.model.named_sublayers(): + if isinstance(layer, DisLoRALinear): + self._find_and_restore_module(layer_name) + return self.model + + def __getattr__(self, name: str): + """ + Forward missing attributes to the wrapped module. + """ + try: + return super().__getattr__(name) # defer to nn.Layer's logic + except AttributeError: + return getattr(self.model, name) + + def train(self): + self.training = True + self.model.training = True + for layer in self.model.sublayers(): + layer.training = True + layer.train() + + def eval(self): + self.training = False + self.model.training = False + for layer in self.model.sublayers(): + layer.training = False + layer.eval() + + def disable_dislora(self): + """ + Disable the DisLoRA adapter + """ + for _, layer in self.model.named_sublayers(): + if isinstance(layer, DisLoRALinear): + layer.disable_adapters() + + def enable_dislora(self): + """ + Enable the DisLoRA adapter + """ + for _, layer in self.model.named_sublayers(): + if isinstance(layer, DisLoRALinear): + layer.enable_adapters() + + def merge(self): + for _, layer in self.model.named_sublayers(): + if any(isinstance(layer, dislora_layer) for dislora_layer in AVAILABLE_LAYERS): + layer.merge() + + def unmerge(self): + for _, layer in self.model.named_sublayers(): + if any(isinstance(layer, dislora_layer) for dislora_layer in AVAILABLE_LAYERS): + layer.unmerge() + + def get_model_config( + self, + ): + return self.model_config.to_dict() diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index b48679dbf26a..cb7ff1454c96 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -82,7 +82,14 @@ default_data_collator, init_dataloader_comm_group, ) -from ..peft import LoKrModel, LoRAModel, PrefixModelForCausalLM, ReFTModel, VeRAModel +from ..peft import ( + DisLoRAModel, + LoKrModel, + LoRAModel, + PrefixModelForCausalLM, + ReFTModel, + VeRAModel, +) from ..quantization.quantization_linear import ( ColumnParallelQuantizationLinear, QuantizationLinear, @@ -107,6 +114,7 @@ from ..transformers.tokenizer_utils import PretrainedTokenizer from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler from ..utils.env import ( + DISLORA_WEIGHTS_NAME, LOKR_WEIGHTS_NAME, LORA_WEIGHTS_NAME, MODEL_META_NAME, @@ -464,6 +472,7 @@ def _save_ckpt_func(state_dict, path, signal_path=None): or isinstance(self.model, PrefixModelForCausalLM) or isinstance(self.model, VeRAModel) or isinstance(self.model, LoKrModel) + or isinstance(self.model, DisLoRAModel) or isinstance(self.model, ReFTModel) ): if self.args.unified_checkpoint and "skip_save_model_weight" in self.args.unified_checkpoint_config: @@ -616,6 +625,8 @@ def _load_from_peft_checkpoint(self, resume_from_checkpoint=None): weights_file = os.path.join(resume_from_checkpoint, VERA_WEIGHTS_NAME) elif isinstance(self.model, LoKrModel): weights_file = os.path.join(resume_from_checkpoint, LOKR_WEIGHTS_NAME) + elif isinstance(self.model, DisLoRAModel): + weights_file = os.path.join(resume_from_checkpoint, DISLORA_WEIGHTS_NAME) elif isinstance(self.model, ReFTModel): self.model.from_pretrained(resume_from_checkpoint, self.model.model) return @@ -681,6 +692,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): or isinstance(self.model, PrefixModelForCausalLM) or isinstance(self.model, VeRAModel) or isinstance(self.model, LoKrModel) + or isinstance(self.model, DisLoRAModel) or isinstance(self.model, ReFTModel) ): self._load_from_peft_checkpoint(resume_from_checkpoint) @@ -2996,6 +3008,7 @@ def _save( or isinstance(self.model, PrefixModelForCausalLM) or isinstance(self.model, VeRAModel) or isinstance(self.model, LoKrModel) + or isinstance(self.model, DisLoRAModel) or isinstance(self.model, ReFTModel) ): self.model.save_pretrained( diff --git a/paddlenlp/trl/model_config.py b/paddlenlp/trl/model_config.py index 2e244d211158..66a3be6a38d6 100644 --- a/paddlenlp/trl/model_config.py +++ b/paddlenlp/trl/model_config.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Optional +from typing import List, Optional __all__ = ["ModelConfig"] @@ -90,6 +90,27 @@ class ModelConfig: ) lokr_dim: int = field(default=8, metadata={"help": "Lora dimension in LoKr dimension for adapter matrix"}) + # dislora related parameters + dislora: bool = field(default=False, metadata={"help": "Whether to use dislora technique"}) + dislora_path: str = field(default=None, metadata={"help": "Initialize dislora state dict."}) + dislora_rank: int = field(default=8, metadata={"help": "DisLoRA attention dimension"}) + dislora_dropout: float = field(default=0.05, metadata={"help": "DisLoRA dropout"}) + target_modules: Optional[List[str]] = field( + default=None, + metadata={"help": "Custom target modules for DisLoRA. If None, will use default modules based on model type."}, + ) + dash_flag: int = field( + default=50, metadata={"help": "The number of preheating steps before introducing additional low-rank updates"} + ) + s_tsd: int = field( + default=8, metadata={"help": "The number of top-k singular vectors dynamically selected after preheating"} + ) + ortho_lambda: float = field(default=1, metadata={"help": "The weight of orthogonal regularization loss"}) + prefer_small_sigma: bool = field( + default=True, + metadata={"help": "Whether to prioritize the smallest singular value in the top-k selection process"}, + ) + # prefix tuning related parameters prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"}) prefix_path: str = field(default=None, metadata={"help": "Initialize prefix state dict."}) diff --git a/paddlenlp/trl/sft_config.py b/paddlenlp/trl/sft_config.py index f759bc68a1aa..be8152e4bb90 100644 --- a/paddlenlp/trl/sft_config.py +++ b/paddlenlp/trl/sft_config.py @@ -75,6 +75,10 @@ class SFTConfig(TrainingArguments): "help": "The ratio parameter for grouping in SSA, controlling the number of tokens considered in each group for sparse attention calculation." }, ) + dislora_ortho_lambda: float = field( + default=0.0, + metadata={"help": "Orthogonal regularization weight for DisLoRA. Set to 1 for Pareto optimization."}, + ) def __post_init__(self): super().__post_init__() diff --git a/paddlenlp/trl/sft_trainer.py b/paddlenlp/trl/sft_trainer.py index 8466bfc7abba..fd1d6ce6be65 100644 --- a/paddlenlp/trl/sft_trainer.py +++ b/paddlenlp/trl/sft_trainer.py @@ -419,3 +419,70 @@ def ptq_loop( self.prediction_step(model=self.model, inputs=inputs, prediction_loss_only=True, ignore_keys=None) if max_eval_iters > 0 and step >= max_eval_iters - 1: break + + def _calc_ortho_loss(self, model): + """Calculate the orthogonal constraint loss of DisLoRA""" + import paddle + + ortho_loss = 0.0 + den = 0 + + for name, param in model.named_parameters(): + if "Direc_Ur" in name and "weight" in name: + u = param + iu = paddle.eye(u.shape[0], dtype=u.dtype) + u_loss = paddle.norm(u @ u.T - iu, p="fro") + ortho_loss += u_loss + den += 1 + + elif "Direc_Vhr" in name and "weight" in name: + vh = param + ivh = paddle.eye(vh.shape[1], dtype=vh.dtype) + vh_loss = paddle.norm(vh.T @ vh - ivh, p="fro") + ortho_loss += vh_loss + den += 1 + + if den > 0: + return ortho_loss / den + else: + return None + + def compute_loss(self, model, inputs, return_outputs=False): + """Override compute_loss to add DisLoRA orthogonal regularization""" + import paddle + + result = super().compute_loss(model, inputs, return_outputs=False) + + if isinstance(result, tuple): + loss = result[0] + outputs = result[1] if len(result) > 1 else None + else: + loss = result + outputs = None + + if isinstance(loss, tuple): + loss = loss[0] + + if hasattr(self.args, "dislora_ortho_lambda") and self.args.dislora_ortho_lambda > 0: + ortho_loss = self._calc_ortho_loss(model) + + if ortho_loss is not None and loss is not None: + + if loss.numel() > 1: + loss = loss.mean() + if ortho_loss.numel() > 1: + ortho_loss = ortho_loss.mean() + + if abs(self.args.dislora_ortho_lambda - 1.0) < 1e-6: + + with paddle.no_grad(): + ratio = ortho_loss / (loss + 1e-8) + alpha_task = paddle.exp(-ratio) / (paddle.exp(-ratio) + paddle.exp(-1 / ratio)) + alpha_ortho = 1.0 - alpha_task + + loss = alpha_task * loss + alpha_ortho * ortho_loss + else: + + loss = loss + self.args.dislora_ortho_lambda * ortho_loss + + return (loss, outputs) if return_outputs else loss diff --git a/paddlenlp/utils/env.py b/paddlenlp/utils/env.py index 62503e09a39e..0489d6b6d6cc 100644 --- a/paddlenlp/utils/env.py +++ b/paddlenlp/utils/env.py @@ -111,6 +111,9 @@ def _get_bool_env(env_key: str, default_value: str) -> bool: LOKR_WEIGHTS_NAME = "lokr_model_state.pdparams" LOKR_CONFIG_NAME = "lokr_config.json" +DISLORA_WEIGHTS_NAME = "dislora_model_state.pdparams" +DISLORA_CONFIG_NAME = "dislora_config.json" + PAST_KEY_VALUES_FILE_NAME = "pre_caches.npy" PADDLE_WEIGHTS_NAME = "model_state.pdparams" diff --git a/tests/fixtures/llm/dislora.yaml b/tests/fixtures/llm/dislora.yaml new file mode 100644 index 000000000000..15500928b97d --- /dev/null +++ b/tests/fixtures/llm/dislora.yaml @@ -0,0 +1,78 @@ +dislora: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 1 + gradient_accumulation_steps: 5 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 1 + learning_rate: 2e-05 + lr_scheduler_type: linear + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "no" + save_strategy: "steps" + save_steps: 500 + src_length: 256 + max_length: 256 + fp16: true + fp16_opt_level: "O2" + do_train: true + do_eval: false + disable_tqdm: false + load_best_model_at_end: false + eval_with_do_generation: false + recompute: false + save_total_limit: 5 + sharding: "stage3" + zero_padding: false + use_flash_attention: false + unified_checkpoint: false + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + dislora: true + dislora_rank: 8 + dislora_dropout: 0.05 + + s_tsd: 8 + ortho_lambda: 1.0 + prefer_small_sigma: true + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + chatglm: + model_name_or_path: __internal_testing__/tiny-fused-chatglm + chatglm2: + model_name_or_path: __internal_testing__/tiny-fused-chatglm2 + bloom: + model_name_or_path: __internal_testing__/tiny-fused-bloom + qwen: + model_name_or_path: __internal_testing__/tiny-fused-qwen + qwen2: + model_name_or_path: __internal_testing__/tiny-random-qwen2 + qwen2moe: + model_name_or_path: __internal_testing__/tiny-random-qwen2moe + baichuan: + model_name_or_path: __internal_testing__/tiny-fused-baichuan + +inference-predict: + default: + mode: dynamic + max_length: 20 + batch_size: 2 + decode_strategy: greedy_search + dtype: float16 + +inference-to-static: + default: + dtype: float16 + max_length: 20 + +inference-infer: + default: + mode: static + dtype: float16 + batch_size: 2 + decode_strategy: greedy_search + max_length: 20 \ No newline at end of file diff --git a/tests/llm/test_dislora.py b/tests/llm/test_dislora.py new file mode 100644 index 000000000000..957962f71ab6 --- /dev/null +++ b/tests/llm/test_dislora.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +import unittest + +import paddle +from parameterized import parameterized_class + +from tests.testing_utils import argv_context_guard, load_test_config + +from .testing_utils import LLMTest + + +@parameterized_class( + ["model_dir"], + [ + ["llama"], + ["chatglm"], + ["chatglm2"], + ["bloom"], + ["qwen"], + ["baichuan"], + ], +) +class DisLoRATest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/dislora.yaml" + model_dir: str = None + + def setUp(self) -> None: + LLMTest.setUp(self) + + self.model_codes_dir = os.path.join(self.root_path, self.model_dir) + sys.path.insert(0, self.model_codes_dir) + + def tearDown(self) -> None: + LLMTest.tearDown(self) + sys.path.remove(self.model_codes_dir) + + def test_dislora(self): + self.disable_static() + paddle.set_default_dtype("float32") + + dislora_config = load_test_config(self.config_path, "dislora", self.model_dir) + dislora_config["output_dir"] = self.output_dir + dislora_config["dataset_name_or_path"] = self.data_dir + + with argv_context_guard(dislora_config): + from run_finetune import main + + main() + + # merge weights + merge_dislora_weights_config = { + "dislora_path": dislora_config["output_dir"], + "merge_dislora_model_path": dislora_config["output_dir"], + "device": "gpu", + "low_gpu_mem": True, + } + with argv_context_guard(merge_dislora_weights_config): + from tools.merge_dislora_params import merge + + merge() + + # # TODO(wj-Mcat): disable chatglm2 test temporarily + # if self.model_dir not in ["qwen", "baichuan", "chatglm2"]: + # self.run_predictor({"inference_model": True}) + + self.run_predictor({"inference_model": False}) diff --git a/tests/peft/test_dislora.py b/tests/peft/test_dislora.py new file mode 100644 index 000000000000..fb5e8db8396e --- /dev/null +++ b/tests/peft/test_dislora.py @@ -0,0 +1,232 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import re +import unittest +from tempfile import TemporaryDirectory + +import numpy as np +import paddle +from parameterized import parameterized + +from paddlenlp.peft.dislora import DisLoRAConfig, DisLoRALinear, DisLoRAModel +from paddlenlp.transformers import AutoModel, BertModel + + +class TestDisLoRALayer(unittest.TestCase): + def test_r_raise_exception(self): + with self.assertRaises(ValueError): + DisLoRALinear(in_features=16, out_features=8, r=0, dislora_alpha=8) + + def test_forward(self): + # r=8, dislora_alpha=12 (1.5 * 8) + dislora_layer = DisLoRALinear(in_features=16, out_features=8, r=8, dislora_dropout=0.1, dislora_alpha=12) + x = paddle.randn([2, 4, 16], "float32") + output = dislora_layer(x) + + # Check the trainable DisLoRA parameters (related to W_res) + self.assertFalse(dislora_layer.Direc_Ur.weight.stop_gradient) + self.assertFalse(dislora_layer.Direc_Vhr.weight.stop_gradient) + self.assertFalse(dislora_layer.Direc_Sr.stop_gradient) + self.assertFalse(dislora_layer.Direc_Stsd.stop_gradient) + + # Check the frozen TSD parameters + self.assertTrue(dislora_layer.Direc_Utsd.weight.stop_gradient) + self.assertTrue(dislora_layer.Direc_Vhtsd.weight.stop_gradient) + + # Check the frozen main branch weights W_prin + self.assertTrue(dislora_layer.weight.stop_gradient) + + # Check the bias parameters (by default, they should be trainable, but this depends on the configuration) + if dislora_layer.bias is not None: + self.assertFalse(dislora_layer.bias.stop_gradient) + + self.assertEqual(output.shape, [2, 4, 8]) + + def test_train_eval(self): + x = paddle.randn([2, 4, 16], "float32") + + dislora_layer = DisLoRALinear(in_features=16, out_features=8, r=8, dislora_alpha=12) + dislora_layer.train() + train_result = dislora_layer(x) + train_weight = copy.deepcopy(dislora_layer.weight) + dislora_layer.eval() + eval_result = dislora_layer(x) + eval_weight = dislora_layer.weight + self.assertTrue(paddle.allclose(train_result, eval_result)) + self.assertTrue(paddle.allclose(train_weight, eval_weight)) + + def test_save_load(self): + with TemporaryDirectory() as tempdir: + + dislora_layer = DisLoRALinear(in_features=16, out_features=8, r=8, dislora_alpha=12) + weights_path = os.path.join(tempdir, "model.pdparams") + paddle.save(dislora_layer.state_dict(), weights_path) + + new_dislora_layer = DisLoRALinear(in_features=16, out_features=8, r=8, dislora_alpha=12) + state_dict = paddle.load(weights_path) + new_dislora_layer.set_dict(state_dict) + x = paddle.randn([2, 4, 16], "float32") + self.assertTrue(paddle.allclose(new_dislora_layer(x), dislora_layer(x))) + + def test_load_regular_linear(self): + with TemporaryDirectory() as tempdir: + regular_linear = paddle.nn.Linear(in_features=16, out_features=12) + weights_path = os.path.join(tempdir, "model.pdparams") + paddle.save(regular_linear.state_dict(), weights_path) + state_dict = paddle.load(weights_path) + # should be identical to regular linear + + dislora_layer_r8 = DisLoRALinear( + in_features=16, out_features=12, r=8, dislora_alpha=12, init_lora_weights=False + ) + + dislora_layer_r10 = DisLoRALinear( + in_features=16, out_features=12, r=10, dislora_alpha=15, init_lora_weights=False + ) + + # Load regular linear weights first + filtered_state_dict = {k: v for k, v in state_dict.items() if k in ["weight", "bias"]} + dislora_layer_r8.set_dict(filtered_state_dict) + dislora_layer_r10.set_dict(filtered_state_dict) + + # Then perform SVD initialization + dislora_layer_r8._init_lora_weights() + dislora_layer_r10._init_lora_weights() + + x = paddle.randn([2, 4, 16], "float32") + + diff_r8 = paddle.abs(dislora_layer_r8(x) - regular_linear(x)) + print(f"R8 - Max diff: {paddle.max(diff_r8).item():.6e}, Mean diff: {paddle.mean(diff_r8).item():.6e}") + self.assertTrue(paddle.allclose(dislora_layer_r8(x), regular_linear(x), atol=2e-3)) + # Update variable name + self.assertTrue(paddle.allclose(dislora_layer_r10(x), regular_linear(x), atol=2e-3)) + + +class TestDisLoRAModel(unittest.TestCase): + def test_dislora_model_restore(self): + + dislora_config = DisLoRAConfig( + target_modules=[".*q_proj.*", ".*v_proj.*"], + r=8, + dislora_alpha=12, + base_model_name_or_path="__internal_testing__/tiny-random-bert", + ) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20])) + model.eval() + original_results_1 = model(input_ids) + dislora_model = DisLoRAModel(model, dislora_config) + restored_model = dislora_model.restore_original_model() + restored_model.eval() + original_results_2 = restored_model(input_ids) + self.assertIsNotNone(original_results_1) + self.assertIsNotNone(original_results_2) + self.assertIsInstance(restored_model, BertModel) + self.assertTrue(paddle.allclose(original_results_1[0], original_results_2[0])) + + @parameterized.expand([(None,), ("all",), ("dislora",)]) + def test_dislora_model_constructor(self, bias): + + dislora_config = DisLoRAConfig( + target_modules=[".*q_proj.*", ".*v_proj.*"], + r=8, + dislora_alpha=12, + trainable_bias=bias, + base_model_name_or_path="__internal_testing__/tiny-random-bert", + ) + model = AutoModel.from_pretrained( + "__internal_testing__/tiny-random-bert", hidden_dropout_prob=0, attention_probs_dropout_prob=0 + ) + dislora_model = DisLoRAModel(model, dislora_config) + dislora_model.mark_only_dislora_as_trainable() + for name, weight in dislora_model.state_dict().items(): + if any([re.fullmatch(target_module, name) for target_module in dislora_config.target_modules]): + if any( + [dislora_param in name for dislora_param in ["Direc_Ur", "Direc_Sr", "Direc_Vhr", "Direc_Stsd"]] + ): + self.assertFalse(weight.stop_gradient) + elif any([tsd_param in name for tsd_param in ["Direc_Utsd", "Direc_Vhtsd"]]): + self.assertTrue(weight.stop_gradient) + elif "bias" in name and bias in ["dislora", "all"]: + self.assertFalse(weight.stop_gradient) + else: + self.assertTrue(weight.stop_gradient) + else: + if "bias" in name and bias == "all": + self.assertFalse(weight.stop_gradient) + else: + self.assertTrue(weight.stop_gradient) + + input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20])) + dislora_model.train() + train_forward_results = dislora_model(input_ids) + self.assertIsNotNone(train_forward_results) + dislora_model.eval() + eval_forward_results = dislora_model(input_ids) + self.assertIsNotNone(eval_forward_results) + self.assertTrue(paddle.allclose(train_forward_results[0], eval_forward_results[0])) + + def test_dislora_model_save_load(self): + with TemporaryDirectory() as tempdir: + input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20])) + + dislora_config = DisLoRAConfig( + target_modules=[".*q_proj.*", ".*v_proj.*"], + r=8, + dislora_alpha=12, + base_model_name_or_path="__internal_testing__/tiny-random-bert", + ) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + dislora_model = DisLoRAModel(model, dislora_config) + dislora_model.eval() + original_results = dislora_model(input_ids) + dislora_model.save_pretrained(tempdir) + + loaded_dislora_model = DisLoRAModel.from_pretrained(model, tempdir) + loaded_dislora_model.eval() + loaded_results = loaded_dislora_model(input_ids) + self.assertTrue(paddle.allclose(original_results[0], loaded_results[0])) + + config_loaded_dislora_model = DisLoRAModel.from_pretrained(model, tempdir, dislora_config=dislora_config) + config_loaded_dislora_model.eval() + config_loaded_results = config_loaded_dislora_model(input_ids) + self.assertTrue(paddle.allclose(original_results[0], config_loaded_results[0])) + + def test_dislora_module_raise_exception(self): + + dislora_config = DisLoRAConfig( + target_modules=[".*norm1.*"], + r=8, + dislora_alpha=12, + base_model_name_or_path="__internal_testing__/tiny-random-bert", + ) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + with self.assertRaises(ValueError): + DisLoRAModel(model, dislora_config) + + +class TestDisLoRAConfig(unittest.TestCase): + def test_save_load(self): + with TemporaryDirectory() as tempdir: + # Set r and dislora_alpha explicitly + dislora_config = DisLoRAConfig(target_modules=["test"], r=8, dislora_alpha=12) + dislora_config.save_pretrained(tempdir) + loaded_dislora_config = DisLoRAConfig.from_pretrained(tempdir) + self.assertEqual(dislora_config.r, loaded_dislora_config.r) + self.assertEqual(dislora_config.dislora_alpha, loaded_dislora_config.dislora_alpha) + self.assertEqual(dislora_config.dash_flag, loaded_dislora_config.dash_flag) + self.assertEqual(dislora_config.s_tsd, loaded_dislora_config.s_tsd)