diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 122e1c1049..51a42558d2 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -87,6 +87,8 @@ PromptEncoderReparameterizationType, PromptTuningConfig, PromptTuningInit, + RandLoraConfig, + RandLoraModel, TrainableTokensConfig, TrainableTokensModel, VBLoRAConfig, @@ -178,6 +180,8 @@ "PromptLearningConfig", "PromptTuningConfig", "PromptTuningInit", + "RandLoraConfig", + "RandLoraModel", "TaskType", "TrainableTokensConfig", "TrainableTokensModel", diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 65abbd4046..bb38230bf0 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -39,6 +39,7 @@ from .poly import PolyConfig, PolyModel from .prefix_tuning import PrefixEncoder, PrefixTuningConfig from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit +from .randlora import RandLoraConfig, RandLoraModel from .trainable_tokens import TrainableTokensConfig, TrainableTokensModel from .vblora import VBLoRAConfig, VBLoRAModel from .vera import VeraConfig, VeraModel @@ -89,6 +90,8 @@ "PromptEncoderReparameterizationType", "PromptTuningConfig", "PromptTuningInit", + "RandLoraConfig", + "RandLoraModel", "TrainableTokensConfig", "TrainableTokensModel", "VBLoRAConfig", diff --git a/src/peft/tuners/randlora/__init__.py b/src/peft/tuners/randlora/__init__.py new file mode 100644 index 0000000000..fbad681aeb --- /dev/null +++ b/src/peft/tuners/randlora/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2025-present the HuggingFace Inc. team. + +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.utils import register_peft_method + +from .config import RandLoraConfig +from .layer import Linear, RandLoraLayer +from .model import RandLoraModel + + +__all__ = ["Linear", "RandLoraConfig", "RandLoraLayer", "RandLoraModel"] + +register_peft_method(name="randlora", config_cls=RandLoraConfig, model_cls=RandLoraModel, prefix="randlora_") + + +def __getattr__(name): + if (name == "Linear8bitLt") and is_bnb_available(): + from .bnb import Linear8bitLt + + return Linear8bitLt + + if (name == "Linear4bit") and is_bnb_4bit_available(): + from .bnb import Linear4bit + + return Linear4bit + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/peft/tuners/randlora/bnb.py b/src/peft/tuners/randlora/bnb.py new file mode 100644 index 0000000000..84983643ef --- /dev/null +++ b/src/peft/tuners/randlora/bnb.py @@ -0,0 +1,456 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Optional + +import bitsandbytes as bnb +import torch + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import check_adapters_to_merge +from peft.utils.integrations import dequantize_bnb_weight +from peft.utils.other import transpose + +from .layer import RandLoraLayer, UniqueBaseGrad + + +if is_bnb_available(): + + class Linear8bitLt(torch.nn.Module, RandLoraLayer): + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + randlora_A, + randlora_B, + r: int = 0, + randlora_alpha: int = 0, + randlora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + RandLoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + randlora_A, + randlora_B, + r, + randlora_alpha=randlora_alpha, + randlora_dropout=randlora_dropout, + init_weights=init_weights, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + for active_adapter in adapter_names: + if active_adapter not in self.randlora_lambda.keys(): + continue + + warnings.warn( + "Merge RandLora module to 8-bit linear may get different generations due to rounding errors." + ) + randlora_data = self.get_delta_weight(active_adapter) + + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + + output = dequantize_bnb_weight(weight, state) + w_data = output.to(randlora_data.dtype).to(randlora_data.device) + randlora_data + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.get_base_layer().weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + state.reset_grads() + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.randlora_lambda.keys(): + continue + warnings.warn( + "Unmerge randlora module to 8-bit linear may get different generations due to rounding errors." + ) + randlora_data = self.get_delta_weight(active_adapter) + + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + output = dequantize_bnb_weight(weight, state=state) + + w_data = output.to(randlora_data.dtype).to(randlora_data.device) - randlora_data + + self.get_base_layer().weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + state.reset_grads() + + def get_scaled_bases(self, adapter, device=None) -> list[torch.Tensor, torch.Tensor]: + """ + Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the + correct order to fit the target layers' dimensions + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + randlora_A = self.randlora_A[adapter] + randlora_B = self.randlora_B[adapter] + + if device is None: + device = randlora_B.device + dtype = randlora_B.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + randlora_lambda = self.randlora_lambda[adapter].to(device) + randlora_gamma = self.randlora_gamma[adapter].to(device) + + if cast_to_fp32: + randlora_A = randlora_A.float() + randlora_B = randlora_B.float() + randlora_lambda = randlora_lambda.float() + randlora_gamma = randlora_gamma.float() + + # The trainable paramters are always applied to randlora_A, the smallest basis. + min_dim, max_dim = min(self.out_features, self.in_features), max(self.out_features, self.in_features) + + # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, + # we initialize these matrices with the largest required size for each dimension. + # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. + sliced_A = randlora_A[:, : self.num_bases, :min_dim].to(device) + sliced_B = randlora_B[:max_dim, : self.num_bases, :].to(device) + + # Flattening the matrices over the rank and number of bases dimensions is more memory efficient + update_B = sliced_B.flatten(start_dim=1) + update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) + if min_dim == self.in_features: + return update_A, update_B + + return update_B.T, update_A.T + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + update_B, update_A = self.get_scaled_bases(adapter) + + update = update_B @ update_A + output_tensor = transpose(update, self.fan_in_fan_out) + + scaling = self.scaling[adapter] + + return output_tensor * scaling + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Perform the forward pass using the RandLora adapter. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the RandLora adaptation. + + Note: + This method implements the RandLora-specific forward pass. It applies the shared projections + (randlora_A and randlora_B) along with the per-layer trainable parameters (lambda and gamma) to compute + the adapter output. + """ + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.randlora_lambda.keys(): + continue + + update_B, update_A = self.get_scaled_bases(active_adapter, device=x.device) + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = update_A.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + dropout = self.randlora_dropout[active_adapter] + x_temp = dropout(x.to(update_A.dtype)) + + adapter_output = torch.nn.functional.linear(torch.nn.functional.linear(x_temp, update_B), update_A) + + if requires_conversion: + adapter_output = adapter_output.to(expected_dtype) + + scaling = self.scaling[active_adapter] + result = result + adapter_output * scaling + + # Ensure the output tensor has the same dtype as the input tensor + return result.to(x.dtype) + + def __repr__(self) -> str: + rep = super().__repr__() + return "randlora." + rep + + +if is_bnb_4bit_available(): + + class Linear4bit(torch.nn.Module, RandLoraLayer): + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + randlora_A, + randlora_B, + r: int = 0, + randlora_alpha: int = 0, + randlora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + RandLoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + randlora_A, + randlora_B, + r, + randlora_alpha=randlora_alpha, + randlora_dropout=randlora_dropout, + init_weights=init_weights, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + for active_adapter in adapter_names: + if active_adapter not in self.randlora_lambda.keys(): + continue + + warnings.warn( + "Merge RandLora module to 4-bit linear may get different generations due to rounding errors." + ) + randlora_data = self.get_delta_weight(active_adapter) + + weight = self.get_base_layer().weight + kwargs = weight.__dict__ + w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + randlora_data + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( + weight.device + ) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.randlora_lambda.keys(): + continue + warnings.warn( + "Unmerge RandLora module to 4-bit linear may get different generations due to rounding errors." + ) + randlora_data = self.get_delta_weight(active_adapter) + + weight = self.get_base_layer().weight + kwargs = weight.__dict__ + w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - randlora_data + + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( + weight.device + ) + + def get_scaled_bases(self, adapter, device=None) -> list[torch.Tensor, torch.Tensor]: + """ + Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the + correct order to fit the target layers' dimensions + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + randlora_A = self.randlora_A[adapter] + randlora_B = self.randlora_B[adapter] + if device is None: + device = randlora_B.device + dtype = randlora_B.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + randlora_lambda = self.randlora_lambda[adapter].to(device) + randlora_gamma = self.randlora_gamma[adapter].to(device) + + if cast_to_fp32: + randlora_A = randlora_A.float() + randlora_B = randlora_B.float() + randlora_lambda = randlora_lambda.float() + randlora_gamma = randlora_gamma.float() + + # The trainable paramters are always applied to randlora_A, the smallest basis. + min_dim, max_dim = min(self.out_features, self.in_features), max(self.out_features, self.in_features) + + # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, + # we initialize these matrices with the largest required size for each dimension. + # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. + sliced_A = randlora_A[:, : self.num_bases, :min_dim].to(device) + sliced_B = randlora_B[:max_dim, : self.num_bases, :].to(device) + # Flattening the matrices over the rank and number of bases dimensions is more memory efficient + update_B = sliced_B.flatten(start_dim=1) + update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) + if min_dim == self.in_features: + return update_A, update_B + + return update_B.T, update_A.T + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + update_B, update_A = self.get_scaled_bases(adapter) + + update = update_B @ update_A + output_tensor = transpose(update, self.fan_in_fan_out) + + scaling = self.scaling[adapter] + + return output_tensor * scaling + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + result = result.clone() + for active_adapter in self.active_adapters: + if active_adapter not in self.randlora_lambda.keys(): + continue + + update_B, update_A = self.get_scaled_bases(active_adapter, device=x.device) + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = update_A.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + dropout = self.randlora_dropout[active_adapter] + x_temp = dropout(x.to(update_A.dtype)) + + adapter_output = torch.nn.functional.linear(torch.nn.functional.linear(x_temp, update_B), update_A) + + if requires_conversion: + adapter_output = adapter_output.to(expected_dtype) + + scaling = self.scaling[active_adapter] + result = result + adapter_output * scaling + + # Ensure the output tensor has the same dtype as the input tensor + return result.to(x.dtype) + + def __repr__(self) -> str: + rep = super().__repr__() + return "randlora." + rep diff --git a/src/peft/tuners/randlora/config.py b/src/peft/tuners/randlora/config.py new file mode 100644 index 0000000000..2eac42bf08 --- /dev/null +++ b/src/peft/tuners/randlora/config.py @@ -0,0 +1,199 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class RandLoraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`RandLoraModel`]. + + Paper: https://arxiv.org/pdf/2502.00987. + + Args: + r (`int`, *optional*, defaults to `32`): + RandLora's random basis rank dimension. Contrary to Lora, this parameter is inversely proportional to the + amount of trainable parameters as reducing it increases trainable parameters. + target_modules (`Union[list[str], str]`): + The names of the modules to apply RandLora to. Only linear layers are supported. + projection_prng_key (`int`): + RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a + checkpoint that did not include these projections. Defaults to `0`. + save_projection (`bool`): + Whether to save the global basis_A / basis_B random basis in the state dict alongside per layer lambda / + gamma diagonal matrices. This will increase the size of the checkpoint, but guarantee that we can reload + the checkpoint on all system configurations. Defaults to `True`. + sparse (`bool`): + Whether to use sparse random bases as described in the RandLora paper. The bases are ternary sparse bases + (only containing -1, 0 and 1) where the attribution probability is 1/6 for -1 and 1 and 2/3 for 0. These + sparse matrices aim to be used for matmul free computation in the future, see + https://arxiv.org/pdf/2406.02528v1 The current implementation is a proof of concept however where the + sparseness is not used to improve speed or memory usage. Using sparse matrices typically does not reduce + performance and can even help reduce overfitting. Defaults to `False`. + very_sparse (`bool`): + Whether to use highly sparse random bases as described in the RandLora paper. The very sparse bases are + ternary sparse bases (only containing -1, 0 and 1) given a matrix with smallest dimension d, the + attribution probability is 1/√D for -1 and 1 and 1- 2/√D for 0. Using these sparse matrices can further + reduce overfitting over the `sparse` alternatives but will most likely decrease performance as a results. + Use carefully. Defaults to `False`. + randlora_dropout (`float`): + The dropout probability for RandLora layers. + randlora_alpha (`float`): + The scaling coefficient for RandLora layers, this would typically be 20 times the rank. Because the + `randlora_alpha` coefficient is large by default, it can lead to numerical instabilities especially when + learning rates are high. If training is unstable, consider reducing the learning rate or the + `randlora_alpha` coefficient. + fan_in_fan_out (`bool`): + Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses + `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. + bias (`str`): + Bias type. Can be 'none', 'all' or 'randlora_only'. If 'all' or 'randlora_only', the corresponding biases + will be updated during training. Be aware that this means that, even when disabling the adapters, the model + will not produce the same output as the base model would have without adaptation. + modules_to_save (`list[str]`): + list of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. + init_weights (`bool`): + Whether to initialize the weights of the RandLora layers with their default initialization. Don't change + this setting, except if you know exactly what you're doing. + layers_to_transform (`Union[list[int],int]`): + The layer indexes to transform, if this argument is specified, it will apply the RandLora transformations + on the layer indexes that are specified in this list. If a single integer is passed, it will apply the + RandLora transformations on the layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer + pattern is not in the common layers pattern. + """ + + r: int = field(default=32, metadata={"help": "RandLora random basis rank"}) + + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": ( + "list of module names or regex expression of the module names to replace with RandLora." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. " + "Only linear layers are supported." + ) + }, + ) + projection_prng_key: int = field( + default=0, + metadata={ + "help": ( + "RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a " + "checkpoint that did not include these projections." + ) + }, + ) + save_projection: bool = field( + default=True, + metadata={ + "help": ( + "Whether to save the basis_A / basis_B projections in the state dict alongside per layer lambda / " + "gamma weights. This will increase the size of the checkpoint, but guarantee that we can reload " + "the checkpoint on all system configurations." + ) + }, + ) + sparse: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use sparse random bases as described in the RandLora paper." + "The current implementation is a proof of concept where the sparseness" + "is not used to improve speed or memory usage." + ) + }, + ) + very_sparse: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use very sparse random bases." + "The current implementation is a proof of concept where the sparseness" + "is not used to improve speed or memory usage." + ) + }, + ) + randlora_dropout: float = field(default=0.0, metadata={"help": "Dropout in the adapter layers"}) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + randlora_alpha: int = field( + default=640, + metadata={ + "help": "Scaling coefficient in the adapter layers, typically 20 times the rank of the random bases." + }, + ) + bias: str = field( + default="none", metadata={"help": "Bias type for RandLora. Can be 'none', 'all' or 'randlora_only'"} + ) + modules_to_save: Optional[list[str]] = field( + default=None, + metadata={ + "help": ( + "list of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. For" + " example, in Sequence Classification or Token Classification tasks, the final layer" + " `classifier/score` are randomly initialized and as such need to be trainable and saved." + ) + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the RandLora layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[list[int], int]] = field( + default=None, + metadata={ + "help": ( + "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers" + " indexes that are specified inside this list. If a single integer is passed, PEFT will transform only" + " the layer at this index." + ) + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer" + " pattern is not in the common layers pattern." + ) + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.RANDLORA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + + if not self.save_projection: + warnings.warn( + "Specified to not save basis_A and basis_B within the state dictionary, instead they will be restored " + "using the PRNG key store in `config.projection_prng_key`. Consider setting `config.save_projection` " + "to `True` to guarantee restoring the checkpoint correctly on all system configurations." + ) diff --git a/src/peft/tuners/randlora/layer.py b/src/peft/tuners/randlora/layer.py new file mode 100644 index 0000000000..40e5aeab6b --- /dev/null +++ b/src/peft/tuners/randlora/layer.py @@ -0,0 +1,349 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + +from .._buffer_dict import BufferDict + + +class UniqueBaseGrad(torch.autograd.Function): + # Memory efficent for a unique base + @staticmethod + def forward(ctx, randlora_A, randlora_lambda, randlora_gamma): + out = randlora_lambda[:, :, None] * randlora_A * randlora_gamma[None,] + ctx.save_for_backward(randlora_A, randlora_lambda, randlora_gamma) + return out + + @staticmethod + def backward(ctx, grad_output): + randlora_A, randlora_lambda, randlora_gamma = ctx.saved_tensors + randlora_A, randlora_lambda, randlora_gamma = ( + randlora_A.to(grad_output.dtype), + randlora_lambda.to(grad_output.dtype), + randlora_gamma.to(grad_output.dtype), + ) + grad_randlora_lambda = torch.einsum("kbj,kvj,bj->kb", grad_output, randlora_A, randlora_gamma) + grad_randlora_gamma = torch.einsum("kbj,kvj,kb->bj", grad_output, randlora_A, randlora_lambda) + return None, grad_randlora_lambda, grad_randlora_gamma + + +class RandLoraLayer(BaseTunerLayer): + # List all names of layers that may contain adapter weights + adapter_layer_names = ("randlora_lambda", "randlora_gamma") + other_param_names = ("randlora_A", "randlora_B") + + def __init__(self, base_layer: nn.Module, **kwargs): + self.base_layer = base_layer + self.r = {} + self.scaling = {} + self.randlora_dropout = nn.ModuleDict({}) + + # For storing vector scale + self.randlora_lambda = nn.ParameterDict({}) + self.randlora_gamma = nn.ParameterDict({}) + + # Stores a reference to the randlora_A/B BufferDict. + # Set to `None` otherwise to avoid computation with random weights + self.randlora_A: Optional[BufferDict] = None + self.randlora_B: Optional[BufferDict] = None + + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + + # flag to enable/disable casting of input to weight dtype during forward call + self.cast_input_dtype_enabled = True + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + + self.in_features = in_features + self.out_features = out_features + self.kwargs = kwargs + + @property + def merged(self) -> bool: + return bool(self.merged_adapters) + + def update_layer( + self, + adapter_name, + randlora_A: BufferDict, + randlora_B: BufferDict, + r, + randlora_alpha, + randlora_dropout, + init_weights, + ): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + self.r[adapter_name] = r + if randlora_dropout > 0.0: + randlora_dropout_layer = nn.Dropout(p=randlora_dropout) + else: + randlora_dropout_layer = nn.Identity() + + self.randlora_dropout.update(nn.ModuleDict({adapter_name: randlora_dropout_layer})) + + # Actual trainable parameters + num_bases = min(self.in_features, self.out_features) / r + self.num_bases = int(num_bases) if num_bases.is_integer() else int(num_bases) + 1 # Full rank + self.randlora_lambda[adapter_name] = nn.Parameter(torch.randn(r, self.num_bases), requires_grad=True) + self.randlora_gamma[adapter_name] = nn.Parameter( + torch.ones(self.num_bases, min(self.out_features, self.in_features)) + / max(self.out_features, self.in_features), + requires_grad=True, + ) + + self.scaling[adapter_name] = randlora_alpha / r + + # non trainable references to randlora_A/B buffers + self.randlora_A = randlora_A + self.randlora_B = randlora_B + if adapter_name not in randlora_A: + # This means that this is not the first RandLora adapter. We have to add an entry in the dict for this adapter. + if len(self.randlora_A) < 1: + raise ValueError( + "The `randlora_A` and `randlora_B` buffers are empty. This should not happen. Please report this issue." + ) + # we can take any of the existing adapter's parameters, as they should all be identical + randlora_A_param = list(self.randlora_A.values())[0] + randlora_B_param = list(self.randlora_B.values())[0] + + error_tmpl = ( + "{} has a size of {} but {} or greater is required; this probably happened because an additional RandLora " + "adapter was added after the first one with incompatible shapes." + ) + max_dim, min_dim = max(self.in_features, self.out_features), min(self.in_features, self.out_features) + # check input size + if randlora_B_param.shape[0] < max_dim: + raise ValueError(error_tmpl.format("randlora_B", randlora_B_param.shape[0], max_dim)) + # check output size + if randlora_A_param.shape[-1] < min_dim: + raise ValueError(error_tmpl.format("randlora_A", randlora_A_param.shape[1], min_dim)) + + # check r + error_tmpl = ( + "{} has a size of {} but {} or greater is required; this probably happened because an additional RandLora " + "adapter with a lower rank was added after the first one; loading the adapters " + "in reverse order may solve this." + ) + if randlora_A_param.shape[0] < self.r[adapter_name]: + raise ValueError(error_tmpl.format("randlora_A", randlora_A_param.shape[0], self.r[adapter_name])) + + if randlora_B_param.shape[-1] < self.r[adapter_name]: + raise ValueError(error_tmpl.format("randlora_B", randlora_B_param.shape[-1], self.r[adapter_name])) + + self.randlora_A[adapter_name] = randlora_A_param + self.randlora_B[adapter_name] = randlora_B_param + + if init_weights: + self.reset_randlora_parameters(adapter_name) + + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters) + + def reset_randlora_parameters(self, adapter_name): + if adapter_name in self.randlora_lambda.keys(): + with torch.no_grad(): + nn.init.zeros_(self.randlora_lambda[adapter_name]) + nn.init.constant_(self.randlora_gamma[adapter_name], 1 / max(self.randlora_gamma[adapter_name].shape)) + + +class Linear(nn.Linear, RandLoraLayer): + # RandLora implemented in a dense layer + def __init__( + self, + base_layer, + randlora_A: BufferDict, + randlora_B: BufferDict, + adapter_name: str, + r: int = 0, + randlora_alpha: int = 0, + randlora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + is_target_conv_1d_layer: bool = False, + init_weights: bool = True, + **kwargs, + ) -> None: + # this gets the init from nn.Linear's super perspective, i.e. nn.Module.__init__, which should always be called + super(nn.Linear, self).__init__() + RandLoraLayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.update_layer(adapter_name, randlora_A, randlora_B, r, randlora_alpha, randlora_dropout, init_weights) + self.is_target_conv_1d_layer = is_target_conv_1d_layer + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.randlora_lambda.keys(): + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype + + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + + orig_weights += self.get_delta_weight(active_adapter) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights.to(orig_dtype) + else: + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data += delta_weight.to(orig_dtype) + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype + active_adapter = self.merged_adapters.pop() + if active_adapter in self.randlora_lambda.keys(): + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data -= delta_weight.to(orig_dtype) + + def get_scaled_bases(self, adapter, device=None) -> tuple[torch.Tensor, torch.Tensor]: + """ + Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct + order to fit the target layers' dimensions + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + randlora_A = self.randlora_A[adapter] + randlora_B = self.randlora_B[adapter] + if device is None: + device = randlora_B.device + dtype = randlora_B.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + randlora_lambda = self.randlora_lambda[adapter].to(device) + randlora_gamma = self.randlora_gamma[adapter].to(device) + + if cast_to_fp32: + randlora_A = randlora_A.float() + randlora_B = randlora_B.float() + randlora_lambda = randlora_lambda.float() + randlora_gamma = randlora_gamma.float() + + # The trainable paramters are always applied to randlora_A, the smallest basis. + min_dim, max_dim = min(self.out_features, self.in_features), max(self.out_features, self.in_features) + + # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, + # we initialize these matrices with the largest required size for each dimension. + # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. + sliced_A = randlora_A[:, : self.num_bases, :min_dim].to(device) + sliced_B = randlora_B[:max_dim, : self.num_bases, :].to(device) + + # Flattening the matrices over the rank and number of bases dimensions is more memory efficient + update_B = sliced_B.flatten(start_dim=1) + update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) + + # Since update_A is applied on the smallest dimension, test whether update_A or update_B should be applied first. This is done to reduce trainable parameters. + if min_dim == self.in_features: + return update_A, update_B + return update_B.T, update_A.T + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + update_B, update_A = self.get_scaled_bases(adapter) + + update = (update_B.T @ update_A.T).T + output_tensor = transpose(update, self.fan_in_fan_out) + + scaling = self.scaling[adapter] + return output_tensor * scaling + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.randlora_lambda.keys(): + continue + dropout = self.randlora_dropout[active_adapter] + update_B, update_A = self.get_scaled_bases(active_adapter, device=x.device) + x = x.to(update_A.dtype) + scaling = self.scaling[active_adapter] + result = result + F.linear(F.linear(dropout(x), update_B), update_A) * scaling + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "randlora." + rep diff --git a/src/peft/tuners/randlora/model.py b/src/peft/tuners/randlora/model.py new file mode 100644 index 0000000000..f0d7498d79 --- /dev/null +++ b/src/peft/tuners/randlora/model.py @@ -0,0 +1,566 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import warnings +from dataclasses import asdict +from enum import Enum +from typing import Optional, Union + +import torch +import torch.nn as nn +from tqdm import tqdm +from transformers.pytorch_utils import Conv1D + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _get_submodules, +) + +from .._buffer_dict import BufferDict +from ..tuners_utils import _maybe_include_all_linear_layers +from .config import RandLoraConfig +from .layer import Linear, RandLoraLayer + + +def _kaiming_init( + tensor_or_shape: Union[torch.Tensor, tuple[int, ...]], + generator: torch.Generator, +) -> torch.Tensor: + """ + Kaiming Uniform Initialisation adapted to accept a `torch.Generator` object for PRNG. + + Args: + tensor_or_shape (`Union[torch.Tensor, tuple[int, ...]]`): + Tensor to initialise, or shape of new tensor to create and then initialise. + generator: (`torch.Generator`): + Generator object that manages the state of the PRNG algorithm in use. + + Returns: + `torch.Tensor`: The initialised tensor. + """ + if isinstance(tensor_or_shape, tuple): + tensor = torch.empty(tensor_or_shape, dtype=torch.float32) + else: + tensor = tensor_or_shape + + with torch.no_grad(): + basis = torch.nn.init.kaiming_uniform_(tensor, a=math.sqrt(5), generator=generator) + return basis + + +class RandLoraModel(BaseTuner): + """ + Creates a RandLoRA model from a pretrained transformers model. + + Args: + model ([`~transformers.PreTrainedModel`]): The model to be adapted. + config ([`RandLoraConfig`]): The configuration of the RandLora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. + + Returns: + `torch.nn.Module`: The RandLora model. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import RandLoraConfig, get_peft_model + + >>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + >>> config = RandLoraConfig(r=128) + >>> model = get_peft_model(base_model, config) + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`RandLoraConfig`]): The configuration of the RandLora model. + """ + + prefix: str = "randlora_" + + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) + + def _find_dim(self, config) -> tuple[int, int]: + """ + Finds the largest input and output dimensions across linear layers that have been wrapped with RandLora. + + This will be used for determining the size of the shared randlora_A and randlora_B matrices. + """ + model_config = self.get_model_config(self.model) + + peft_config = self._prepare_adapter_config(config, model_config) + peft_config = _maybe_include_all_linear_layers(peft_config, self.model) + + largest_shape = None + for key, module in self.model.named_modules(): + if not self._check_target_module_exists(peft_config, key): + continue + + if isinstance(module, nn.Linear): + module_shape = module.out_features, module.in_features + elif isinstance(module, Conv1D): + module_shape = module.weight.ds_shape if hasattr(module.weight, "ds_shape") else module.weight.shape + module_shape = module_shape[::-1] + else: + continue + + if largest_shape is None: + largest_shape = module_shape + continue + + if module_shape != largest_shape: + largest_shape = tuple(max(a, b) for a, b in zip(largest_shape, module_shape)) + + if largest_shape is None: + msg = "No layers types compatible with RandLora were found. Please check `peft_config.target_modules`." + raise ValueError(msg) + + return largest_shape + + def _init_randlora_A_randlora_B_sparse(self, config: RandLoraConfig, adapter_name: str, sparsity: int = 3) -> None: + """ + Sparse random projections as described in https://cs-people.bu.edu/evimaria/cs565/kdd-rp.pdf + """ + + linear_out_dim, linear_in_dim = self._find_dim(config) + max_dim, min_dim = max(linear_out_dim, linear_in_dim), min(linear_out_dim, linear_in_dim) + + # use of persistent to exclude randlora_A and randlora_B from the state dict if we choose not to save them. + self.randlora_A = BufferDict({}, persistent=config.save_projection) + self.randlora_B = BufferDict({}, persistent=config.save_projection) + + # deterministic init of randlora_A and randlora_B if we know the key + generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) + + # The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices. + # We also set randlora_A as the smallest matrix to reduce trainable parameters. + randlora_A = torch.rand((config.r, 1, min_dim), generator=generator) + + # Number of bases to ensure full rank + num_bases = min_dim / config.r + num_bases = int(num_bases) if num_bases.is_integer() else int(num_bases) + 1 # Ensure full rank + randlora_B = torch.rand((max_dim, num_bases, config.r), generator=generator) + + # The current implementation is a proof of concept and does take into consideration + # the sparsity to reduce memory usage or speed up compute + randlora_B_sparse = torch.zeros(randlora_B.shape) + randlora_A_sparse = torch.zeros(randlora_A.shape) + randlora_B_sparse[randlora_B < 1 / (2 * sparsity)] = -1 + randlora_B_sparse[randlora_B > 1 - 1 / (2 * sparsity)] = 1 + randlora_A_sparse[randlora_A < 1 / (2 * sparsity)] = -1 + randlora_A_sparse[randlora_A > 1 - 1 / (2 * sparsity)] = 1 + + # Std normalization is empirically found to be the best + randlora_A, randlora_B = ( + randlora_A_sparse / randlora_A_sparse.std(), + randlora_B_sparse / randlora_B_sparse.std(), + ) + self.randlora_A[adapter_name] = randlora_A + self.randlora_B[adapter_name] = randlora_B + + def _init_randlora_A_randlora_B(self, config: RandLoraConfig, adapter_name: str) -> None: + linear_out_dim, linear_in_dim = self._find_dim(config) + max_dim, min_dim = max(linear_out_dim, linear_in_dim), min(linear_out_dim, linear_in_dim) + + # use of persistent to exclude randlora_A and randlora_B from the state dict if we choose not to save them. + self.randlora_A = BufferDict({}, persistent=config.save_projection) + self.randlora_B = BufferDict({}, persistent=config.save_projection) + + # deterministic init of randlora_A and randlora_B if we know the key + generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) + + # The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices. + # We also set randlora_A as the smallest matrix to reduce trainable parameters. + randlora_A = _kaiming_init((config.r, 1, min_dim), generator=generator) + + # Ensure full rank + num_bases = min(linear_out_dim, linear_in_dim) / config.r + num_bases = int(num_bases) if num_bases.is_integer() else int(num_bases) + 1 + randlora_B = torch.cat( + [_kaiming_init((max_dim, 1, config.r), generator=generator) for _ in range(num_bases)], dim=1 + ) + + # Std normalization is empirically found to be the best + randlora_A, randlora_B = randlora_A / randlora_A.std(), randlora_B / randlora_B.std() + self.randlora_A[adapter_name] = randlora_A + self.randlora_B[adapter_name] = randlora_B + + def _pre_injection_hook(self, model: nn.Module, config: RandLoraConfig, adapter_name: str) -> None: + if config.very_sparse: + linear_out_dim, linear_in_dim = self._find_dim(config) + self._init_randlora_A_randlora_B_sparse( + config, adapter_name, sparsity=math.sqrt(min(linear_out_dim, linear_in_dim)) + ) + elif config.sparse: + self._init_randlora_A_randlora_B_sparse(config, adapter_name, sparsity=3) + else: + self._init_randlora_A_randlora_B(config, adapter_name) + + def _check_new_adapter_config(self, config: RandLoraConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # the below todo is copied from LoRA + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + for existing_config in self.peft_config.values(): + if existing_config is config: + # skip the current config + continue + + if existing_config.projection_prng_key != config.projection_prng_key: + raise ValueError( + f"RandLora PRNG initialisation key must be the same for all adapters. Got {config.projection_prng_key=} but " + f"previous config had {existing_config.projection_prng_key}." + ) + + save_project_unique_values = sorted({config.save_projection for config in self.peft_config.values()}) + if len(save_project_unique_values) > 1: + raise ValueError( + "RandLora projection weights must be saved for all adapters or none, but got multiple different values: " + f"{save_project_unique_values}" + ) + + @staticmethod + def _check_target_module_exists(randlora_config, key): + return check_target_module_exists(randlora_config, key) + + def _create_and_replace( + self, + randlora_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + r = randlora_config.r + bias = hasattr(target, "bias") and target.bias is not None + kwargs = { + "r": r, + "randlora_alpha": randlora_config.randlora_alpha, + "randlora_dropout": randlora_config.randlora_dropout, + "fan_in_fan_out": randlora_config.fan_in_fan_out, + "init_weights": randlora_config.init_weights, + "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), + } + kwargs["bias"] = bias + if isinstance(target, Linear): + target.update_layer( + adapter_name, + self.randlora_A, + self.randlora_B, + r, + randlora_config.randlora_alpha, + randlora_config.randlora_dropout, + randlora_config.init_weights, + ) + else: + new_module = self._create_new_module( + randlora_config, self.randlora_A, self.randlora_B, adapter_name, target, **kwargs + ) + if adapter_name not in self.active_adapter: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _replace_module(parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + meta = torch.device("meta") + # dispatch to correct device + for name, module in new_module.named_modules(): + if "randlora_" in name: + if not any(p.device == meta for p in module.parameters()): + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "randlora_only": + for m in model.modules(): + if isinstance(m, RandLoraLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(randlora_config, randlora_A, randlora_B, adapter_name, target, **kwargs): + # avoid eager bnb import + if is_bnb_available(): + import bitsandbytes as bnb + + from .bnb import Linear8bitLt + + if is_bnb_4bit_available(): + from .bnb import Linear4bit + + bias = kwargs.pop("bias", False) + loaded_in_8bit = kwargs.get("loaded_in_8bit", False) + loaded_in_4bit = kwargs.get("loaded_in_4bit", False) + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): + eightbit_kwargs = kwargs.copy() + eightbit_kwargs.update( + { + "has_fp16_weights": target_base_layer.state.has_fp16_weights, + "threshold": target_base_layer.state.threshold, + "index": target_base_layer.index, + } + ) + return Linear8bitLt(target, adapter_name, randlora_A, randlora_B, **eightbit_kwargs) + elif loaded_in_4bit and isinstance(target_base_layer, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target_base_layer.compute_dtype, + "compress_statistics": target_base_layer.weight.compress_statistics, + "quant_type": target_base_layer.weight.quant_type, + } + ) + return Linear4bit(target, adapter_name, randlora_A, randlora_B, **fourbit_kwargs) + elif isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = randlora_config.fan_in_fan_out = False + elif isinstance(target_base_layer, Conv1D): + kwargs["is_target_conv_1d_layer"] = True + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = randlora_config.fan_in_fan_out = True + else: + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`, `transformers.pytorch_utils.Conv1D`." + ) + new_module = Linear( + target, + randlora_A, + randlora_B, + adapter_name, + bias=bias, + **kwargs, + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "model": # see #1892: prevent infinite recursion if class is not initialized + raise + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name): + for module in self.model.modules(): + if isinstance(module, RandLoraLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + # we cannot use self.prefix as we want to include non-trainable randlora parameters + key_list = [key for key, _ in self.model.named_modules() if "randlora" not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + + return self.model + + def delete_adapter(self, adapter_name: str): + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + # we cannot use self.prefix as we want to include non-trainable randlora parameters + key_list = [key for key, _ in self.model.named_modules() if "randlora" not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, RandLoraLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapter[:] + + self.active_adapter = new_adapter or [] + self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter) + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ): + r""" + This method merges the RandLora layers into the base model. This is needed if someone wants to use the base + model as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + + >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") + >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample" + >>> model = PeftModel.from_pretrained(base_model, peft_model_id) + >>> merged_model = model.merge_and_unload() + ``` + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self): + """ + Gets back the base model by removing all the RandLora modules without merging. This gives back the original + base model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 97474caad6..93343f0575 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -25,6 +25,7 @@ TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, WEIGHTS_NAME, @@ -64,6 +65,7 @@ "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", + "TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", "WEIGHTS_NAME", diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index e04b1ea1a2..0936d95110 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -291,6 +291,10 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "qwen2": ["q_proj", "v_proj"], } +TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING = ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING # Leaving this for now but RandLoRA is flexible +) + WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" CONFIG_NAME = "adapter_config.json" diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index bc352b6006..86669ffcaa 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -44,6 +44,7 @@ TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, WEIGHTS_NAME, @@ -71,6 +72,7 @@ "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", + "TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", "WEIGHTS_NAME", diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index a205bd4550..023fbaed78 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -40,6 +40,7 @@ class PeftType(str, enum.Enum): - FOURIERFT - HRA - BONE + - RANDLORA """ PROMPT_TUNING = "PROMPT_TUNING" @@ -63,6 +64,7 @@ class PeftType(str, enum.Enum): VBLORA = "VBLORA" CPT = "CPT" BONE = "BONE" + RANDLORA = "RANDLORA" TRAINABLE_TOKENS = "TRAINABLE_TOKENS" diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 6d10f3bd27..b68510f51c 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -46,6 +46,7 @@ LoraConfig, OFTConfig, PeftModel, + RandLoraConfig, TaskType, VBLoRAConfig, VeraConfig, @@ -70,11 +71,13 @@ from peft.tuners.ia3 import Linear8bitLt as IA3Linear8bitLt from peft.tuners.lora import Linear8bitLt as LoraLinear8bitLt + from peft.tuners.randlora import Linear8bitLt as RandLoraLinear8bitLt from peft.tuners.vera import Linear8bitLt as VeraLinear8bitLt if is_bnb_4bit_available(): from peft.tuners.ia3 import Linear4bit as IA3Linear4bit from peft.tuners.lora import Linear4bit as LoraLinear4bit + from peft.tuners.randlora import Linear4bit as RandLoraLinear4bit from peft.tuners.vera import Linear4bit as VeraLinear4bit @@ -199,6 +202,54 @@ def test_vera_bnb_8bit_quantization(self): whisper_8bit = get_peft_model(whisper_8bit, config) assert isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, VeraLinear8bitLt) + @require_bitsandbytes + @pytest.mark.multi_gpu_tests + @pytest.mark.single_gpu_tests + def test_randlora_bnb_8bit_quantization(self): + r""" + Test that tests if the 8bit quantization using RandLora works as expected + """ + whisper_8bit = WhisperForConditionalGeneration.from_pretrained( + self.audio_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + opt_8bit = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + flan_8bit = AutoModelForSeq2SeqLM.from_pretrained( + self.seq2seq_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + flan_randlora_config = RandLoraConfig( + r=16, target_modules=["q", "v"], randlora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM" + ) + + opt_randlora_config = RandLoraConfig( + r=10, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + config = RandLoraConfig(r=5, target_modules=["q_proj", "v_proj"], randlora_dropout=0.05, bias="none") + + flan_8bit = get_peft_model(flan_8bit, flan_randlora_config) + assert isinstance(flan_8bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, RandLoraLinear8bitLt) + + opt_8bit = get_peft_model(opt_8bit, opt_randlora_config) + assert isinstance(opt_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RandLoraLinear8bitLt) + + whisper_8bit = get_peft_model(whisper_8bit, config) + assert isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RandLoraLinear8bitLt) + @require_bitsandbytes @pytest.mark.multi_gpu_tests @pytest.mark.single_gpu_tests @@ -347,6 +398,43 @@ def test_vera_bnb_quantization_from_pretrained_safetensors(self, quantization): assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.vera_A assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.vera_A + @require_bitsandbytes + @pytest.mark.multi_gpu_tests + @pytest.mark.single_gpu_tests + @parameterized.expand(["4bit", "8bit"]) + def test_randlora_bnb_quantization_from_pretrained_safetensors(self, quantization): + r""" + Tests that the bnb quantization using RandLora works as expected with safetensors weights. + """ + model_id = "facebook/opt-350m" + kwargs = {"device_map": "auto"} + if quantization == "4bit": + kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) + else: + kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) + config = RandLoraConfig(task_type=TaskType.CAUSAL_LM) + peft_model = get_peft_model(model, config) + peft_model = prepare_model_for_kbit_training(peft_model) + peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) + + with tempfile.TemporaryDirectory() as tmp_dir: + peft_model.save_pretrained(tmp_dir) + model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) + model = PeftModel.from_pretrained(model, tmp_dir) + model = prepare_model_for_kbit_training(model) + model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) + + # loading a 2nd adapter works, #1239 + model.load_adapter(tmp_dir, "adapter2") + model.set_adapter("adapter2") + model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) + + # check that both adapters are in the same layer + assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.randlora_A + assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.randlora_A + @require_bitsandbytes @pytest.mark.multi_gpu_tests @pytest.mark.single_gpu_tests @@ -519,6 +607,54 @@ def test_vera_bnb_4bit_quantization(self): whisper_4bit = get_peft_model(whisper_4bit, config) assert isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, VeraLinear4bit) + @require_bitsandbytes + @pytest.mark.multi_gpu_tests + @pytest.mark.single_gpu_tests + def test_randlora_bnb_4bit_quantization(self): + r""" + Test that tests if the 4bit quantization using RandLoRA works as expected + """ + whisper_4bit = WhisperForConditionalGeneration.from_pretrained( + self.audio_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + opt_4bit = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + flan_4bit = AutoModelForSeq2SeqLM.from_pretrained( + self.seq2seq_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + flan_randlora_config = RandLoraConfig( + r=16, target_modules=["q", "v"], randlora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM" + ) + + opt_randlora_config = RandLoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + config = RandLoraConfig(r=32, target_modules=["q_proj", "v_proj"], randlora_dropout=0.05, bias="none") + + flan_4bit = get_peft_model(flan_4bit, flan_randlora_config) + assert isinstance(flan_4bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, RandLoraLinear4bit) + + opt_4bit = get_peft_model(opt_4bit, opt_randlora_config) + assert isinstance(opt_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RandLoraLinear4bit) + + whisper_4bit = get_peft_model(whisper_4bit, config) + assert isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RandLoraLinear4bit) + @require_bitsandbytes @pytest.mark.multi_gpu_tests @pytest.mark.single_gpu_tests @@ -1787,6 +1923,28 @@ def test_vera_add_new_adapter_does_not_change_device(self, mlp): assert model.lin0.vera_A.other.device.type == self.device assert model.lin0.vera_lambda_d.other.device.type == self.device + def test_randlora_add_new_adapter_does_not_change_device(self, mlp): + # same as first test, but using RandLora + config = RandLoraConfig(target_modules=["lin0"]) + model = get_peft_model(mlp, config) + model = model.to(self.device) + model.lin0.randlora_A.cpu() + model.lin0.randlora_lambda.cpu() + + # check that the adapter is indeed on CPU and the base model on GPU + assert model.lin0.randlora_A.default.device.type == "cpu" + assert model.lin0.randlora_lambda.default.device.type == "cpu" + assert model.lin0.base_layer.weight.device.type == self.device + + model.add_adapter("other", config) + # check that after adding a new adapter, the old adapter is still on CPU + assert model.lin0.randlora_A.default.device.type == "cpu" + assert model.lin0.randlora_lambda.default.device.type == "cpu" + # the rest should be on GPU + assert model.lin0.base_layer.weight.device.type == self.device + assert model.lin0.randlora_A.other.device.type == self.device + assert model.lin0.randlora_lambda.other.device.type == self.device + def test_vblora_add_new_adapter_does_not_change_device(self, mlp): # same as first test, but using VBLoRA config = VBLoRAConfig(target_modules=["lin0"], vector_length=2) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 392a4ef708..ebe3f09847 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -46,6 +46,7 @@ LoraConfig, OFTConfig, PeftModel, + RandLoraConfig, TaskType, TrainableTokensConfig, VBLoRAConfig, @@ -511,6 +512,32 @@ TrainableTokensConfig, {"target_modules": ["emb"], "token_indices": [0, 1, 3], "init_weights": False}, ), + ######## + # RandLora # + ######## + # We have to reduce the default scaling parameter to avoid nans when using large learning rates + ("Vanilla MLP 1 RandLora", "MLP", RandLoraConfig, {"target_modules": "lin0", "randlora_alpha": 1}), + ("Vanilla MLP 2 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0"], "randlora_alpha": 1}), + ("Vanilla MLP 3 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin1"], "randlora_alpha": 1}), + ("Vanilla MLP 4 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "randlora_alpha": 1}), + ( + "Vanilla MLP 5 RandLora", + "MLP", + RandLoraConfig, + {"target_modules": ["lin0", "lin1"], "sparse": True, "randlora_alpha": 1}, + ), + ( + "Vanilla MLP 6 RandLora", + "MLP", + RandLoraConfig, + {"target_modules": ["lin0", "lin1"], "very_sparse": True, "randlora_alpha": 1}, + ), + ( + "Vanilla MLP 7 RandLora", + "MLP", + RandLoraConfig, + {"target_modules": ["lin0"], "modules_to_save": ["lin1"], "randlora_alpha": 1}, + ), ] # For this test matrix, each tuple consists of: @@ -617,6 +644,14 @@ {"target_modules": ["lin0"], "init_weights": False}, {"target_modules": ["lin0"], "init_weights": False}, ), + # Note: RandLora may present the same problem mentioned above for Vera. + ( + "RandLora Same", + "randlora", + RandLoraConfig, + {"target_modules": ["lin0"], "init_weights": False}, + {"target_modules": ["lin0"], "init_weights": False}, + ), ( "HRA Same", "hra", @@ -684,6 +719,7 @@ BOFTConfig: "boft_", LNTuningConfig: "ln_tuning_", VeraConfig: "vera_lambda_", + RandLoraConfig: "randlora_", FourierFTConfig: "fourierft_", HRAConfig: "hra_", VBLoRAConfig: "vblora_", @@ -1440,7 +1476,7 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c lr = 0.1 # otherwise we get nan elif "mha" in model_id.lower(): lr = 1e-3 # we get exploding gradients with MHA when learning rate is too high - elif issubclass(config_cls, VBLoRAConfig): + elif issubclass(config_cls, VBLoRAConfig) or issubclass(config_cls, RandLoraConfig): lr = 0.01 # otherwise we get nan optimizer = torch.optim.SGD(model.parameters(), lr=lr) @@ -3886,6 +3922,148 @@ def forward(self, X): "base_model.model.lin2.vera_lambda_d.adapter1", ) + def test_requires_grad_randlora_different_targets(self): + # Test two different RandLora adapters that target different modules. Most notably, ensure that randbasis_A and randbasis_B + # don't require grads. + + # requires a model with at least 2 layers with the same shapes + class MLP2(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.relu = nn.ReLU() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.lin1 = nn.Linear(20, 20, bias=bias) # lin1 and lin2 have same shape + self.lin2 = nn.Linear(20, 20, bias=bias) + self.lin3 = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + X = self.relu(X) + X = self.lin2(X) + X = self.relu(X) + X = self.lin3(X) + X = self.sm(X) + return X + + config0 = RandLoraConfig(target_modules=["lin1"]) + peft_model = get_peft_model(MLP2(), config0) + + config1 = RandLoraConfig(target_modules=["lin2"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.default", + "base_model.model.lin1.randlora_gamma.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.default", + "base_model.model.lin1.randlora_gamma.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin2.randlora_lambda.adapter1", + "base_model.model.lin2.randlora_gamma.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin2.randlora_lambda.adapter1", + "base_model.model.lin2.randlora_gamma.adapter1", + ) + + def test_requires_grad_randlora_same_targets(self): + # Test two different RandLora adapters that target the same module. Most notably, ensure that randbasis_A and randbasis_B + # don't require grads. + + # requires a model with at least 2 layers with the same shapes + class MLP2(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.relu = nn.ReLU() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.lin1 = nn.Linear(20, 20, bias=bias) # lin1 and lin2 have same shape + self.lin2 = nn.Linear(20, 20, bias=bias) + self.lin3 = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + X = self.relu(X) + X = self.lin2(X) + X = self.relu(X) + X = self.lin3(X) + X = self.sm(X) + return X + + config0 = RandLoraConfig(target_modules=["lin1", "lin2"]) + peft_model = get_peft_model(MLP2(), config0) + + config1 = RandLoraConfig(target_modules=["lin1", "lin2"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.default", + "base_model.model.lin1.randlora_gamma.default", + "base_model.model.lin2.randlora_lambda.default", + "base_model.model.lin2.randlora_gamma.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.default", + "base_model.model.lin1.randlora_gamma.default", + "base_model.model.lin2.randlora_lambda.default", + "base_model.model.lin2.randlora_gamma.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.adapter1", + "base_model.model.lin1.randlora_gamma.adapter1", + "base_model.model.lin2.randlora_lambda.adapter1", + "base_model.model.lin2.randlora_gamma.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.adapter1", + "base_model.model.lin1.randlora_gamma.adapter1", + "base_model.model.lin2.randlora_lambda.adapter1", + "base_model.model.lin2.randlora_gamma.adapter1", + ) + def test_requires_grad_vblora_different_targets(self): # test two different VBLoRA adapters that target different modules config0 = VBLoRAConfig(target_modules=["lin0"], vector_length=1, num_vectors=2) diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index 2bffb935ef..2a4c5b88bd 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -44,10 +44,10 @@ def skip_non_prompt_tuning(test_list): def skip_deberta_lora_tests(test_list): r""" - Skip tests that are checkpointing with lora/ia3/boft/vera/fourierft for Deberta models (couldn't find much info on - the error) + Skip tests that are checkpointing with lora/ia3/boft/vera/randlora/fourierft for Deberta models (couldn't find much + info on the error) """ - to_skip = ["lora", "ia3", "boft", "vera", "fourierft", "hra", "bone"] + to_skip = ["lora", "ia3", "boft", "vera", "fourierft", "hra", "bone", "randlora"] return [test for test in test_list if not (any(k in test[0] for k in to_skip) and "Deberta" in test[0])] @@ -116,6 +116,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "boft_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, + "randlora_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, "bone_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", @@ -171,6 +172,7 @@ def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_k "boft_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, + "randlora_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, "bone_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index f591d52826..4e81e2c6a8 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -60,6 +60,7 @@ PeftModel, PrefixTuningConfig, PromptEncoderConfig, + RandLoraConfig, TaskType, VeraConfig, get_peft_model, @@ -1391,6 +1392,232 @@ def test_causal_lm_training_multi_gpu_4bit_vera(self): # assert loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None + @pytest.mark.single_gpu_tests + def test_causal_lm_training_8bit_randlora(self): + r""" + Same as test_causal_lm_training but with RandLora + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + device_map="auto", + ) + + tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + model = prepare_model_for_kbit_training(model) + + config = RandLoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.single_gpu_tests + def test_causal_lm_training_4bit_randlora(self): + r""" + Same as test_causal_lm_training_4bit but with RandLora + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + device_map="auto", + ) + + tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + model = prepare_model_for_kbit_training(model) + + config = RandLoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.multi_gpu_tests + def test_causal_lm_training_multi_gpu_8bit_randlora(self): + r""" + Same as test_causal_lm_training_multi_gpu but with RandLoRA + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + assert set(model.hf_device_map.values()) == set(range(device_count)) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = RandLoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.multi_gpu_tests + def test_causal_lm_training_multi_gpu_4bit_randlora(self): + r""" + Same as test_causal_lm_training_multi_gpu_4bit but with RandLora + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + assert set(model.hf_device_map.values()) == set(range(device_count)) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = RandLoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + @pytest.mark.single_gpu_tests def test_causal_lm_training_lora_resize_embeddings_trainable_tokens(self): r""" diff --git a/tests/test_randlora.py b/tests/test_randlora.py new file mode 100644 index 0000000000..b553177ea9 --- /dev/null +++ b/tests/test_randlora.py @@ -0,0 +1,303 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This test file is for tests specific to RandLora, since Randlora has some specific challenges due to the shared weights. +# These tests are copied from the test_vera.py file + +import os + +import pytest +import torch +from safetensors import safe_open +from torch import nn + +from peft import PeftModel, RandLoraConfig, get_peft_model +from peft.utils import infer_device + + +class MLP(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.relu = nn.ReLU() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.lin1 = nn.Linear(20, 20, bias=bias) # lin1 and lin2 have same shape + self.lin2 = nn.Linear(20, 20, bias=bias) + self.lin3 = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + X = self.relu(X) + X = self.lin2(X) + X = self.relu(X) + X = self.lin3(X) + X = self.sm(X) + return X + + +# Tests copied from the TestVera class in test_vera.py. +# Changes to the code file should be reflected here. +class TestRandLora: + @pytest.fixture + def mlp(self): + torch.manual_seed(0) + model = MLP() + return model + + @pytest.fixture + def mlp_same_prng(self, mlp): + torch.manual_seed(0) + + config = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False) + # creates a default RandLora adapter + peft_model = get_peft_model(mlp, config) + config2 = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False) + peft_model.add_adapter("other", config2) + return peft_model + + def test_multiple_adapters_same_prng_weights(self, mlp_same_prng): + # we can have multiple adapters with the same prng key, in which case the weights should be shared + assert ( + mlp_same_prng.base_model.model.lin1.randlora_A["default"] + is mlp_same_prng.base_model.model.lin1.randlora_A["other"] + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_B["default"] + is mlp_same_prng.base_model.model.lin1.randlora_B["other"] + ) + assert ( + mlp_same_prng.base_model.model.lin2.randlora_A["default"] + is mlp_same_prng.base_model.model.lin2.randlora_A["other"] + ) + assert ( + mlp_same_prng.base_model.model.lin2.randlora_B["default"] + is mlp_same_prng.base_model.model.lin2.randlora_B["other"] + ) + + input = torch.randn(5, 10) + mlp_same_prng.set_adapter("default") + output_default = mlp_same_prng(input) + mlp_same_prng.set_adapter("other") + output_other = mlp_same_prng(input) + assert not torch.allclose(output_default, output_other, atol=1e-3, rtol=1e-3) + + def test_multiple_adapters_different_prng_raises(self): + # we cannot have multiple adapters with different prng keys + model = MLP() + config = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False) + # creates a default RandLora adapter + peft_model = get_peft_model(model, config) + config2 = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False, projection_prng_key=123) + + msg = ( + r"RandLora PRNG initialisation key must be the same for all adapters. Got config.projection_prng_key=123 but " + r"previous config had 0" + ) + with pytest.raises(ValueError, match=msg): + peft_model.add_adapter("other", config2) + + def test_multiple_adapters_save_load_save_projection_true(self, mlp_same_prng, tmp_path): + # check saving and loading works with multiple adapters and saved projection weights + torch.manual_seed(0) + input = torch.randn(5, 10) + mlp_same_prng.set_adapter("default") + output_default = mlp_same_prng(input) + mlp_same_prng.set_adapter("other") + output_other = mlp_same_prng(input) + + # sanity check + assert not torch.allclose(output_default, output_other, atol=1e-3, rtol=1e-3) + + save_path = tmp_path / "randlora" + mlp_same_prng.save_pretrained(save_path) + assert os.path.exists(save_path / "adapter_config.json") + assert os.path.exists(save_path / "other" / "adapter_config.json") + + torch.manual_seed(0) + mlp = MLP() + peft_model = PeftModel.from_pretrained(mlp, save_path) + peft_model.load_adapter(save_path / "other", "other") + + peft_model.set_adapter("default") + output_default_loaded = peft_model(input) + peft_model.set_adapter("other") + output_other_loaded = peft_model(input) + + assert torch.allclose(output_default, output_default_loaded, atol=1e-3, rtol=1e-3) + assert torch.allclose(output_other, output_other_loaded, atol=1e-3, rtol=1e-3) + + def test_multiple_adapters_save_load_save_projection_false(self, mlp, tmp_path): + # check saving and loading works with multiple adapters without saved projection weights + torch.manual_seed(1) + config = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + # creates a default RandLora adapter + peft_model = get_peft_model(mlp, config, adapter_name="first") + config2 = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + peft_model.add_adapter("second", config2) + + input = torch.randn(5, 10) + peft_model.set_adapter("first") + output_first = peft_model(input) + peft_model.set_adapter("second") + output_second = peft_model(input) + + # sanity check + assert not torch.allclose(output_first, output_second, atol=1e-3, rtol=1e-3) + + save_path = tmp_path / "randlora" + peft_model.save_pretrained(save_path) + assert os.path.exists(save_path / "first" / "adapter_config.json") + assert os.path.exists(save_path / "second" / "adapter_config.json") + + torch.manual_seed(0) + mlp = MLP() + peft_model = PeftModel.from_pretrained(mlp, save_path / "first", adapter_name="first") + peft_model.load_adapter(save_path / "second", "second") + + peft_model.set_adapter("first") + output_first_loaded = peft_model(input) + peft_model.set_adapter("second") + output_second_loaded = peft_model(input) + + assert torch.allclose(output_first, output_first_loaded, atol=1e-3, rtol=1e-3) + assert torch.allclose(output_second, output_second_loaded, atol=1e-3, rtol=1e-3) + + def test_multiple_adapters_save_projection_true_contains_randlora_A_randlora_B(self, mlp_same_prng, tmp_path): + # check that the state_dicts don't contain the projection weights + save_path = tmp_path / "randlora" + mlp_same_prng.save_pretrained(save_path) + + sd_default = {} + with safe_open(save_path / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_default[key] = f.get_tensor(key) + + assert any("randlora_A" in key for key in sd_default) + assert any("randlora_B" in key for key in sd_default) + # default rank for RandLora is 32 + assert sd_default["base_model.randlora_A"].shape == (32, 1, 20) + assert sd_default["base_model.randlora_B"].shape == (20, 1, 32) + + sd_other = {} + with safe_open(save_path / "other" / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_other[key] = f.get_tensor(key) + + assert any("randlora_A" in key for key in sd_other) + assert any("randlora_B" in key for key in sd_other) + assert sd_other["base_model.randlora_A"].shape == (32, 1, 20) + assert sd_other["base_model.randlora_B"].shape == (20, 1, 32) + + def test_multiple_adapters_save_projection_false_contains_no_randlora_A_randlora_B(self, mlp, tmp_path): + torch.manual_seed(1) + config = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + # creates a default RandLora adapter + peft_model = get_peft_model(mlp, config, adapter_name="first") + config2 = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + peft_model.add_adapter("second", config2) + + save_path = tmp_path / "randlora" + peft_model.save_pretrained(save_path) + + sd_default = {} + with safe_open(save_path / "first" / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_default[key] = f.get_tensor(key) + + assert not any("randlora_A" in key for key in sd_default) + assert not any("randlora_B" in key for key in sd_default) + + sd_other = {} + with safe_open(save_path / "second" / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_other[key] = f.get_tensor(key) + + assert not any("randlora_A" in key for key in sd_other) + assert not any("randlora_B" in key for key in sd_other) + + def test_randlora_A_randlora_B_share_memory(self, mlp_same_prng): + randlora_A = mlp_same_prng.randlora_A["default"] + randlora_B = mlp_same_prng.randlora_B["default"] + + # these tensors should share the same data + assert randlora_A.data_ptr() == mlp_same_prng.base_model.model.lin1.randlora_A["default"].data_ptr() + assert randlora_B.data_ptr() == mlp_same_prng.base_model.model.lin1.randlora_B["default"].data_ptr() + assert randlora_A.data_ptr() == mlp_same_prng.base_model.model.lin2.randlora_A["default"].data_ptr() + assert randlora_B.data_ptr() == mlp_same_prng.base_model.model.lin2.randlora_B["default"].data_ptr() + # sanity check: these tensors shouldn't share the same data + assert randlora_A.data_ptr() != randlora_B.data_ptr() + + def test_randlora_lambda_dont_share_memory(self, mlp_same_prng): + # sanity check: these tensors shouldn't share the same data + assert ( + mlp_same_prng.base_model.model.lin1.randlora_lambda["default"].data_ptr() + != mlp_same_prng.base_model.model.lin1.randlora_lambda["other"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_lambda["default"].data_ptr() + != mlp_same_prng.base_model.model.lin2.randlora_lambda["default"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_lambda["other"].data_ptr() + != mlp_same_prng.base_model.model.lin2.randlora_lambda["other"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_gamma["default"].data_ptr() + != mlp_same_prng.base_model.model.lin1.randlora_gamma["other"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_gamma["default"].data_ptr() + != mlp_same_prng.base_model.model.lin2.randlora_gamma["default"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_gamma["other"].data_ptr() + != mlp_same_prng.base_model.model.lin2.randlora_gamma["other"].data_ptr() + ) + + def test_randlora_different_shapes(self, mlp): + config = RandLoraConfig(target_modules=["lin0", "lin3"], init_weights=False) + mlp_different_shapes = get_peft_model(mlp, config) + + randlora_A = mlp_different_shapes.randlora_A["default"] + randlora_B = mlp_different_shapes.randlora_B["default"] + + # sanity check + assert mlp.lin0.base_layer.weight.shape != mlp.lin3.base_layer.weight.shape + + # lin0 has the largest output dimension, lin3 has the largest input dimension + # randlora_A should have the shape of (rank, largest_in), randlora_B should have the shape of (largest_out, rank) + assert randlora_A.shape == (config.r, 1, mlp.lin3.in_features) + assert randlora_B.shape == (mlp.lin0.out_features, 1, config.r) + + # should not raise + input = torch.randn(5, 10) + mlp_different_shapes(input) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_randlora_dtypes(self, dtype): + if dtype == torch.bfloat16: + # skip if bf16 is not supported on hardware, see #1872 + is_xpu = infer_device() == "xpu" + is_cuda_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() + if not (is_xpu or is_cuda_bf16): + pytest.skip("bfloat16 not supported on this system, skipping the test") + + model = MLP().to(dtype) + config = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False) + peft_model = get_peft_model(model, config) + inputs = torch.randn(5, 10).to(dtype) + output = peft_model(inputs) # should not raise + assert output.dtype == dtype diff --git a/tests/testing_common.py b/tests/testing_common.py index dd21426f5d..8a1acc6daf 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -50,6 +50,7 @@ PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, + RandLoraConfig, VBLoRAConfig, VeraConfig, get_peft_model, @@ -142,6 +143,16 @@ "bias": "none", "trainable_token_indices": [0, 1, 3], }, + # RandLoRA + { + "r": 32, + "randlora_alpha": 64, + "target_modules": None, + "randlora_dropout": 0.05, + "projection_prng_key": 0xFF, + "save_projection": True, + "bias": "none", + }, # CPT tuninig { "cpt_token_ids": [0, 1, 2, 3, 4, 5, 6, 7], # Example token IDs for testing @@ -165,9 +176,10 @@ "oft": (OFTConfig, CONFIG_TESTING_KWARGS[11]), "bone": (BoneConfig, CONFIG_TESTING_KWARGS[12]), "lora+trainable_tokens": (LoraConfig, CONFIG_TESTING_KWARGS[13]), + "randlora": (RandLoraConfig, CONFIG_TESTING_KWARGS[14]), } -DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[14])} +DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[15])} # Adapted from https://github.com/huggingface/transformers/blob/48327c57182fdade7f7797d1eaad2d166de5c55b/src/transformers/activations.py#LL166C7-L166C22 @@ -482,7 +494,7 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs, safe_serial if issubclass(config_cls, IA3Config): config_kwargs = config_kwargs.copy() config_kwargs["init_ia3_weights"] = False - if issubclass(config_cls, VeraConfig): + if hasattr(config_cls, "init_weights"): config_kwargs = config_kwargs.copy() config_kwargs["init_weights"] = False @@ -1642,6 +1654,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): "FOURIERFT", "HRA", "VBLORA", + "RANDLORA", "BONE", ): with pytest.raises(AttributeError):