From 649a35bf177c3ad9fd6b60bca43ffb46baf46de3 Mon Sep 17 00:00:00 2001 From: Paul Albert Date: Fri, 28 Mar 2025 05:17:45 +0000 Subject: [PATCH 01/13] randlora integration - more work to do to conform to quantization practices --- src/peft/__init__.py | 2 + src/peft/tuners/__init__.py | 4 +- src/peft/tuners/randlora/__init__.py | 40 ++ src/peft/tuners/randlora/bnb.py | 449 ++++++++++++++++++++++ src/peft/tuners/randlora/config.py | 179 +++++++++ src/peft/tuners/randlora/layer.py | 351 +++++++++++++++++ src/peft/tuners/randlora/model.py | 552 +++++++++++++++++++++++++++ src/peft/utils/__init__.py | 2 + src/peft/utils/constants.py | 2 + src/peft/utils/other.py | 1 + src/peft/utils/peft_types.py | 2 + 11 files changed, 1583 insertions(+), 1 deletion(-) create mode 100644 src/peft/tuners/randlora/__init__.py create mode 100644 src/peft/tuners/randlora/bnb.py create mode 100644 src/peft/tuners/randlora/config.py create mode 100644 src/peft/tuners/randlora/layer.py create mode 100644 src/peft/tuners/randlora/model.py diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 122e1c1049..d74fdaccd4 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -95,6 +95,8 @@ VeraModel, XLoraConfig, XLoraModel, + RandLoraConfig, + RandLoraModel, get_eva_state_dict, initialize_lora_eva_weights, ) diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 65abbd4046..1d00b6fde4 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -43,7 +43,7 @@ from .vblora import VBLoRAConfig, VBLoRAModel from .vera import VeraConfig, VeraModel from .xlora import XLoraConfig, XLoraModel - +from .randlora import RandLoraConfig, RandLoraModel __all__ = [ "AdaLoraConfig", @@ -97,6 +97,8 @@ "VeraModel", "XLoraConfig", "XLoraModel", + "RandLoraConfig", + "RandLoraModel", "get_eva_state_dict", "initialize_lora_eva_weights", ] diff --git a/src/peft/tuners/randlora/__init__.py b/src/peft/tuners/randlora/__init__.py new file mode 100644 index 0000000000..f6d41dcc38 --- /dev/null +++ b/src/peft/tuners/randlora/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.utils import register_peft_method + +from .config import RandLoraConfig +from .layer import Linear, RandLoraLayer +from .model import RandLoraModel + + +__all__ = ["RandLoraConfig", "RandLoraLayer", "Linear", "RandLoraModel"] + +register_peft_method( + name="randlora", config_cls=RandLoraConfig, model_cls=RandLoraModel, prefix="randlora_" +) + +def __getattr__(name): + if (name == "Linear8bitLt") and is_bnb_available(): + from .bnb import Linear8bitLt + + return Linear8bitLt + + if (name == "Linear4bit") and is_bnb_4bit_available(): + from .bnb import Linear4bit + + return Linear4bit + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/peft/tuners/randlora/bnb.py b/src/peft/tuners/randlora/bnb.py new file mode 100644 index 0000000000..0e0ec7dd96 --- /dev/null +++ b/src/peft/tuners/randlora/bnb.py @@ -0,0 +1,449 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Optional + +import bitsandbytes as bnb +import torch + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import check_adapters_to_merge +from peft.utils.integrations import dequantize_bnb_weight +from peft.utils.other import transpose + +from .layer import RandLoraLayer, UniqueBaseGrad + + +if is_bnb_available(): + + class Linear8bitLt(torch.nn.Module, RandLoraLayer): + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + randlora_A, + randlora_B, + r: int = 0, + randlora_alpha: int = 0, + randlora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + RandLoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + randlora_A, + randlora_B, + r, + randlora_alpha=randlora_alpha, + randlora_dropout=randlora_dropout, + init_weights=init_weights, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + if self.merged: + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." + ) + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + for active_adapter in adapter_names: + if active_adapter not in self.randlora_lambda.keys(): + continue + + warnings.warn( + "Merge RandLora module to 8-bit linear may get different generations due to rounding errors." + ) + randlora_data = self.get_delta_weight(active_adapter) + + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + + output = dequantize_bnb_weight(weight, state) + w_data = output.to(randlora_data.dtype).to(randlora_data.device) + randlora_data + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.get_base_layer().weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + state.reset_grads() + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + if not self.merged: + warnings.warn("Already unmerged. Nothing to do") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.randlora_lambda.keys(): + continue + warnings.warn( + "Unmerge randlora module to 8-bit linear may get different generations due to rounding errors." + ) + randlora_data = self.get_delta_weight(active_adapter) + + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + output = dequantize_bnb_weight(weight, state=state) + + w_data = output.to(randlora_data.dtype).to(randlora_data.device) - randlora_data + + self.get_base_layer().weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + state.reset_grads() + + def get_scaled_bases(self, adapter) -> List[torch.Tensor, torch.Tensor, torch.dtype]: + """ + Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order + to fit the target layers' dimensions + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + randlora_A = self.randlora_A[adapter] + randlora_B = self.randlora_B[adapter] + + device = randlora_B.device + dtype = randlora_B.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + randlora_lambda = self.randlora_lambda[adapter] + randlora_gamma = self.randlora_gamma[adapter] + + if cast_to_fp32: + randlora_A = randlora_A.float() + randlora_B = randlora_B.float() + randlora_lambda = randlora_lambda.float() + randlora_gamma = randlora_gamma.float() + + #The trainable paramters are always applied to randlora_A, the smallest basis. + min_dim, max_dim = min(self.out_features, self.in_features), max(self.out_features, self.in_features) + + # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, + # we initialize these matrices with the largest required size for each dimension. + # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. + sliced_A = randlora_A[:, : self.n, : min_dim] + sliced_B = randlora_B[: max_dim, : self.n, :] + #Flattening the matrices over the rank and number of bases dimensions is more memory efficient + update_B = sliced_B.flatten(start_dim=1) + update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) + if min_dim == self.in_features: + return update_A, update_B, dtype + + return update_B.T, update_A.T, dtype + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + update_B, update_A, dtype = self.get_scaled_bases(adapter) + + update = update_B @ update_A + output_tensor = transpose(update, self.fan_in_fan_out) + + if dtype != self.randlora_B[adapter].dtype: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + # TODO: why?, taken from the VeRA implementation + self.randlora_lambda[adapter].data = randlora_lambda.to(dtype) + self.randlora_gamma[adapter].data = randlora_gamma.to(dtype) + + scaling = self.scaling[adapter] + + return output_tensor * scaling + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Perform the forward pass using the RandLora adapter. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the RandLora adaptation. + + Note: + This method implements the RandLora-specific forward pass. It applies the shared projections (randlora_A and + randlora_B) along with the per-layer trainable parameters (lambda and gamma) to compute the adapter + output. + """ + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.randlora_lambda.keys(): + continue + + update_B, update_A, dtype = self.get_scaled_bases(active_adapter) + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = update_A.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + dropout = self.randlora_dropout[active_adapter] + x_temp = dropout(x.to(update_A.dtype)) + + adapter_output = torch.nn.functional.linear( + torch.nn.functional.linear(x_temp, update_B), update_A + ) + + if requires_conversion: + adapter_output = adapter_output.to(expected_dtype) + + scaling = self.scaling[active_adapter] + result = result + adapter_output * scaling + + # Ensure the output tensor has the same dtype as the input tensor + return result.to(x.dtype) + + def __repr__(self) -> str: + rep = super().__repr__() + return "randlora." + rep + + +if is_bnb_4bit_available(): + + class Linear4bit(torch.nn.Module, RandLoraLayer): + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + randlora_A, + randlora_B, + r: int = 0, + randlora_alpha: int = 0, + randlora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + RandLoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + randlora_A, + randlora_B, + r, + randlora_alpha=randlora_alpha, + randlora_dropout=randlora_dropout, + init_weights=init_weights, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + if self.merged: + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." + ) + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + for active_adapter in adapter_names: + if active_adapter not in self.randlora_lambda.keys(): + continue + + warnings.warn( + "Merge RandLora module to 4-bit linear may get different generations due to rounding errors." + ) + randlora_data = self.get_delta_weight(active_adapter) + + weight = self.get_base_layer().weight + kwargs = weight.__dict__ + w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + randlora_data + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( + weight.device + ) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + if not self.merged: + warnings.warn("Already unmerged. Nothing to do") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.randlora_lambda.keys(): + continue + warnings.warn( + "Unmerge RandLora module to 4-bit linear may get different generations due to rounding errors." + ) + randlora_data = self.get_delta_weight(active_adapter) + + weight = self.get_base_layer().weight + kwargs = weight.__dict__ + w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - randlora_data + + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( + weight.device + ) + + def get_scaled_bases(self, adapter) -> List[torch.Tensor, torch.Tensor, torch.dtype]: + """ + Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order + to fit the target layers' dimensions + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + randlora_A = self.randlora_A[adapter] + randlora_B = self.randlora_B[adapter] + + device = randlora_B.device + dtype = randlora_B.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + randlora_lambda = self.randlora_lambda[adapter] + randlora_gamma = self.randlora_gamma[adapter] + + if cast_to_fp32: + randlora_A = randlora_A.float() + randlora_B = randlora_B.float() + randlora_lambda = randlora_lambda.float() + randlora_gamma = randlora_gamma.float() + + #The trainable paramters are always applied to randlora_A, the smallest basis. + min_dim, max_dim = min(self.out_features, self.in_features), max(self.out_features, self.in_features) + + # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, + # we initialize these matrices with the largest required size for each dimension. + # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. + sliced_A = randlora_A[:, : self.n, : min_dim] + sliced_B = randlora_B[: max_dim, : self.n, :] + #Flattening the matrices over the rank and number of bases dimensions is more memory efficient + update_B = sliced_B.flatten(start_dim=1) + update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) + if min_dim == self.in_features: + return update_A, update_B, dtype + + return update_B.T, update_A.T, dtype + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + update_B, update_A, dtype = self.get_scaled_bases(adapter) + + update = update_B @ update_A + output_tensor = transpose(update, self.fan_in_fan_out) + + if dtype != self.randlora_B[adapter].dtype: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + # TODO: why?, taken from the VeRA implementation + self.randlora_lambda[adapter].data = randlora_lambda.to(dtype) + self.randlora_gamma[adapter].data = randlora_gamma.to(dtype) + + scaling = self.scaling[adapter] + + return output_tensor * scaling + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + result = result.clone() + for active_adapter in self.active_adapters: + if active_adapter not in self.randlora_lambda.keys(): + continue + update_B, update_A, dtype = self.get_scaled_bases(active_adapter) + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = update_A.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + dropout = self.randlora_dropout[active_adapter] + x_temp = dropout(x.to(update_A.dtype)) + + adapter_output = torch.nn.functional.linear( + torch.nn.functional.linear(x_temp, update_B), update_A + ) + + if requires_conversion: + adapter_output = adapter_output.to(expected_dtype) + + scaling = self.scaling[active_adapter] + result = result + adapter_output * scaling + + # Ensure the output tensor has the same dtype as the input tensor + return result.to(x.dtype) + + def __repr__(self) -> str: + rep = super().__repr__() + return "randlora." + rep diff --git a/src/peft/tuners/randlora/config.py b/src/peft/tuners/randlora/config.py new file mode 100644 index 0000000000..b69b067abe --- /dev/null +++ b/src/peft/tuners/randlora/config.py @@ -0,0 +1,179 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +import math +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + +@dataclass +class RandLoraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`RandLoraModel`]. + + Paper: {}. + + Args: + r (`int`, *optional*, defaults to `32`): + RandLora's random basis rank dimension. This parameter is inversely proportional to the amount of trainable parameters. + target_modules (`Union[List[str], str]`): + The names of the modules to apply RandLora to. Only linear layers are supported. + projection_prng_key (`int`): + RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a checkpoint + that did not include these projections. Defaults to `int(math.exp(1)*3.1415*1000)`. + save_projection (`bool`): + Whether to save the global basis_A / basis_B random basis in the state dict alongside per layer lambda / gamma diagonal matrices. + weights. This will increase the size of the checkpoint, but guarantee that we can reload the checkpoint on + all system configurations. Defaults to `True`. + sparse (`bool`): + Whether to use sparse random bases as described in the RandLora paper. The current implementation is a proof of concept where the sparseness is not used to improve speed or memory usage. Defaults to `False`. + very_sparse (`bool`): + Whether to use very sparse random bases. The current implementation is a proof of concept where the sparseness is not used to improve speed or memory usage. Defaults to `False`. + randlora_dropout (`float`): + The dropout probability for RandLora layers. + randlora_alpha (`float`): + The scaling coefficient for RandLora layers, this would be typically be the same as LoRA, e.g. 2 times the rank. + fan_in_fan_out (`bool`): + Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses + `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. + bias (`str`): + Bias type. Can be 'none', 'all' or 'randlora_only'. If 'all' or 'randlora_only', the corresponding biases + will be updated during training. Be aware that this means that, even when disabling the adapters, the model + will not produce the same output as the base model would have without adaptation. + modules_to_save (`List[str]`): + List of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. + init_weights (`bool`): + Whether to initialize the weights of the RandLora layers with their default initialization. Don't change this + setting, except if you know exactly what you're doing. + layers_to_transform (`Union[List[int],int]`): + The layer indexes to transform, if this argument is specified, it will apply the RandLora transformations on + the layer indexes that are specified in this list. If a single integer is passed, it will apply the RandLora + transformations on the layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer + pattern is not in the common layers pattern. + """ + + r: int = field(default=32, metadata={"help": "RandLora random basis rank"}) + + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": ( + "List of module names or regex expression of the module names to replace with RandLora." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. " + "Only linear layers are supported." + ) + }, + ) + projection_prng_key: int = field( + default=int(math.exp(1)*3.1415*1000), + metadata={ + "help": ( + "RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a " + "checkpoint that did not include these projections." + ) + }, + ) + save_projection: bool = field( + default=True, + metadata={ + "help": ( + "Whether to save the basis_A / basis_B projections in the state dict alongside per layer lambda / " + "gamma weights. This will increase the size of the checkpoint, but guarantee that we can reload " + "the checkpoint on all system configurations." + ) + }, + ) + sparse: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use sparse random bases as described in the RandLora paper." + "The current implementation is a proof of concept where the sparseness" + "is not used to improve speed or memory usage." + ) + }, + ) + very_sparse: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use very sparse random bases." + "The current implementation is a proof of concept where the sparseness" + "is not used to improve speed or memory usage." + ) + }, + ) + randlora_dropout: float = field(default=0.0, metadata={"help": "Dropout in the adapter layers"}) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + randlora_alpha: int = field(default=64, metadata={"help": "Scaling coefficient in the adapter layers, typically 2 times the rank of the random bases."}) + bias: str = field(default="none", metadata={"help": "Bias type for RandLora. Can be 'none', 'all' or 'randlora_only'"}) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": ( + "List of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. For" + " example, in Sequence Classification or Token Classification tasks, the final layer" + " `classifier/score` are randomly initialized and as such need to be trainable and saved." + ) + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the RandLora layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": ( + "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers" + " indexes that are specified inside this list. If a single integer is passed, PEFT will transform only" + " the layer at this index." + ) + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer" + " pattern is not in the common layers pattern." + ) + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.RANDLORA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + + if not self.save_projection: + warnings.warn( + "Specified to not save basis_A and basis_B within the state dictionary, instead they will be restored " + "using the PRNG key store in `config.projection_prng_key`. Consider setting `config.save_projection` " + "to `True` to guarantee restoring the checkpoint correctly on all system configurations." + ) diff --git a/src/peft/tuners/randlora/layer.py b/src/peft/tuners/randlora/layer.py new file mode 100644 index 0000000000..2f9a7c776d --- /dev/null +++ b/src/peft/tuners/randlora/layer.py @@ -0,0 +1,351 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + +from .._buffer_dict import BufferDict + +class UniqueBaseGrad(torch.autograd.Function): + #Memory efficent for a unique base + @staticmethod + def forward(ctx, randlora_A, randlora_lambda, randlora_gamma): + Out = randlora_lambda[:, :, None] * randlora_A * randlora_gamma[None, ] + ctx.save_for_backward(randlora_A, randlora_lambda, randlora_gamma) + return Out + + @staticmethod + def backward(ctx, grad_output): + randlora_A, randlora_lambda, randlora_gamma = ctx.saved_tensors + randlora_A, randlora_lambda, randlora_gamma = randlora_A.to(grad_output.dtype), randlora_lambda.to(grad_output.dtype), randlora_gamma.to(grad_output.dtype) + grad_randlora_lambda = torch.einsum('kbj,kvj,bj->kb', grad_output, randlora_A, randlora_gamma) + grad_randlora_gamma = torch.einsum('kbj,kvj,kb->bj', grad_output, randlora_A, randlora_lambda) + return None, grad_randlora_lambda, grad_randlora_gamma + +class RandLoraLayer(BaseTunerLayer): + # List all names of layers that may contain adapter weights + adapter_layer_names = ("randlora_lambda", "randlora_gamma") + other_param_names = ("randlora_A", "randlora_B") + + def __init__(self, base_layer: nn.Module, **kwargs): + self.base_layer = base_layer + self.r = {} + self.scaling = {} + self.randlora_dropout = nn.ModuleDict({}) + + # For storing vector scale + self.randlora_lambda = nn.ParameterDict({}) + self.randlora_gamma = nn.ParameterDict({}) + self.randlora_m = nn.ParameterDict({}) + self.randlora_cache_norm_dora = nn.ParameterDict({}) + + # Stores a reference to the randlora_A/B BufferDict. + # Set to `None` otherwise to avoid computation with random weights + self.randlora_A: Optional[BufferDict] = None + self.randlora_B: Optional[BufferDict] = None + + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + + self.in_features = in_features + self.out_features = out_features + self.kwargs = kwargs + + @property + def merged(self) -> bool: + return bool(self.merged_adapters) + + def update_layer( + self, + adapter_name, + randlora_A: BufferDict, + randlora_B: BufferDict, + r, + randlora_alpha, + randlora_dropout, + init_weights, + ): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + self.r[adapter_name] = r + if randlora_dropout > 0.0: + randlora_dropout_layer = nn.Dropout(p=randlora_dropout) + else: + randlora_dropout_layer = nn.Identity() + + self.randlora_dropout.update(nn.ModuleDict({adapter_name: randlora_dropout_layer})) + + # Actual trainable parameters + n = min(self.in_features, self.out_features) / r + self.n = int(n) if n.is_integer() else int(n) + 1 #Full rank + self.randlora_lambda[adapter_name] = nn.Parameter(torch.zeros(r, self.n), requires_grad=True) + self.randlora_gamma[adapter_name] = nn.Parameter(torch.ones(self.n, min(self.out_features, self.in_features))/max(self.out_features, self.in_features), requires_grad=True) + + self.scaling[adapter_name] = randlora_alpha / r / math.sqrt(self.n)#* 10#/ math.sqrt(self.n) + + # non trainable references to randlora_A/B buffers + self.randlora_A = randlora_A + self.randlora_B = randlora_B + if adapter_name not in randlora_A: + # This means that this is not the first RandLora adapter. We have to add an entry in the dict for this adapter. + if len(self.randlora_A) < 1: + raise ValueError( + "The `randlora_A` and `randlora_B` buffers are empty. This should not happen. Please report this issue." + ) + # we can take any of the existing adapter's parameters, as they should all be identical + randlora_A_param = list(self.randlora_A.values())[0] + randlora_B_param = list(self.randlora_B.values())[0] + + error_tmpl = ( + "{} has a size of {} but {} or greater is required; this probably happened because an additional RandLora " + "adapter was added after the first one with incompatible shapes." + ) + # check input size + if randlora_A_param.shape[-1] < self.in_features: + raise ValueError(error_tmpl.format("randlora_A", randlora_A_param.shape[1], self.in_features)) + # check output size + if randlora_B_param.shape[0] < self.out_features: + raise ValueError(error_tmpl.format("randlora_B", randlora_B_param.shape[0], self.out_features)) + # check r + error_tmpl = ( + "{} has a size of {} but {} or greater is required; this probably happened because an additional RandLora " + "adapter with a lower rank was added after the first one; loading the adapters " + "in reverse order may solve this." + ) + if randlora_A_param.shape[0] < self.r[adapter_name]: + raise ValueError(error_tmpl.format("randlora_A", randlora_A_param.shape[0], self.r[adapter_name])) + if randlora_B_param.shape[1] < self.r[adapter_name]: + raise ValueError(error_tmpl.format("randlora_B", randlora_B_param.shape[1], self.r[adapter_name])) + + self.randlora_A[adapter_name] = randlora_A_param + self.randlora_B[adapter_name] = randlora_B_param + + if init_weights: + self.reset_randlora_parameters(adapter_name) + + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters) + + def reset_randlora_parameters(self, adapter_name): + if adapter_name in self.randlora_lambda.keys(): + with torch.no_grad(): + nn.init.zeros_(self.randlora_lambda[adapter_name]) + nn.init.zeros_(self.randlora_gamma[adapter_name]).fill_(1/max(self.randlora_gamma[adapter_name].shape)) + + +class Linear(nn.Linear, RandLoraLayer): + # RandLora implemented in a dense layer + def __init__( + self, + base_layer, + randlora_A: BufferDict, + randlora_B: BufferDict, + adapter_name: str, + r: int = 0, + randlora_alpha: int = 0, + randlora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + is_target_conv_1d_layer: bool = False, + init_weights: bool = True, + **kwargs, + ) -> None: + # this gets the init from nn.Linear's super perspective, i.e. nn.Module.__init__, which should always be called + super(nn.Linear, self).__init__() + RandLoraLayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.update_layer(adapter_name, randlora_A, randlora_B, r, randlora_alpha, randlora_dropout, init_weights) + self.is_target_conv_1d_layer = is_target_conv_1d_layer + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.randlora_lambda.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + if active_adapter in self.randlora_m.keys(): + norm = torch.linalg.norm(orig_weight, dim=1, keepdim=True) + self.randlora_cache_norm_dora[active_adapter] = norm + orig_weight *= (self.randlora_m[active_adapter] / norm).view(1, -1) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + base_layer.weight.data += self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.randlora_lambda.keys(): + if active_adapter in self.randlora_m.keys(): + ori_weight = self.get_base_layer().weight.data + delta = self.get_delta_weight(active_adapter) + if active_adapter in self.randlora_m.keys(): + norm = self.randlora_cache_norm_dora[active_adapter] + else: + norm = self.randlora_m[active_adapter].data + + self.get_base_layer().weight.data = \ + ori_weight * (norm / self.randlora_m[active_adapter]).view(1, -1) - delta + else: + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + + def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor, torch.dtype]: + """ + Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order + to fit the target layers' dimensions + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + randlora_A = self.randlora_A[adapter] + randlora_B = self.randlora_B[adapter] + + device = randlora_B.device + dtype = randlora_B.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + randlora_lambda = self.randlora_lambda[adapter] + randlora_gamma = self.randlora_gamma[adapter] + + if cast_to_fp32: + randlora_A = randlora_A.float() + randlora_B = randlora_B.float() + randlora_lambda = randlora_lambda.float() + randlora_gamma = randlora_gamma.float() + + #The trainable paramters are always applied to randlora_A, the smallest basis. + min_dim, max_dim = min(self.out_features, self.in_features), max(self.out_features, self.in_features) + + # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, + # we initialize these matrices with the largest required size for each dimension. + # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. + sliced_A = randlora_A[:, : self.n, : min_dim] + sliced_B = randlora_B[: max_dim, : self.n, :] + + #Flattening the matrices over the rank and number of bases dimensions is more memory efficient + update_B = sliced_B.flatten(start_dim=1) + update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) + if min_dim == self.in_features: + return update_A, update_B, dtype + return update_B.T, update_A.T, dtype + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + update_B, update_A, dtype = self.get_scaled_bases(adapter) + + update = (update_B.T @ update_A.T).T + output_tensor = transpose(update, self.fan_in_fan_out) + + if dtype != self.randlora_B[adapter].dtype: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + # TODO: why?, taken from the VeRA implementation + self.randlora_lambda[adapter].data = randlora_lambda.to(dtype) + self.randlora_gamma[adapter].data = randlora_gamma.to(dtype) + + scaling = self.scaling[adapter] + return output_tensor * scaling + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.randlora_lambda.keys(): + continue + dropout = self.randlora_dropout[active_adapter] + update_B, update_A, _ = self.get_scaled_bases(active_adapter) + x = x.to(update_A.dtype) + scaling = self.scaling[active_adapter] + if active_adapter in self.randlora_m.keys(): + update = update_A @ update_B * scaling + weight_update_norm = torch.linalg.norm(update + self.weight, dim=1, keepdim=True).detach() + lora_result = F.linear(F.linear(dropout(x), update_B), update_A) * scaling + result = (result + lora_result) * (self.randlora_m[active_adapter] / weight_update_norm).view(1, -1) + else: + result = result + F.linear(F.linear(dropout(x), update_B), update_A) * scaling + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "randlora." + rep diff --git a/src/peft/tuners/randlora/model.py b/src/peft/tuners/randlora/model.py new file mode 100644 index 0000000000..73305d02b5 --- /dev/null +++ b/src/peft/tuners/randlora/model.py @@ -0,0 +1,552 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import warnings +from dataclasses import asdict +from enum import Enum +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch.nn.init import _calculate_correct_fan +from tqdm import tqdm +from transformers.pytorch_utils import Conv1D + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _get_submodules, +) + +from .._buffer_dict import BufferDict +from ..tuners_utils import _maybe_include_all_linear_layers +from .config import RandLoraConfig +from .layer import Linear, RandLoraLayer + + +def _kaiming_init( + tensor_or_shape: Union[torch.Tensor, tuple[int, ...]], + generator: torch.Generator, +) -> torch.Tensor: + """ + Kaiming Uniform Initialisation adapted to accept a `torch.Generator` object for PRNG. + + Args: + tensor_or_shape (`Union[torch.Tensor, tuple[int, ...]]`): + Tensor to initialise, or shape of new tensor to create and then initialise. + generator: (`torch.Generator`): + Generator object that manages the state of the PRNG algorithm in use. + + Returns: + `torch.Tensor`: The initialised tensor. + """ + if isinstance(tensor_or_shape, tuple): + tensor = torch.empty(tensor_or_shape, dtype=torch.float32) + else: + tensor = tensor_or_shape + + with torch.no_grad(): + basis = torch.nn.init.kaiming_uniform_(tensor, a=math.sqrt(5), generator=generator) + return basis + + + +class RandLoraModel(BaseTuner): + """ + Creates a RandLoRA model from a pretrained transformers model. + + Args: + model ([`~transformers.PreTrainedModel`]): The model to be adapted. + config ([`RandLoraConfig`]): The configuration of the RandLora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. + + Returns: + `torch.nn.Module`: The RandLora model. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import RandLoraConfig, get_peft_model + + >>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + >>> config = RandLoraConfig(r=128) + >>> model = get_peft_model(base_model, config) + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`RandLoraConfig`]): The configuration of the RandLora model. + """ + + prefix: str = "randlora_" + + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) + + def _find_dim(self, config) -> tuple[int, int]: + """ + Finds the largest input and output dimensions across linear layers that have been wrapped with RandLora. + + This will be used for determining the size of the shared randlora_A and randlora_B matrices. + """ + model_config = self.get_model_config(self.model) + + peft_config = self._prepare_adapter_config(config, model_config) + peft_config = _maybe_include_all_linear_layers(peft_config, self.model) + + largest_shape = None + for key, module in self.model.named_modules(): + if not self._check_target_module_exists(peft_config, key): + continue + + if isinstance(module, nn.Linear): + module_shape = module.out_features, module.in_features + elif isinstance(module, Conv1D): + module_shape = module.weight.ds_shape if hasattr(module.weight, "ds_shape") else module.weight.shape + module_shape = module_shape[::-1] + else: + continue + + if largest_shape is None: + largest_shape = module_shape + continue + + if module_shape != largest_shape: + largest_shape = tuple(max(a, b) for a, b in zip(largest_shape, module_shape)) + + if largest_shape is None: + msg = "No layers types compatible with RandLora were found. Please check `peft_config.target_modules`." + raise ValueError(msg) + + return largest_shape + + def _init_randlora_A_randlora_B_sparse(self, config: RandLoraConfig, adapter_name: str, s: int = 3) -> None: + """ + Sparse random projections as described in https://cs-people.bu.edu/evimaria/cs565/kdd-rp.pdf + """ + + linear_out_dim, linear_in_dim = self._find_dim(config) + + # use of persistent to exclude randlora_A and randlora_B from the state dict if we choose not to save them. + self.randlora_A = BufferDict({}, persistent=config.save_projection) + self.randlora_B = BufferDict({}, persistent=config.save_projection) + + # deterministic init of randlora_A and randlora_B if we know the key + generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) + randlora_A = torch.rand((config.r, 1, linear_out_dim), generator=generator) #The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices + + #Ensure full rank + n = min(linear_out_dim, linear_in_dim) / config.r + n = int(n) if n.is_integer() else int(n) + 1 #Ensure full rank + randlora_B = torch.rand((linear_in_dim, n, config.r), generator=generator) + + #The current implementation is a proof of concept and does take into consideration + #the sparsity to reduce memory usage or speed up compute + randlora_B_sparse = torch.zeros(randlora_B.shape) + randlora_A_sparse = torch.zeros(randlora_A.shape) + randlora_B_sparse[randlora_B<1/(2*s)] = -1 + randlora_B_sparse[randlora_B>1-1/(2*s)] = 1 + randlora_A_sparse[randlora_A<1/(2*s)] = -1 + randlora_A_sparse[randlora_A>1-1/(2*s)] = 1 + + #Std normalization is empirically found to be the best + randlora_A, randlora_B = randlora_A_sparse / randlora_A_sparse.std(), randlora_B_sparse / randlora_B_sparse.std() + self.randlora_A[adapter_name] = randlora_A + self.randlora_B[adapter_name] = randlora_B + + def _init_randlora_A_randlora_B(self, config: RandLoraConfig, adapter_name: str) -> None: + linear_out_dim, linear_in_dim = self._find_dim(config) + # use of persistent to exclude randlora_A and randlora_B from the state dict if we choose not to save them. + self.randlora_A = BufferDict({}, persistent=config.save_projection) + self.randlora_B = BufferDict({}, persistent=config.save_projection) + + # deterministic init of randlora_A and randlora_B if we know the key + generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) + #The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices + randlora_A = _kaiming_init((config.r, 1, linear_out_dim), generator=generator) + + #Ensure full rank + n = min(linear_out_dim, linear_in_dim) / config.r + n = int(n) if n.is_integer() else int(n) + 1 + randlora_B = torch.cat([_kaiming_init((linear_in_dim, 1, config.r), generator=generator) for _ in range(n)], dim=1) + + #Std normalization is empirically found to be the best + randlora_A, randlora_B = randlora_A / randlora_A.std(), randlora_B / randlora_B.std() + self.randlora_A[adapter_name] = randlora_A + self.randlora_B[adapter_name] = randlora_B + + def _pre_injection_hook(self, model: nn.Module, config: RandLoraxConfig, adapter_name: str) -> None: + if config.very_sparse: + linear_out_dim, linear_in_dim = self._find_dim(config) + self._init_randlora_A_randlora_B_sparse(config, adapter_name, s=math.sqrt(min(linear_out_dim, linear_in_dim))) + elif config.sparse: + self._init_randlora_A_randlora_B_sparse(config, adapter_name, s=3) + else: + self._init_randlora_A_randlora_B(config, adapter_name) + + def _check_new_adapter_config(self, config: RandLora) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # the below todo is copied from LoRA + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + for existing_config in self.peft_config.values(): + if existing_config is config: + # skip the current config + continue + + if existing_config.projection_prng_key != config.projection_prng_key: + raise ValueError( + f"RandLora PRNG initialisation key must be the same for all adapters. Got {config.projection_prng_key=} but " + f"previous config had {existing_config.projection_prng_key}." + ) + + save_project_unique_values = sorted({config.save_projection for config in self.peft_config.values()}) + if len(save_project_unique_values) > 1: + raise ValueError( + "RandLora projection weights must be saved for all adapters or none, but got multiple different values: " + f"{save_project_unique_values}" + ) + + @staticmethod + def _check_target_module_exists(randlora_config, key): + return check_target_module_exists(randlora_config, key) + + def _create_and_replace( + self, + randlora_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + r = randlora_config.r + bias = hasattr(target, "bias") and target.bias is not None + kwargs = { + "r": r, + "randlora_alpha": randlora_config.randlora_alpha, + "randlora_dropout": randlora_config.randlora_dropout, + "fan_in_fan_out": randlora_config.fan_in_fan_out, + "init_weights": randlora_config.init_weights, + "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), + } + kwargs["bias"] = bias + if isinstance(target, Linear): + target.update_layer( + adapter_name, + self.randlora_A, + self.randlora_B, + r, + randlora_config.randlora_alpha, + randlora_config.randlora_dropout, + randlora_config.init_weights, + ) + else: + new_module = self._create_new_module(randlora_config, self.randlora_A, self.randlora_B, adapter_name, target, **kwargs) + if adapter_name not in self.active_adapter: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _replace_module(parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + meta = torch.device("meta") + # dispatch to correct device + for name, module in new_module.named_modules(): + if "randlora_" in name: + if not any(p.device == meta for p in module.parameters()): + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "randlora_only": + for m in model.modules(): + if isinstance(m, RandLoraLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(randlora_config, randlora_A, randlora_B, adapter_name, target, **kwargs): + # avoid eager bnb import + if is_bnb_available(): + import bitsandbytes as bnb + + from .bnb import Linear8bitLt + + if is_bnb_4bit_available(): + from .bnb import Linear4bit + + bias = kwargs.pop("bias", False) + loaded_in_8bit = kwargs.get("loaded_in_8bit", False) + loaded_in_4bit = kwargs.get("loaded_in_4bit", False) + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): + eightbit_kwargs = kwargs.copy() + eightbit_kwargs.update( + { + "has_fp16_weights": target_base_layer.state.has_fp16_weights, + "memory_efficient_backward": target_base_layer.state.memory_efficient_backward, + "threshold": target_base_layer.state.threshold, + "index": target_base_layer.index, + } + ) + return Linear8bitLt(target, adapter_name, randlora_A, randlora_B, **eightbit_kwargs) + elif loaded_in_4bit and isinstance(target_base_layer, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target_base_layer.compute_dtype, + "compress_statistics": target_base_layer.weight.compress_statistics, + "quant_type": target_base_layer.weight.quant_type, + } + ) + return Linear4bit(target, adapter_name, randlora_A, randlora_B, **fourbit_kwargs) + elif isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = randlora_config.fan_in_fan_out = False + elif isinstance(target_base_layer, Conv1D): + kwargs["is_target_conv_1d_layer"] = True + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = randlora_config.fan_in_fan_out = True + else: + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`, `transformers.pytorch_utils.Conv1D`." + ) + new_module = Linear( + target, + randlora_A, + randlora_B, + adapter_name, + bias=bias, + **kwargs, + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "model": # see #1892: prevent infinite recursion if class is not initialized + raise + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name): + for module in self.model.modules(): + if isinstance(module, RandLoraLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + # we cannot use self.prefix as we want to include non-trainable randlora parameters + key_list = [key for key, _ in self.model.named_modules() if "randlora" not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + + return self.model + + def delete_adapter(self, adapter_name: str): + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + # we cannot use self.prefix as we want to include non-trainable randlora parameters + key_list = [key for key, _ in self.model.named_modules() if "randlora" not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, RandLoraLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapter[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ): + r""" + This method merges the RandLora layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + + >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") + >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample" + >>> model = PeftModel.from_pretrained(base_model, peft_model_id) + >>> merged_model = model.merge_and_unload() + ``` + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self): + """ + Gets back the base model by removing all the RandLora modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 97474caad6..c6e710e58b 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -27,6 +27,7 @@ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, WEIGHTS_NAME, AuxiliaryTrainingWrapper, ModulesToSaveWrapper, @@ -66,6 +67,7 @@ "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", "TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING", "WEIGHTS_NAME", "AuxiliaryTrainingWrapper", "ModulesToSaveWrapper", diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index e04b1ea1a2..a1359f273d 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -291,6 +291,8 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "qwen2": ["q_proj", "v_proj"], } +TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING # Leaving this for now but RandLoRA is flexible + WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" CONFIG_NAME = "adapter_config.json" diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index bc352b6006..be3d565dcd 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -46,6 +46,7 @@ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, WEIGHTS_NAME, bloom_model_postprocess_past_key_value, starcoder_model_postprocess_past_key_value, diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index a205bd4550..023fbaed78 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -40,6 +40,7 @@ class PeftType(str, enum.Enum): - FOURIERFT - HRA - BONE + - RANDLORA """ PROMPT_TUNING = "PROMPT_TUNING" @@ -63,6 +64,7 @@ class PeftType(str, enum.Enum): VBLORA = "VBLORA" CPT = "CPT" BONE = "BONE" + RANDLORA = "RANDLORA" TRAINABLE_TOKENS = "TRAINABLE_TOKENS" From b1aabdebc5957ba2aacdbb0f262d41654dc3a6ef Mon Sep 17 00:00:00 2001 From: Paul Albert Date: Mon, 31 Mar 2025 16:07:03 +1100 Subject: [PATCH 02/13] added randlora tests to test_custom_models --- tests/test_custom_models.py | 172 ++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 392a4ef708..cc1fdd5e03 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -50,6 +50,7 @@ TrainableTokensConfig, VBLoRAConfig, VeraConfig, + RandLoraConfig, get_peft_model, ) from peft.tuners.tuners_utils import BaseTunerLayer @@ -511,6 +512,25 @@ TrainableTokensConfig, {"target_modules": ["emb"], "token_indices": [0, 1, 3], "init_weights": False}, ), + ######## + # RandLora # + ######## + ("Vanilla MLP 1 RandLora", "MLP", RandLoraConfig, {"target_modules": "lin0"}), + ("Vanilla MLP 2 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0"]}), + ("Vanilla MLP 3 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin1"]}), + ("Vanilla MLP 4 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"]}), + ( + "Vanilla MLP 5 RandLora", + "MLP", + RandLoraConfig, + {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}, + ), + ( + "Embedding + transformers Conv1D 1 RandLora", + "EmbConv1D", + RandLoraConfig, + {"target_modules": ["conv1d"]}, + ), ] # For this test matrix, each tuple consists of: @@ -617,6 +637,14 @@ {"target_modules": ["lin0"], "init_weights": False}, {"target_modules": ["lin0"], "init_weights": False}, ), + # Note: RandLora may present the same problem mentioned above for Vera. + ( + "RandLora Different", + "randlora", + RandLoraConfig, + {"target_modules": ["lin0"], "init_weights": False}, + {"target_modules": ["lin1"], "init_weights": False}, + ), ( "HRA Same", "hra", @@ -684,6 +712,7 @@ BOFTConfig: "boft_", LNTuningConfig: "ln_tuning_", VeraConfig: "vera_lambda_", + RandLoraConfig: "randlora_", FourierFTConfig: "fourierft_", HRAConfig: "hra_", VBLoRAConfig: "vblora_", @@ -3780,6 +3809,7 @@ def forward(self, X): # active adapter is still "default" self.check_requires_grad( peft_model, + "no" "base_model.model.lin1.vera_lambda_b.default", "base_model.model.lin1.vera_lambda_d.default", ) @@ -3886,6 +3916,148 @@ def forward(self, X): "base_model.model.lin2.vera_lambda_d.adapter1", ) + def test_requires_grad_randlora_different_targets(self): + # Test two different RandLora adapters that target different modules. Most notably, ensure that randbasis_A and randbasis_B + # don't require grads. + + # requires a model with at least 2 layers with the same shapes + class MLP2(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.relu = nn.ReLU() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.lin1 = nn.Linear(20, 20, bias=bias) # lin1 and lin2 have same shape + self.lin2 = nn.Linear(20, 20, bias=bias) + self.lin3 = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + X = self.relu(X) + X = self.lin2(X) + X = self.relu(X) + X = self.lin3(X) + X = self.sm(X) + return X + + config0 = VeraConfig(target_modules=["lin1"]) + peft_model = get_peft_model(MLP2(), config0) + + config1 = VeraConfig(target_modules=["lin2"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.default", + "base_model.model.lin1.randlora_gamma.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.default", + "base_model.model.lin1.randlora_gamma.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin2.randlora_lambda.adapter1", + "base_model.model.lin2.randlora_gamma.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin2.randlora_lambda.adapter1", + "base_model.model.lin2.randlora_gamma.adapter1", + ) + + def test_requires_grad_randlora_same_targets(self): + # Test two different RandLora adapters that target the same module. Most notably, ensure that randbasis_A and randbasis_B + # don't require grads. + + # requires a model with at least 2 layers with the same shapes + class MLP2(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.relu = nn.ReLU() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.lin1 = nn.Linear(20, 20, bias=bias) # lin1 and lin2 have same shape + self.lin2 = nn.Linear(20, 20, bias=bias) + self.lin3 = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + X = self.relu(X) + X = self.lin2(X) + X = self.relu(X) + X = self.lin3(X) + X = self.sm(X) + return X + + config0 = RandLoraConfig(target_modules=["lin1", "lin2"]) + peft_model = get_peft_model(MLP2(), config0) + + config1 = RandLoraConfig(target_modules=["lin1", "lin2"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.default", + "base_model.model.lin1.randlora_gamma.default", + "base_model.model.lin2.randlora_lambda.default", + "base_model.model.lin2.randlora_gamma.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.default", + "base_model.model.lin1.randlora_gamma.default", + "base_model.model.lin2.randlora_lambda.default", + "base_model.model.lin2.randlora_gamma.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.adapter1", + "base_model.model.lin1.randlora_gamma.adapter1", + "base_model.model.lin2.randlora_lambda.adapter1", + "base_model.model.lin2.randlora_gamma.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin1.randlora_lambda.adapter1", + "base_model.model.lin1.randlora_gamma.adapter1", + "base_model.model.lin2.randlora_lambda.adapter1", + "base_model.model.lin2.randlora_gamma.adapter1", + ) + def test_requires_grad_vblora_different_targets(self): # test two different VBLoRA adapters that target different modules config0 = VBLoRAConfig(target_modules=["lin0"], vector_length=1, num_vectors=2) From c25744c43ba1788aa69e4bf32f955e04085abc93 Mon Sep 17 00:00:00 2001 From: PaulAlbert31 Date: Tue, 8 Apr 2025 05:29:29 +0000 Subject: [PATCH 03/13] tests passing and fixes following feedback rebase --- src/peft/__init__.py | 6 +- src/peft/tuners/__init__.py | 7 +- src/peft/tuners/randlora/__init__.py | 9 +- src/peft/tuners/randlora/bnb.py | 73 ++++++++------- src/peft/tuners/randlora/config.py | 51 ++++++----- src/peft/tuners/randlora/layer.py | 127 ++++++++++++--------------- src/peft/tuners/randlora/model.py | 82 +++++++++-------- src/peft/utils/__init__.py | 4 +- src/peft/utils/constants.py | 4 +- src/peft/utils/other.py | 3 +- tests/test_custom_models.py | 28 +++--- 11 files changed, 202 insertions(+), 192 deletions(-) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index d74fdaccd4..51a42558d2 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -87,6 +87,8 @@ PromptEncoderReparameterizationType, PromptTuningConfig, PromptTuningInit, + RandLoraConfig, + RandLoraModel, TrainableTokensConfig, TrainableTokensModel, VBLoRAConfig, @@ -95,8 +97,6 @@ VeraModel, XLoraConfig, XLoraModel, - RandLoraConfig, - RandLoraModel, get_eva_state_dict, initialize_lora_eva_weights, ) @@ -180,6 +180,8 @@ "PromptLearningConfig", "PromptTuningConfig", "PromptTuningInit", + "RandLoraConfig", + "RandLoraModel", "TaskType", "TrainableTokensConfig", "TrainableTokensModel", diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 1d00b6fde4..bb38230bf0 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -39,11 +39,12 @@ from .poly import PolyConfig, PolyModel from .prefix_tuning import PrefixEncoder, PrefixTuningConfig from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit +from .randlora import RandLoraConfig, RandLoraModel from .trainable_tokens import TrainableTokensConfig, TrainableTokensModel from .vblora import VBLoRAConfig, VBLoRAModel from .vera import VeraConfig, VeraModel from .xlora import XLoraConfig, XLoraModel -from .randlora import RandLoraConfig, RandLoraModel + __all__ = [ "AdaLoraConfig", @@ -89,6 +90,8 @@ "PromptEncoderReparameterizationType", "PromptTuningConfig", "PromptTuningInit", + "RandLoraConfig", + "RandLoraModel", "TrainableTokensConfig", "TrainableTokensModel", "VBLoRAConfig", @@ -97,8 +100,6 @@ "VeraModel", "XLoraConfig", "XLoraModel", - "RandLoraConfig", - "RandLoraModel", "get_eva_state_dict", "initialize_lora_eva_weights", ] diff --git a/src/peft/tuners/randlora/__init__.py b/src/peft/tuners/randlora/__init__.py index f6d41dcc38..92b8ffe5e0 100644 --- a/src/peft/tuners/randlora/__init__.py +++ b/src/peft/tuners/randlora/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,11 +20,10 @@ from .model import RandLoraModel -__all__ = ["RandLoraConfig", "RandLoraLayer", "Linear", "RandLoraModel"] +__all__ = ["Linear", "RandLoraConfig", "RandLoraLayer", "RandLoraModel"] + +register_peft_method(name="randlora", config_cls=RandLoraConfig, model_cls=RandLoraModel, prefix="randlora_") -register_peft_method( - name="randlora", config_cls=RandLoraConfig, model_cls=RandLoraModel, prefix="randlora_" -) def __getattr__(name): if (name == "Linear8bitLt") and is_bnb_available(): diff --git a/src/peft/tuners/randlora/bnb.py b/src/peft/tuners/randlora/bnb.py index 0e0ec7dd96..f965dae75f 100644 --- a/src/peft/tuners/randlora/bnb.py +++ b/src/peft/tuners/randlora/bnb.py @@ -1,4 +1,4 @@ -# Copyright 2024-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -124,10 +124,11 @@ def unmerge(self) -> None: ).to(weight.device) state.reset_grads() - def get_scaled_bases(self, adapter) -> List[torch.Tensor, torch.Tensor, torch.dtype]: + def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor, torch.dtype]: """ - Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order - to fit the target layers' dimensions + Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the + correct order to fit the target layers' dimensions + Args: adapter (str): The name of the adapter for which the delta weight should be computed. @@ -153,15 +154,15 @@ def get_scaled_bases(self, adapter) -> List[torch.Tensor, torch.Tensor, torch.dt randlora_lambda = randlora_lambda.float() randlora_gamma = randlora_gamma.float() - #The trainable paramters are always applied to randlora_A, the smallest basis. + # The trainable paramters are always applied to randlora_A, the smallest basis. min_dim, max_dim = min(self.out_features, self.in_features), max(self.out_features, self.in_features) # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, # we initialize these matrices with the largest required size for each dimension. # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. - sliced_A = randlora_A[:, : self.n, : min_dim] - sliced_B = randlora_B[: max_dim, : self.n, :] - #Flattening the matrices over the rank and number of bases dimensions is more memory efficient + sliced_A = randlora_A[:, : self.n, :min_dim] + sliced_B = randlora_B[:max_dim, : self.n, :] + # Flattening the matrices over the rank and number of bases dimensions is more memory efficient update_B = sliced_B.flatten(start_dim=1) update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) if min_dim == self.in_features: @@ -188,11 +189,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor: # cast back the weights # TODO: why?, taken from the VeRA implementation - self.randlora_lambda[adapter].data = randlora_lambda.to(dtype) - self.randlora_gamma[adapter].data = randlora_gamma.to(dtype) + self.randlora_lambda[adapter].data = self.randlora_lambda[adapter].data.to(dtype) + self.randlora_gamma[adapter].data = self.randlora_gamma[adapter].data.to(dtype) scaling = self.scaling[adapter] - + return output_tensor * scaling def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: @@ -206,9 +207,9 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: torch.Tensor: Output tensor after applying the RandLora adaptation. Note: - This method implements the RandLora-specific forward pass. It applies the shared projections (randlora_A and - randlora_B) along with the per-layer trainable parameters (lambda and gamma) to compute the adapter - output. + This method implements the RandLora-specific forward pass. It applies the shared projections + (randlora_A and randlora_B) along with the per-layer trainable parameters (lambda and gamma) to compute + the adapter output. """ if self.disable_adapters: if self.merged: @@ -221,7 +222,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: for active_adapter in self.active_adapters: if active_adapter not in self.randlora_lambda.keys(): continue - + update_B, update_A, dtype = self.get_scaled_bases(active_adapter) requires_conversion = not torch.is_autocast_enabled() if requires_conversion: @@ -232,14 +233,12 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: dropout = self.randlora_dropout[active_adapter] x_temp = dropout(x.to(update_A.dtype)) - - adapter_output = torch.nn.functional.linear( - torch.nn.functional.linear(x_temp, update_B), update_A - ) + + adapter_output = torch.nn.functional.linear(torch.nn.functional.linear(x_temp, update_B), update_A) if requires_conversion: adapter_output = adapter_output.to(expected_dtype) - + scaling = self.scaling[active_adapter] result = result + adapter_output * scaling @@ -337,10 +336,11 @@ def unmerge(self) -> None: weight.device ) - def get_scaled_bases(self, adapter) -> List[torch.Tensor, torch.Tensor, torch.dtype]: + def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor, torch.dtype]: """ - Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order - to fit the target layers' dimensions + Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the + correct order to fit the target layers' dimensions + Args: adapter (str): The name of the adapter for which the delta weight should be computed. @@ -366,15 +366,15 @@ def get_scaled_bases(self, adapter) -> List[torch.Tensor, torch.Tensor, torch.dt randlora_lambda = randlora_lambda.float() randlora_gamma = randlora_gamma.float() - #The trainable paramters are always applied to randlora_A, the smallest basis. + # The trainable paramters are always applied to randlora_A, the smallest basis. min_dim, max_dim = min(self.out_features, self.in_features), max(self.out_features, self.in_features) # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, # we initialize these matrices with the largest required size for each dimension. - # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. - sliced_A = randlora_A[:, : self.n, : min_dim] - sliced_B = randlora_B[: max_dim, : self.n, :] - #Flattening the matrices over the rank and number of bases dimensions is more memory efficient + # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. + sliced_A = randlora_A[:, : self.n, :min_dim] + sliced_B = randlora_B[:max_dim, : self.n, :] + # Flattening the matrices over the rank and number of bases dimensions is more memory efficient update_B = sliced_B.flatten(start_dim=1) update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) if min_dim == self.in_features: @@ -385,6 +385,7 @@ def get_scaled_bases(self, adapter) -> List[torch.Tensor, torch.Tensor, torch.dt def get_delta_weight(self, adapter) -> torch.Tensor: """ Compute the delta weight for the given adapter. + Args: adapter (str): The name of the adapter for which the delta weight should be computed. @@ -400,13 +401,13 @@ def get_delta_weight(self, adapter) -> torch.Tensor: # cast back the weights # TODO: why?, taken from the VeRA implementation - self.randlora_lambda[adapter].data = randlora_lambda.to(dtype) - self.randlora_gamma[adapter].data = randlora_gamma.to(dtype) + self.randlora_lambda[adapter].data = self.randlora_lambda[adapter].to(dtype) + self.randlora_gamma[adapter].data = self.randlora_gamma[adapter].to(dtype) scaling = self.scaling[adapter] - + return output_tensor * scaling - + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if self.disable_adapters: if self.merged: @@ -419,7 +420,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: result = result.clone() for active_adapter in self.active_adapters: if active_adapter not in self.randlora_lambda.keys(): - continue + continue update_B, update_A, dtype = self.get_scaled_bases(active_adapter) requires_conversion = not torch.is_autocast_enabled() if requires_conversion: @@ -431,16 +432,14 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: dropout = self.randlora_dropout[active_adapter] x_temp = dropout(x.to(update_A.dtype)) - adapter_output = torch.nn.functional.linear( - torch.nn.functional.linear(x_temp, update_B), update_A - ) + adapter_output = torch.nn.functional.linear(torch.nn.functional.linear(x_temp, update_B), update_A) if requires_conversion: adapter_output = adapter_output.to(expected_dtype) scaling = self.scaling[active_adapter] result = result + adapter_output * scaling - + # Ensure the output tensor has the same dtype as the input tensor return result.to(x.dtype) diff --git a/src/peft/tuners/randlora/config.py b/src/peft/tuners/randlora/config.py index b69b067abe..ea730055e2 100644 --- a/src/peft/tuners/randlora/config.py +++ b/src/peft/tuners/randlora/config.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,40 +13,44 @@ # limitations under the License. import warnings -import math from dataclasses import dataclass, field from typing import List, Optional, Union from peft.config import PeftConfig from peft.utils import PeftType + @dataclass class RandLoraConfig(PeftConfig): """ This is the configuration class to store the configuration of a [`RandLoraModel`]. - Paper: {}. + Paper: https://arxiv.org/pdf/2502.00987. Args: r (`int`, *optional*, defaults to `32`): - RandLora's random basis rank dimension. This parameter is inversely proportional to the amount of trainable parameters. + RandLora's random basis rank dimension. This parameter is inversely proportional to the amount of trainable + parameters. target_modules (`Union[List[str], str]`): The names of the modules to apply RandLora to. Only linear layers are supported. projection_prng_key (`int`): - RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a checkpoint - that did not include these projections. Defaults to `int(math.exp(1)*3.1415*1000)`. + RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a + checkpoint that did not include these projections. Defaults to `0`. save_projection (`bool`): - Whether to save the global basis_A / basis_B random basis in the state dict alongside per layer lambda / gamma diagonal matrices. - weights. This will increase the size of the checkpoint, but guarantee that we can reload the checkpoint on - all system configurations. Defaults to `True`. + Whether to save the global basis_A / basis_B random basis in the state dict alongside per layer lambda / + gamma diagonal matrices. This will increase the size of the checkpoint, but guarantee that we can + reload the checkpoint on all system configurations. Defaults to `True`. sparse (`bool`): - Whether to use sparse random bases as described in the RandLora paper. The current implementation is a proof of concept where the sparseness is not used to improve speed or memory usage. Defaults to `False`. + Whether to use sparse random bases as described in the RandLora paper. The current implementation is a + proof of concept where the sparseness is not used to improve speed or memory usage. Defaults to `False`. very_sparse (`bool`): - Whether to use very sparse random bases. The current implementation is a proof of concept where the sparseness is not used to improve speed or memory usage. Defaults to `False`. + Whether to use very sparse random bases. The current implementation is a proof of concept where the + sparseness is not used to improve speed or memory usage. Defaults to `False`. randlora_dropout (`float`): The dropout probability for RandLora layers. randlora_alpha (`float`): - The scaling coefficient for RandLora layers, this would be typically be the same as LoRA, e.g. 2 times the rank. + The scaling coefficient for RandLora layers, this would be typically be the same as LoRA, e.g. 2 times the + rank. fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. @@ -57,12 +61,12 @@ class RandLoraConfig(PeftConfig): modules_to_save (`List[str]`): List of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. init_weights (`bool`): - Whether to initialize the weights of the RandLora layers with their default initialization. Don't change this - setting, except if you know exactly what you're doing. + Whether to initialize the weights of the RandLora layers with their default initialization. Don't change + this setting, except if you know exactly what you're doing. layers_to_transform (`Union[List[int],int]`): - The layer indexes to transform, if this argument is specified, it will apply the RandLora transformations on - the layer indexes that are specified in this list. If a single integer is passed, it will apply the RandLora - transformations on the layer at this index. + The layer indexes to transform, if this argument is specified, it will apply the RandLora transformations + on the layer indexes that are specified in this list. If a single integer is passed, it will apply the + RandLora transformations on the layer at this index. layers_pattern (`str`): The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer pattern is not in the common layers pattern. @@ -81,7 +85,7 @@ class RandLoraConfig(PeftConfig): }, ) projection_prng_key: int = field( - default=int(math.exp(1)*3.1415*1000), + default=0, metadata={ "help": ( "RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a " @@ -124,8 +128,15 @@ class RandLoraConfig(PeftConfig): default=False, metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, ) - randlora_alpha: int = field(default=64, metadata={"help": "Scaling coefficient in the adapter layers, typically 2 times the rank of the random bases."}) - bias: str = field(default="none", metadata={"help": "Bias type for RandLora. Can be 'none', 'all' or 'randlora_only'"}) + randlora_alpha: int = field( + default=64, + metadata={ + "help": "Scaling coefficient in the adapter layers, typically 2 times the rank of the random bases." + }, + ) + bias: str = field( + default="none", metadata={"help": "Bias type for RandLora. Can be 'none', 'all' or 'randlora_only'"} + ) modules_to_save: Optional[List[str]] = field( default=None, metadata={ diff --git a/src/peft/tuners/randlora/layer.py b/src/peft/tuners/randlora/layer.py index 2f9a7c776d..0aa1c2fc11 100644 --- a/src/peft/tuners/randlora/layer.py +++ b/src/peft/tuners/randlora/layer.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import warnings from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F -import math from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge @@ -26,22 +26,28 @@ from .._buffer_dict import BufferDict + class UniqueBaseGrad(torch.autograd.Function): - #Memory efficent for a unique base - @staticmethod + # Memory efficent for a unique base + @staticmethod def forward(ctx, randlora_A, randlora_lambda, randlora_gamma): - Out = randlora_lambda[:, :, None] * randlora_A * randlora_gamma[None, ] + Out = randlora_lambda[:, :, None] * randlora_A * randlora_gamma[None,] ctx.save_for_backward(randlora_A, randlora_lambda, randlora_gamma) - return Out - + return Out + @staticmethod def backward(ctx, grad_output): randlora_A, randlora_lambda, randlora_gamma = ctx.saved_tensors - randlora_A, randlora_lambda, randlora_gamma = randlora_A.to(grad_output.dtype), randlora_lambda.to(grad_output.dtype), randlora_gamma.to(grad_output.dtype) - grad_randlora_lambda = torch.einsum('kbj,kvj,bj->kb', grad_output, randlora_A, randlora_gamma) - grad_randlora_gamma = torch.einsum('kbj,kvj,kb->bj', grad_output, randlora_A, randlora_lambda) + randlora_A, randlora_lambda, randlora_gamma = ( + randlora_A.to(grad_output.dtype), + randlora_lambda.to(grad_output.dtype), + randlora_gamma.to(grad_output.dtype), + ) + grad_randlora_lambda = torch.einsum("kbj,kvj,bj->kb", grad_output, randlora_A, randlora_gamma) + grad_randlora_gamma = torch.einsum("kbj,kvj,kb->bj", grad_output, randlora_A, randlora_lambda) return None, grad_randlora_lambda, grad_randlora_gamma + class RandLoraLayer(BaseTunerLayer): # List all names of layers that may contain adapter weights adapter_layer_names = ("randlora_lambda", "randlora_gamma") @@ -56,8 +62,6 @@ def __init__(self, base_layer: nn.Module, **kwargs): # For storing vector scale self.randlora_lambda = nn.ParameterDict({}) self.randlora_gamma = nn.ParameterDict({}) - self.randlora_m = nn.ParameterDict({}) - self.randlora_cache_norm_dora = nn.ParameterDict({}) # Stores a reference to the randlora_A/B BufferDict. # Set to `None` otherwise to avoid computation with random weights @@ -105,12 +109,16 @@ def update_layer( self.randlora_dropout.update(nn.ModuleDict({adapter_name: randlora_dropout_layer})) # Actual trainable parameters - n = min(self.in_features, self.out_features) / r - self.n = int(n) if n.is_integer() else int(n) + 1 #Full rank - self.randlora_lambda[adapter_name] = nn.Parameter(torch.zeros(r, self.n), requires_grad=True) - self.randlora_gamma[adapter_name] = nn.Parameter(torch.ones(self.n, min(self.out_features, self.in_features))/max(self.out_features, self.in_features), requires_grad=True) + num_bases = min(self.in_features, self.out_features) / r + self.num_bases = int(num_bases) if num_bases.is_integer() else int(num_bases) + 1 # Full rank + self.randlora_lambda[adapter_name] = nn.Parameter(torch.randn(r, self.num_bases), requires_grad=True) + self.randlora_gamma[adapter_name] = nn.Parameter( + torch.ones(self.num_bases, min(self.out_features, self.in_features)) + / max(self.out_features, self.in_features), + requires_grad=True, + ) - self.scaling[adapter_name] = randlora_alpha / r / math.sqrt(self.n)#* 10#/ math.sqrt(self.n) + self.scaling[adapter_name] = randlora_alpha / r / math.sqrt(self.num_bases) # non trainable references to randlora_A/B buffers self.randlora_A = randlora_A @@ -129,12 +137,14 @@ def update_layer( "{} has a size of {} but {} or greater is required; this probably happened because an additional RandLora " "adapter was added after the first one with incompatible shapes." ) + max_dim, min_dim = max(self.in_features, self.out_features), min(self.in_features, self.out_features) # check input size - if randlora_A_param.shape[-1] < self.in_features: - raise ValueError(error_tmpl.format("randlora_A", randlora_A_param.shape[1], self.in_features)) + if randlora_B_param.shape[0] < max_dim: + raise ValueError(error_tmpl.format("randlora_B", randlora_B_param.shape[0], max_dim)) # check output size - if randlora_B_param.shape[0] < self.out_features: - raise ValueError(error_tmpl.format("randlora_B", randlora_B_param.shape[0], self.out_features)) + if randlora_A_param.shape[-1] < min_dim: + raise ValueError(error_tmpl.format("randlora_A", randlora_A_param.shape[1], min_dim)) + # check r error_tmpl = ( "{} has a size of {} but {} or greater is required; this probably happened because an additional RandLora " @@ -143,8 +153,8 @@ def update_layer( ) if randlora_A_param.shape[0] < self.r[adapter_name]: raise ValueError(error_tmpl.format("randlora_A", randlora_A_param.shape[0], self.r[adapter_name])) - if randlora_B_param.shape[1] < self.r[adapter_name]: - raise ValueError(error_tmpl.format("randlora_B", randlora_B_param.shape[1], self.r[adapter_name])) + if randlora_B_param.shape[-1] < self.r[adapter_name]: + raise ValueError(error_tmpl.format("randlora_B", randlora_B_param.shape[-1], self.r[adapter_name])) self.randlora_A[adapter_name] = randlora_A_param self.randlora_B[adapter_name] = randlora_B_param @@ -159,7 +169,9 @@ def reset_randlora_parameters(self, adapter_name): if adapter_name in self.randlora_lambda.keys(): with torch.no_grad(): nn.init.zeros_(self.randlora_lambda[adapter_name]) - nn.init.zeros_(self.randlora_gamma[adapter_name]).fill_(1/max(self.randlora_gamma[adapter_name].shape)) + nn.init.ones_(self.randlora_gamma[adapter_name]).fill_( + 1 / max(self.randlora_gamma[adapter_name].shape) + ) class Linear(nn.Linear, RandLoraLayer): @@ -211,10 +223,8 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weights = base_layer.weight.data.clone() - if active_adapter in self.randlora_m.keys(): - norm = torch.linalg.norm(orig_weight, dim=1, keepdim=True) - self.randlora_cache_norm_dora[active_adapter] = norm - orig_weight *= (self.randlora_m[active_adapter] / norm).view(1, -1) + + orig_weights += self.get_delta_weight(active_adapter) if not torch.isfinite(orig_weights).all(): raise ValueError( @@ -234,28 +244,18 @@ def unmerge(self) -> None: while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.randlora_lambda.keys(): - if active_adapter in self.randlora_m.keys(): - ori_weight = self.get_base_layer().weight.data - delta = self.get_delta_weight(active_adapter) - if active_adapter in self.randlora_m.keys(): - norm = self.randlora_cache_norm_dora[active_adapter] - else: - norm = self.randlora_m[active_adapter].data - - self.get_base_layer().weight.data = \ - ori_weight * (norm / self.randlora_m[active_adapter]).view(1, -1) - delta - else: - self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor, torch.dtype]: """ - Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order - to fit the target layers' dimensions + Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct + order to fit the target layers' dimensions + Args: adapter (str): The name of the adapter for which the delta weight should be computed. """ - + randlora_A = self.randlora_A[adapter] randlora_B = self.randlora_B[adapter] @@ -276,22 +276,24 @@ def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor, torch.d randlora_lambda = randlora_lambda.float() randlora_gamma = randlora_gamma.float() - #The trainable paramters are always applied to randlora_A, the smallest basis. + # The trainable paramters are always applied to randlora_A, the smallest basis. min_dim, max_dim = min(self.out_features, self.in_features), max(self.out_features, self.in_features) - + # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, # we initialize these matrices with the largest required size for each dimension. - # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. - sliced_A = randlora_A[:, : self.n, : min_dim] - sliced_B = randlora_B[: max_dim, : self.n, :] - - #Flattening the matrices over the rank and number of bases dimensions is more memory efficient + # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. + sliced_A = randlora_A[:, : self.num_bases, :min_dim] + sliced_B = randlora_B[:max_dim, : self.num_bases, :] + + # Flattening the matrices over the rank and number of bases dimensions is more memory efficient update_B = sliced_B.flatten(start_dim=1) update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) + + # Since update_A is applied on the smallest dimension, test whether update_A or update_B should applied first. This is done to reduce trainable parameters. if min_dim == self.in_features: - return update_A, update_B, dtype + return update_A, update_B, dtype return update_B.T, update_A.T, dtype - + def get_delta_weight(self, adapter) -> torch.Tensor: """ Compute the delta weight for the given adapter. @@ -302,18 +304,10 @@ def get_delta_weight(self, adapter) -> torch.Tensor: """ update_B, update_A, dtype = self.get_scaled_bases(adapter) - + update = (update_B.T @ update_A.T).T output_tensor = transpose(update, self.fan_in_fan_out) - if dtype != self.randlora_B[adapter].dtype: - output_tensor = output_tensor.to(dtype=dtype) - - # cast back the weights - # TODO: why?, taken from the VeRA implementation - self.randlora_lambda[adapter].data = randlora_lambda.to(dtype) - self.randlora_gamma[adapter].data = randlora_gamma.to(dtype) - scaling = self.scaling[adapter] return output_tensor * scaling @@ -330,19 +324,12 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: result = self.base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: if active_adapter not in self.randlora_lambda.keys(): - continue + continue dropout = self.randlora_dropout[active_adapter] update_B, update_A, _ = self.get_scaled_bases(active_adapter) x = x.to(update_A.dtype) scaling = self.scaling[active_adapter] - if active_adapter in self.randlora_m.keys(): - update = update_A @ update_B * scaling - weight_update_norm = torch.linalg.norm(update + self.weight, dim=1, keepdim=True).detach() - lora_result = F.linear(F.linear(dropout(x), update_B), update_A) * scaling - result = (result + lora_result) * (self.randlora_m[active_adapter] / weight_update_norm).view(1, -1) - else: - result = result + F.linear(F.linear(dropout(x), update_B), update_A) * scaling - + result = result + F.linear(F.linear(dropout(x), update_B), update_A) * scaling result = result.to(previous_dtype) return result diff --git a/src/peft/tuners/randlora/model.py b/src/peft/tuners/randlora/model.py index 73305d02b5..ff1fb9f43d 100644 --- a/src/peft/tuners/randlora/model.py +++ b/src/peft/tuners/randlora/model.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,6 @@ import torch import torch.nn as nn -from torch.nn.init import _calculate_correct_fan from tqdm import tqdm from transformers.pytorch_utils import Conv1D @@ -64,7 +63,6 @@ def _kaiming_init( with torch.no_grad(): basis = torch.nn.init.kaiming_uniform_(tensor, a=math.sqrt(5), generator=generator) return basis - class RandLoraModel(BaseTuner): @@ -145,6 +143,7 @@ def _init_randlora_A_randlora_B_sparse(self, config: RandLoraConfig, adapter_nam """ linear_out_dim, linear_in_dim = self._find_dim(config) + max_dim, min_dim = max(linear_out_dim, linear_in_dim), min(linear_out_dim, linear_in_dim) # use of persistent to exclude randlora_A and randlora_B from the state dict if we choose not to save them. self.randlora_A = BufferDict({}, persistent=config.save_projection) @@ -152,58 +151,68 @@ def _init_randlora_A_randlora_B_sparse(self, config: RandLoraConfig, adapter_nam # deterministic init of randlora_A and randlora_B if we know the key generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) - randlora_A = torch.rand((config.r, 1, linear_out_dim), generator=generator) #The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices - - #Ensure full rank - n = min(linear_out_dim, linear_in_dim) / config.r - n = int(n) if n.is_integer() else int(n) + 1 #Ensure full rank - randlora_B = torch.rand((linear_in_dim, n, config.r), generator=generator) + # The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices. + # We also set randlora_A as the smallest matrix to reduce trainable parameters. + randlora_A = torch.rand((config.r, 1, min_dim), generator=generator) + + # Ensure full rank + n = min_dim / config.r + n = int(n) if n.is_integer() else int(n) + 1 # Ensure full rank + randlora_B = torch.rand((max_dim, n, config.r), generator=generator) - #The current implementation is a proof of concept and does take into consideration - #the sparsity to reduce memory usage or speed up compute + # The current implementation is a proof of concept and does take into consideration + # the sparsity to reduce memory usage or speed up compute randlora_B_sparse = torch.zeros(randlora_B.shape) randlora_A_sparse = torch.zeros(randlora_A.shape) - randlora_B_sparse[randlora_B<1/(2*s)] = -1 - randlora_B_sparse[randlora_B>1-1/(2*s)] = 1 - randlora_A_sparse[randlora_A<1/(2*s)] = -1 - randlora_A_sparse[randlora_A>1-1/(2*s)] = 1 - - #Std normalization is empirically found to be the best - randlora_A, randlora_B = randlora_A_sparse / randlora_A_sparse.std(), randlora_B_sparse / randlora_B_sparse.std() + randlora_B_sparse[randlora_B < 1 / (2 * s)] = -1 + randlora_B_sparse[randlora_B > 1 - 1 / (2 * s)] = 1 + randlora_A_sparse[randlora_A < 1 / (2 * s)] = -1 + randlora_A_sparse[randlora_A > 1 - 1 / (2 * s)] = 1 + + # Std normalization is empirically found to be the best + randlora_A, randlora_B = ( + randlora_A_sparse / randlora_A_sparse.std(), + randlora_B_sparse / randlora_B_sparse.std(), + ) self.randlora_A[adapter_name] = randlora_A self.randlora_B[adapter_name] = randlora_B - + def _init_randlora_A_randlora_B(self, config: RandLoraConfig, adapter_name: str) -> None: linear_out_dim, linear_in_dim = self._find_dim(config) + max_dim, min_dim = max(linear_out_dim, linear_in_dim), min(linear_out_dim, linear_in_dim) + # use of persistent to exclude randlora_A and randlora_B from the state dict if we choose not to save them. self.randlora_A = BufferDict({}, persistent=config.save_projection) self.randlora_B = BufferDict({}, persistent=config.save_projection) # deterministic init of randlora_A and randlora_B if we know the key generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) - #The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices - randlora_A = _kaiming_init((config.r, 1, linear_out_dim), generator=generator) - - #Ensure full rank + # The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices. + # We also set randlora_A as the smallest matrix to reduce trainable parameters. + randlora_A = _kaiming_init((config.r, 1, min_dim), generator=generator) + + # Ensure full rank n = min(linear_out_dim, linear_in_dim) / config.r n = int(n) if n.is_integer() else int(n) + 1 - randlora_B = torch.cat([_kaiming_init((linear_in_dim, 1, config.r), generator=generator) for _ in range(n)], dim=1) - - #Std normalization is empirically found to be the best + randlora_B = torch.cat([_kaiming_init((max_dim, 1, config.r), generator=generator) for _ in range(n)], dim=1) + + # Std normalization is empirically found to be the best randlora_A, randlora_B = randlora_A / randlora_A.std(), randlora_B / randlora_B.std() self.randlora_A[adapter_name] = randlora_A self.randlora_B[adapter_name] = randlora_B - def _pre_injection_hook(self, model: nn.Module, config: RandLoraxConfig, adapter_name: str) -> None: + def _pre_injection_hook(self, model: nn.Module, config: RandLoraConfig, adapter_name: str) -> None: if config.very_sparse: linear_out_dim, linear_in_dim = self._find_dim(config) - self._init_randlora_A_randlora_B_sparse(config, adapter_name, s=math.sqrt(min(linear_out_dim, linear_in_dim))) + self._init_randlora_A_randlora_B_sparse( + config, adapter_name, s=math.sqrt(min(linear_out_dim, linear_in_dim)) + ) elif config.sparse: self._init_randlora_A_randlora_B_sparse(config, adapter_name, s=3) else: self._init_randlora_A_randlora_B(config, adapter_name) - def _check_new_adapter_config(self, config: RandLora) -> None: + def _check_new_adapter_config(self, config: RandLoraConfig) -> None: """ A helper method to check the config when a new adapter is being added. @@ -277,7 +286,9 @@ def _create_and_replace( randlora_config.init_weights, ) else: - new_module = self._create_new_module(randlora_config, self.randlora_A, self.randlora_B, adapter_name, target, **kwargs) + new_module = self._create_new_module( + randlora_config, self.randlora_A, self.randlora_B, adapter_name, target, **kwargs + ) if adapter_name not in self.active_adapter: # adding an additional adapter: it is not automatically trainable new_module.requires_grad_(False) @@ -385,8 +396,7 @@ def _create_new_module(randlora_config, randlora_A, randlora_B, adapter_name, ta kwargs["is_target_conv_1d_layer"] = True if not kwargs["fan_in_fan_out"]: warnings.warn( - "fan_in_fan_out is set to False but the target module is `Conv1D`. " - "Setting fan_in_fan_out to True." + "fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True." ) kwargs["fan_in_fan_out"] = randlora_config.fan_in_fan_out = True else: @@ -515,8 +525,8 @@ def merge_and_unload( self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None ): r""" - This method merges the RandLora layers into the base model. This is needed if someone wants to use the base model - as a standalone model. + This method merges the RandLora layers into the base model. This is needed if someone wants to use the base + model as a standalone model. Args: progressbar (`bool`): @@ -546,7 +556,7 @@ def merge_and_unload( def unload(self): """ - Gets back the base model by removing all the RandLora modules without merging. This gives back the original base - model. + Gets back the base model by removing all the RandLora modules without merging. This gives back the original + base model. """ return self._unload_and_optionally_merge(merge=False) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index c6e710e58b..93343f0575 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -25,9 +25,9 @@ TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, WEIGHTS_NAME, AuxiliaryTrainingWrapper, ModulesToSaveWrapper, @@ -65,9 +65,9 @@ "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", + "TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", - "TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING", "WEIGHTS_NAME", "AuxiliaryTrainingWrapper", "ModulesToSaveWrapper", diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index a1359f273d..0936d95110 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -291,7 +291,9 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "qwen2": ["q_proj", "v_proj"], } -TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING # Leaving this for now but RandLoRA is flexible +TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING = ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING # Leaving this for now but RandLoRA is flexible +) WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index be3d565dcd..86669ffcaa 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -44,9 +44,9 @@ TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, WEIGHTS_NAME, bloom_model_postprocess_past_key_value, starcoder_model_postprocess_past_key_value, @@ -72,6 +72,7 @@ "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", + "TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", "WEIGHTS_NAME", diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index cc1fdd5e03..f8eed7892a 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # coding=utf-8 -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import platform import re import shutil +import sys import tempfile import time import unittest @@ -33,6 +34,8 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification from transformers.pytorch_utils import Conv1D + +sys.path.append(os.path.join(os.getcwd(), "src")) from peft import ( AdaLoraConfig, BOFTConfig, @@ -46,11 +49,11 @@ LoraConfig, OFTConfig, PeftModel, + RandLoraConfig, TaskType, TrainableTokensConfig, VBLoRAConfig, VeraConfig, - RandLoraConfig, get_peft_model, ) from peft.tuners.tuners_utils import BaseTunerLayer @@ -519,18 +522,14 @@ ("Vanilla MLP 2 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0"]}), ("Vanilla MLP 3 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin1"]}), ("Vanilla MLP 4 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"]}), + ("Vanilla MLP 5 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "sparse": True}), + ("Vanilla MLP 6 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "very_sparse": True}), ( - "Vanilla MLP 5 RandLora", + "Vanilla MLP 7 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}, ), - ( - "Embedding + transformers Conv1D 1 RandLora", - "EmbConv1D", - RandLoraConfig, - {"target_modules": ["conv1d"]}, - ), ] # For this test matrix, each tuple consists of: @@ -639,11 +638,11 @@ ), # Note: RandLora may present the same problem mentioned above for Vera. ( - "RandLora Different", + "RandLora Same", "randlora", RandLoraConfig, {"target_modules": ["lin0"], "init_weights": False}, - {"target_modules": ["lin1"], "init_weights": False}, + {"target_modules": ["lin0"], "init_weights": False}, ), ( "HRA Same", @@ -3809,8 +3808,7 @@ def forward(self, X): # active adapter is still "default" self.check_requires_grad( peft_model, - "no" - "base_model.model.lin1.vera_lambda_b.default", + "nobase_model.model.lin1.vera_lambda_b.default", "base_model.model.lin1.vera_lambda_d.default", ) @@ -3943,10 +3941,10 @@ def forward(self, X): X = self.sm(X) return X - config0 = VeraConfig(target_modules=["lin1"]) + config0 = RandLoraConfig(target_modules=["lin1"]) peft_model = get_peft_model(MLP2(), config0) - config1 = VeraConfig(target_modules=["lin2"]) + config1 = RandLoraConfig(target_modules=["lin2"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" From e90b52af4ac4b991314dfbb5e7c1dda4a86636dc Mon Sep 17 00:00:00 2001 From: Paul Albert Date: Tue, 1 Apr 2025 10:06:15 +1100 Subject: [PATCH 04/13] style changes rebase --- examples/boft_controlnet/eval.sh | 0 examples/boft_controlnet/test_controlnet.sh | 0 examples/boft_controlnet/train_controlnet.sh | 0 examples/boft_dreambooth/train_dreambooth.sh | 0 tests/testing_common.py | 2 +- 5 files changed, 1 insertion(+), 1 deletion(-) mode change 100755 => 100644 examples/boft_controlnet/eval.sh mode change 100755 => 100644 examples/boft_controlnet/test_controlnet.sh mode change 100755 => 100644 examples/boft_controlnet/train_controlnet.sh mode change 100755 => 100644 examples/boft_dreambooth/train_dreambooth.sh diff --git a/examples/boft_controlnet/eval.sh b/examples/boft_controlnet/eval.sh old mode 100755 new mode 100644 diff --git a/examples/boft_controlnet/test_controlnet.sh b/examples/boft_controlnet/test_controlnet.sh old mode 100755 new mode 100644 diff --git a/examples/boft_controlnet/train_controlnet.sh b/examples/boft_controlnet/train_controlnet.sh old mode 100755 new mode 100644 diff --git a/examples/boft_dreambooth/train_dreambooth.sh b/examples/boft_dreambooth/train_dreambooth.sh old mode 100755 new mode 100644 diff --git a/tests/testing_common.py b/tests/testing_common.py index dd21426f5d..f0f0003e51 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 202-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 7f14c50fe60203ad02a934cb818d8a70beb5de08 Mon Sep 17 00:00:00 2001 From: Paul Albert Date: Tue, 8 Apr 2025 15:59:50 +1000 Subject: [PATCH 05/13] reverting licence change rebase --- src/peft/tuners/randlora/bnb.py | 44 ++--- src/peft/tuners/randlora/config.py | 21 ++- src/peft/tuners/randlora/layer.py | 10 +- src/peft/tuners/randlora/model.py | 2 +- tests/test_common_gpu.py | 158 +++++++++++++++++ tests/test_custom_models.py | 5 +- tests/test_feature_extraction_models.py | 6 +- tests/test_gpu_examples.py | 227 ++++++++++++++++++++++++ tests/test_initialization.py | 2 +- tests/testing_common.py | 19 +- 10 files changed, 439 insertions(+), 55 deletions(-) diff --git a/src/peft/tuners/randlora/bnb.py b/src/peft/tuners/randlora/bnb.py index f965dae75f..8f2c3eb465 100644 --- a/src/peft/tuners/randlora/bnb.py +++ b/src/peft/tuners/randlora/bnb.py @@ -124,7 +124,7 @@ def unmerge(self) -> None: ).to(weight.device) state.reset_grads() - def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor, torch.dtype]: + def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: """ Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order to fit the target layers' dimensions @@ -160,15 +160,15 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor, torch.dt # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, # we initialize these matrices with the largest required size for each dimension. # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. - sliced_A = randlora_A[:, : self.n, :min_dim] - sliced_B = randlora_B[:max_dim, : self.n, :] + sliced_A = randlora_A[:, : self.num_bases, :min_dim] + sliced_B = randlora_B[:max_dim, : self.num_bases, :] # Flattening the matrices over the rank and number of bases dimensions is more memory efficient update_B = sliced_B.flatten(start_dim=1) update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) if min_dim == self.in_features: - return update_A, update_B, dtype + return update_A, update_B - return update_B.T, update_A.T, dtype + return update_B.T, update_A.T def get_delta_weight(self, adapter) -> torch.Tensor: """ @@ -179,19 +179,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor: The name of the adapter for which the delta weight should be computed. """ - update_B, update_A, dtype = self.get_scaled_bases(adapter) + update_B, update_A = self.get_scaled_bases(adapter) update = update_B @ update_A output_tensor = transpose(update, self.fan_in_fan_out) - if dtype != self.randlora_B[adapter].dtype: - output_tensor = output_tensor.to(dtype=dtype) - - # cast back the weights - # TODO: why?, taken from the VeRA implementation - self.randlora_lambda[adapter].data = self.randlora_lambda[adapter].data.to(dtype) - self.randlora_gamma[adapter].data = self.randlora_gamma[adapter].data.to(dtype) - scaling = self.scaling[adapter] return output_tensor * scaling @@ -223,7 +215,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if active_adapter not in self.randlora_lambda.keys(): continue - update_B, update_A, dtype = self.get_scaled_bases(active_adapter) + update_B, update_A = self.get_scaled_bases(active_adapter) requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype @@ -336,7 +328,7 @@ def unmerge(self) -> None: weight.device ) - def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor, torch.dtype]: + def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: """ Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order to fit the target layers' dimensions @@ -372,15 +364,15 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor, torch.dt # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, # we initialize these matrices with the largest required size for each dimension. # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. - sliced_A = randlora_A[:, : self.n, :min_dim] - sliced_B = randlora_B[:max_dim, : self.n, :] + sliced_A = randlora_A[:, : self.num_bases, :min_dim] + sliced_B = randlora_B[:max_dim, : self.num_bases, :] # Flattening the matrices over the rank and number of bases dimensions is more memory efficient update_B = sliced_B.flatten(start_dim=1) update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) if min_dim == self.in_features: - return update_A, update_B, dtype + return update_A, update_B - return update_B.T, update_A.T, dtype + return update_B.T, update_A.T def get_delta_weight(self, adapter) -> torch.Tensor: """ @@ -391,19 +383,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor: The name of the adapter for which the delta weight should be computed. """ - update_B, update_A, dtype = self.get_scaled_bases(adapter) + update_B, update_A = self.get_scaled_bases(adapter) update = update_B @ update_A output_tensor = transpose(update, self.fan_in_fan_out) - if dtype != self.randlora_B[adapter].dtype: - output_tensor = output_tensor.to(dtype=dtype) - - # cast back the weights - # TODO: why?, taken from the VeRA implementation - self.randlora_lambda[adapter].data = self.randlora_lambda[adapter].to(dtype) - self.randlora_gamma[adapter].data = self.randlora_gamma[adapter].to(dtype) - scaling = self.scaling[adapter] return output_tensor * scaling @@ -421,7 +405,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: for active_adapter in self.active_adapters: if active_adapter not in self.randlora_lambda.keys(): continue - update_B, update_A, dtype = self.get_scaled_bases(active_adapter) + update_B, update_A = self.get_scaled_bases(active_adapter) requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype diff --git a/src/peft/tuners/randlora/config.py b/src/peft/tuners/randlora/config.py index ea730055e2..d0ff134978 100644 --- a/src/peft/tuners/randlora/config.py +++ b/src/peft/tuners/randlora/config.py @@ -28,9 +28,9 @@ class RandLoraConfig(PeftConfig): Paper: https://arxiv.org/pdf/2502.00987. Args: - r (`int`, *optional*, defaults to `32`): - RandLora's random basis rank dimension. This parameter is inversely proportional to the amount of trainable - parameters. + r (`int`, *optional*, defaults to `10`): + RandLora's random basis rank dimension. Contrary to Lora, this parameter is inversely proportional to the amount of trainable + parameters as reducing it increases trainable parameters. target_modules (`Union[List[str], str]`): The names of the modules to apply RandLora to. Only linear layers are supported. projection_prng_key (`int`): @@ -41,11 +41,14 @@ class RandLoraConfig(PeftConfig): gamma diagonal matrices. This will increase the size of the checkpoint, but guarantee that we can reload the checkpoint on all system configurations. Defaults to `True`. sparse (`bool`): - Whether to use sparse random bases as described in the RandLora paper. The current implementation is a - proof of concept where the sparseness is not used to improve speed or memory usage. Defaults to `False`. + Whether to use sparse random bases as described in the RandLora paper. The bases are ternary sparse bases (only containing -1, 0 and 1) where the attribution probability is 1/6 for -1 and 1 and 2/3 for 0. + These sparse matrices aim to be used for matmul free computation in the future, see https://arxiv.org/pdf/2406.02528v1 + The current implementation is a proof of concept however where the sparseness is not used to improve speed or memory usage. Using sparse matrices typically does not reduce performance and can even help reduce overfitting. + Defaults to `False`. very_sparse (`bool`): - Whether to use very sparse random bases. The current implementation is a proof of concept where the - sparseness is not used to improve speed or memory usage. Defaults to `False`. + Whether to use highly sparse random bases as described in the RandLora paper. The very sparse bases are ternary sparse bases (only containing -1, 0 and 1) given a matrix with smallest dimension d, the attribution probability is 1/√D for -1 and 1 and 1- 2/√D for 0. + Using these sparse matrices can further reduce overfitting over the `sparse` alternatives but will most likely decrease performance as a results. Use carefully. + Defaults to `False`. randlora_dropout (`float`): The dropout probability for RandLora layers. randlora_alpha (`float`): @@ -72,7 +75,7 @@ class RandLoraConfig(PeftConfig): pattern is not in the common layers pattern. """ - r: int = field(default=32, metadata={"help": "RandLora random basis rank"}) + r: int = field(default=10, metadata={"help": "RandLora random basis rank"}) target_modules: Optional[Union[List[str], str]] = field( default=None, @@ -129,7 +132,7 @@ class RandLoraConfig(PeftConfig): metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, ) randlora_alpha: int = field( - default=64, + default=20, metadata={ "help": "Scaling coefficient in the adapter layers, typically 2 times the rank of the random bases." }, diff --git a/src/peft/tuners/randlora/layer.py b/src/peft/tuners/randlora/layer.py index 0aa1c2fc11..0e7d76c867 100644 --- a/src/peft/tuners/randlora/layer.py +++ b/src/peft/tuners/randlora/layer.py @@ -246,7 +246,7 @@ def unmerge(self) -> None: if active_adapter in self.randlora_lambda.keys(): self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) - def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor, torch.dtype]: + def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: """ Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order to fit the target layers' dimensions @@ -291,8 +291,8 @@ def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor, torch.d # Since update_A is applied on the smallest dimension, test whether update_A or update_B should applied first. This is done to reduce trainable parameters. if min_dim == self.in_features: - return update_A, update_B, dtype - return update_B.T, update_A.T, dtype + return update_A, update_B + return update_B.T, update_A.T def get_delta_weight(self, adapter) -> torch.Tensor: """ @@ -303,7 +303,7 @@ def get_delta_weight(self, adapter) -> torch.Tensor: The name of the adapter for which the delta weight should be computed. """ - update_B, update_A, dtype = self.get_scaled_bases(adapter) + update_B, update_A = self.get_scaled_bases(adapter) update = (update_B.T @ update_A.T).T output_tensor = transpose(update, self.fan_in_fan_out) @@ -326,7 +326,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if active_adapter not in self.randlora_lambda.keys(): continue dropout = self.randlora_dropout[active_adapter] - update_B, update_A, _ = self.get_scaled_bases(active_adapter) + update_B, update_A = self.get_scaled_bases(active_adapter) x = x.to(update_A.dtype) scaling = self.scaling[active_adapter] result = result + F.linear(F.linear(dropout(x), update_B), update_A) * scaling diff --git a/src/peft/tuners/randlora/model.py b/src/peft/tuners/randlora/model.py index ff1fb9f43d..5d8e2dfc2c 100644 --- a/src/peft/tuners/randlora/model.py +++ b/src/peft/tuners/randlora/model.py @@ -151,6 +151,7 @@ def _init_randlora_A_randlora_B_sparse(self, config: RandLoraConfig, adapter_nam # deterministic init of randlora_A and randlora_B if we know the key generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) + # The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices. # We also set randlora_A as the smallest matrix to reduce trainable parameters. randlora_A = torch.rand((config.r, 1, min_dim), generator=generator) @@ -369,7 +370,6 @@ def _create_new_module(randlora_config, randlora_A, randlora_B, adapter_name, ta eightbit_kwargs.update( { "has_fp16_weights": target_base_layer.state.has_fp16_weights, - "memory_efficient_backward": target_base_layer.state.memory_efficient_backward, "threshold": target_base_layer.state.threshold, "index": target_base_layer.index, } diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 6d10f3bd27..b68510f51c 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -46,6 +46,7 @@ LoraConfig, OFTConfig, PeftModel, + RandLoraConfig, TaskType, VBLoRAConfig, VeraConfig, @@ -70,11 +71,13 @@ from peft.tuners.ia3 import Linear8bitLt as IA3Linear8bitLt from peft.tuners.lora import Linear8bitLt as LoraLinear8bitLt + from peft.tuners.randlora import Linear8bitLt as RandLoraLinear8bitLt from peft.tuners.vera import Linear8bitLt as VeraLinear8bitLt if is_bnb_4bit_available(): from peft.tuners.ia3 import Linear4bit as IA3Linear4bit from peft.tuners.lora import Linear4bit as LoraLinear4bit + from peft.tuners.randlora import Linear4bit as RandLoraLinear4bit from peft.tuners.vera import Linear4bit as VeraLinear4bit @@ -199,6 +202,54 @@ def test_vera_bnb_8bit_quantization(self): whisper_8bit = get_peft_model(whisper_8bit, config) assert isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, VeraLinear8bitLt) + @require_bitsandbytes + @pytest.mark.multi_gpu_tests + @pytest.mark.single_gpu_tests + def test_randlora_bnb_8bit_quantization(self): + r""" + Test that tests if the 8bit quantization using RandLora works as expected + """ + whisper_8bit = WhisperForConditionalGeneration.from_pretrained( + self.audio_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + opt_8bit = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + flan_8bit = AutoModelForSeq2SeqLM.from_pretrained( + self.seq2seq_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + flan_randlora_config = RandLoraConfig( + r=16, target_modules=["q", "v"], randlora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM" + ) + + opt_randlora_config = RandLoraConfig( + r=10, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + config = RandLoraConfig(r=5, target_modules=["q_proj", "v_proj"], randlora_dropout=0.05, bias="none") + + flan_8bit = get_peft_model(flan_8bit, flan_randlora_config) + assert isinstance(flan_8bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, RandLoraLinear8bitLt) + + opt_8bit = get_peft_model(opt_8bit, opt_randlora_config) + assert isinstance(opt_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RandLoraLinear8bitLt) + + whisper_8bit = get_peft_model(whisper_8bit, config) + assert isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RandLoraLinear8bitLt) + @require_bitsandbytes @pytest.mark.multi_gpu_tests @pytest.mark.single_gpu_tests @@ -347,6 +398,43 @@ def test_vera_bnb_quantization_from_pretrained_safetensors(self, quantization): assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.vera_A assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.vera_A + @require_bitsandbytes + @pytest.mark.multi_gpu_tests + @pytest.mark.single_gpu_tests + @parameterized.expand(["4bit", "8bit"]) + def test_randlora_bnb_quantization_from_pretrained_safetensors(self, quantization): + r""" + Tests that the bnb quantization using RandLora works as expected with safetensors weights. + """ + model_id = "facebook/opt-350m" + kwargs = {"device_map": "auto"} + if quantization == "4bit": + kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) + else: + kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) + config = RandLoraConfig(task_type=TaskType.CAUSAL_LM) + peft_model = get_peft_model(model, config) + peft_model = prepare_model_for_kbit_training(peft_model) + peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) + + with tempfile.TemporaryDirectory() as tmp_dir: + peft_model.save_pretrained(tmp_dir) + model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) + model = PeftModel.from_pretrained(model, tmp_dir) + model = prepare_model_for_kbit_training(model) + model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) + + # loading a 2nd adapter works, #1239 + model.load_adapter(tmp_dir, "adapter2") + model.set_adapter("adapter2") + model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) + + # check that both adapters are in the same layer + assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.randlora_A + assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.randlora_A + @require_bitsandbytes @pytest.mark.multi_gpu_tests @pytest.mark.single_gpu_tests @@ -519,6 +607,54 @@ def test_vera_bnb_4bit_quantization(self): whisper_4bit = get_peft_model(whisper_4bit, config) assert isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, VeraLinear4bit) + @require_bitsandbytes + @pytest.mark.multi_gpu_tests + @pytest.mark.single_gpu_tests + def test_randlora_bnb_4bit_quantization(self): + r""" + Test that tests if the 4bit quantization using RandLoRA works as expected + """ + whisper_4bit = WhisperForConditionalGeneration.from_pretrained( + self.audio_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + opt_4bit = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + flan_4bit = AutoModelForSeq2SeqLM.from_pretrained( + self.seq2seq_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + flan_randlora_config = RandLoraConfig( + r=16, target_modules=["q", "v"], randlora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM" + ) + + opt_randlora_config = RandLoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + config = RandLoraConfig(r=32, target_modules=["q_proj", "v_proj"], randlora_dropout=0.05, bias="none") + + flan_4bit = get_peft_model(flan_4bit, flan_randlora_config) + assert isinstance(flan_4bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, RandLoraLinear4bit) + + opt_4bit = get_peft_model(opt_4bit, opt_randlora_config) + assert isinstance(opt_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RandLoraLinear4bit) + + whisper_4bit = get_peft_model(whisper_4bit, config) + assert isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RandLoraLinear4bit) + @require_bitsandbytes @pytest.mark.multi_gpu_tests @pytest.mark.single_gpu_tests @@ -1787,6 +1923,28 @@ def test_vera_add_new_adapter_does_not_change_device(self, mlp): assert model.lin0.vera_A.other.device.type == self.device assert model.lin0.vera_lambda_d.other.device.type == self.device + def test_randlora_add_new_adapter_does_not_change_device(self, mlp): + # same as first test, but using RandLora + config = RandLoraConfig(target_modules=["lin0"]) + model = get_peft_model(mlp, config) + model = model.to(self.device) + model.lin0.randlora_A.cpu() + model.lin0.randlora_lambda.cpu() + + # check that the adapter is indeed on CPU and the base model on GPU + assert model.lin0.randlora_A.default.device.type == "cpu" + assert model.lin0.randlora_lambda.default.device.type == "cpu" + assert model.lin0.base_layer.weight.device.type == self.device + + model.add_adapter("other", config) + # check that after adding a new adapter, the old adapter is still on CPU + assert model.lin0.randlora_A.default.device.type == "cpu" + assert model.lin0.randlora_lambda.default.device.type == "cpu" + # the rest should be on GPU + assert model.lin0.base_layer.weight.device.type == self.device + assert model.lin0.randlora_A.other.device.type == self.device + assert model.lin0.randlora_lambda.other.device.type == self.device + def test_vblora_add_new_adapter_does_not_change_device(self, mlp): # same as first test, but using VBLoRA config = VBLoRAConfig(target_modules=["lin0"], vector_length=2) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index f8eed7892a..b1f0993e4c 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # coding=utf-8 -# Copyright 2025-present the HuggingFace Inc. team. +# Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ import platform import re import shutil -import sys import tempfile import time import unittest @@ -34,8 +33,6 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification from transformers.pytorch_utils import Conv1D - -sys.path.append(os.path.join(os.getcwd(), "src")) from peft import ( AdaLoraConfig, BOFTConfig, diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index 2bffb935ef..06fe3779d5 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -44,10 +44,10 @@ def skip_non_prompt_tuning(test_list): def skip_deberta_lora_tests(test_list): r""" - Skip tests that are checkpointing with lora/ia3/boft/vera/fourierft for Deberta models (couldn't find much info on + Skip tests that are checkpointing with lora/ia3/boft/vera/randlora/fourierft for Deberta models (couldn't find much info on the error) """ - to_skip = ["lora", "ia3", "boft", "vera", "fourierft", "hra", "bone"] + to_skip = ["lora", "ia3", "boft", "vera", "fourierft", "hra", "bone", "randlora"] return [test for test in test_list if not (any(k in test[0] for k in to_skip) and "Deberta" in test[0])] @@ -116,6 +116,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "boft_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, + "randlora_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, "bone_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", @@ -171,6 +172,7 @@ def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_k "boft_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, + "randlora_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, "bone_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index f591d52826..bf7373a2a4 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -60,6 +60,7 @@ PeftModel, PrefixTuningConfig, PromptEncoderConfig, + RandLoraConfig, TaskType, VeraConfig, get_peft_model, @@ -1391,6 +1392,232 @@ def test_causal_lm_training_multi_gpu_4bit_vera(self): # assert loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None + @pytest.mark.single_gpu_tests + def test_causal_lm_training_randlora(self): + r""" + Same as test_causal_lm_training but with RandLora + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + device_map="auto", + ) + + tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + model = prepare_model_for_kbit_training(model) + + config = RandLoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.single_gpu_tests + def test_causal_lm_training_4bit_randlora(self): + r""" + Same as test_causal_lm_training_4bit but with RandLora + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + device_map="auto", + ) + + tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + model = prepare_model_for_kbit_training(model) + + config = RandLoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.multi_gpu_tests + def test_causal_lm_training_multi_gpu_RandLora(self): + r""" + Same as test_causal_lm_training_multi_gpu but with RandLoRA + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + assert set(model.hf_device_map.values()) == set(range(device_count)) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = RandLoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.multi_gpu_tests + def test_causal_lm_training_multi_gpu_4bit_randlora(self): + r""" + Same as test_causal_lm_training_multi_gpu_4bit but with RandLora + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + assert set(model.hf_device_map.values()) == set(range(device_count)) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = RandLoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + randlora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + @pytest.mark.single_gpu_tests def test_causal_lm_training_lora_resize_embeddings_trainable_tokens(self): r""" diff --git a/tests/test_initialization.py b/tests/test_initialization.py index a4f86792e8..ad7253993e 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -2661,7 +2661,7 @@ def fn(x, *args): if prepare_layer_inputs_keys is None: prepare_layer_inputs_fn = fn else: - prepare_layer_inputs_fn = {k: fn for k in prepare_layer_inputs_keys} + prepare_layer_inputs_fn = dict.fromkeys(prepare_layer_inputs_keys, fn) shuffled_dataset = dataset.shuffle(seed=0) dataloader = self.get_dataloader(shuffled_dataset) diff --git a/tests/testing_common.py b/tests/testing_common.py index f0f0003e51..ba3a88db22 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1,4 +1,4 @@ -# Copyright 202-present the HuggingFace Inc. team. +# Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,6 +50,7 @@ PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, + RandLoraConfig, VBLoRAConfig, VeraConfig, get_peft_model, @@ -142,6 +143,16 @@ "bias": "none", "trainable_token_indices": [0, 1, 3], }, + # RandLoRA + { + "r": 10, + "randlora_alpha": 20, + "target_modules": None, + "randlora_dropout": 0.05, + "projection_prng_key": 0xFF, + "save_projection": True, + "bias": "none", + }, # CPT tuninig { "cpt_token_ids": [0, 1, 2, 3, 4, 5, 6, 7], # Example token IDs for testing @@ -165,9 +176,10 @@ "oft": (OFTConfig, CONFIG_TESTING_KWARGS[11]), "bone": (BoneConfig, CONFIG_TESTING_KWARGS[12]), "lora+trainable_tokens": (LoraConfig, CONFIG_TESTING_KWARGS[13]), + "randlora": (RandLoraConfig, CONFIG_TESTING_KWARGS[14]), } -DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[14])} +DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[15])} # Adapted from https://github.com/huggingface/transformers/blob/48327c57182fdade7f7797d1eaad2d166de5c55b/src/transformers/activations.py#LL166C7-L166C22 @@ -482,7 +494,7 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs, safe_serial if issubclass(config_cls, IA3Config): config_kwargs = config_kwargs.copy() config_kwargs["init_ia3_weights"] = False - if issubclass(config_cls, VeraConfig): + if hasattr(config_cls, "init_weights"): config_kwargs = config_kwargs.copy() config_kwargs["init_weights"] = False @@ -1642,6 +1654,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): "FOURIERFT", "HRA", "VBLORA", + "RANDLORA", "BONE", ): with pytest.raises(AttributeError): From a1f053979548aeaf1100bff55581734ccd0d4be9 Mon Sep 17 00:00:00 2001 From: Paul Albert Date: Fri, 11 Apr 2025 18:18:54 +1000 Subject: [PATCH 06/13] reverting file permissions --- examples/boft_controlnet/eval.sh | 0 examples/boft_controlnet/test_controlnet.sh | 0 examples/boft_controlnet/train_controlnet.sh | 0 examples/boft_dreambooth/train_dreambooth.sh | 0 4 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 examples/boft_controlnet/eval.sh mode change 100644 => 100755 examples/boft_controlnet/test_controlnet.sh mode change 100644 => 100755 examples/boft_controlnet/train_controlnet.sh mode change 100644 => 100755 examples/boft_dreambooth/train_dreambooth.sh diff --git a/examples/boft_controlnet/eval.sh b/examples/boft_controlnet/eval.sh old mode 100644 new mode 100755 diff --git a/examples/boft_controlnet/test_controlnet.sh b/examples/boft_controlnet/test_controlnet.sh old mode 100644 new mode 100755 diff --git a/examples/boft_controlnet/train_controlnet.sh b/examples/boft_controlnet/train_controlnet.sh old mode 100644 new mode 100755 diff --git a/examples/boft_dreambooth/train_dreambooth.sh b/examples/boft_dreambooth/train_dreambooth.sh old mode 100644 new mode 100755 From 7942a9c369d17662a90caf09cdbb6c830bf559aa Mon Sep 17 00:00:00 2001 From: PaulAlbert31 Date: Thu, 17 Apr 2025 05:05:29 +0000 Subject: [PATCH 07/13] better hyper-parameters and comformity with new dtype casts --- src/peft/tuners/randlora/__init__.py | 1 + src/peft/tuners/randlora/bnb.py | 5 ++++- src/peft/tuners/randlora/config.py | 31 +++++++++++++-------------- src/peft/tuners/randlora/layer.py | 32 +++++++++++++++++----------- src/peft/tuners/randlora/model.py | 1 + tests/test_custom_models.py | 27 ++++++++++++++++------- tests/testing_common.py | 4 ++-- 7 files changed, 62 insertions(+), 39 deletions(-) diff --git a/src/peft/tuners/randlora/__init__.py b/src/peft/tuners/randlora/__init__.py index 92b8ffe5e0..fbad681aeb 100644 --- a/src/peft/tuners/randlora/__init__.py +++ b/src/peft/tuners/randlora/__init__.py @@ -1,4 +1,5 @@ # Copyright 2025-present the HuggingFace Inc. team. + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/peft/tuners/randlora/bnb.py b/src/peft/tuners/randlora/bnb.py index 8f2c3eb465..8e3be08229 100644 --- a/src/peft/tuners/randlora/bnb.py +++ b/src/peft/tuners/randlora/bnb.py @@ -162,6 +162,7 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. sliced_A = randlora_A[:, : self.num_bases, :min_dim] sliced_B = randlora_B[:max_dim, : self.num_bases, :] + # Flattening the matrices over the rank and number of bases dimensions is more memory efficient update_B = sliced_B.flatten(start_dim=1) update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) @@ -216,6 +217,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: continue update_B, update_A = self.get_scaled_bases(active_adapter) + requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype @@ -382,7 +384,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor: adapter (str): The name of the adapter for which the delta weight should be computed. """ - update_B, update_A = self.get_scaled_bases(adapter) update = update_B @ update_A @@ -405,7 +406,9 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: for active_adapter in self.active_adapters: if active_adapter not in self.randlora_lambda.keys(): continue + update_B, update_A = self.get_scaled_bases(active_adapter) + requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype diff --git a/src/peft/tuners/randlora/config.py b/src/peft/tuners/randlora/config.py index d0ff134978..341251af7e 100644 --- a/src/peft/tuners/randlora/config.py +++ b/src/peft/tuners/randlora/config.py @@ -14,7 +14,7 @@ import warnings from dataclasses import dataclass, field -from typing import List, Optional, Union +from typing import Optional, Union from peft.config import PeftConfig from peft.utils import PeftType @@ -28,10 +28,10 @@ class RandLoraConfig(PeftConfig): Paper: https://arxiv.org/pdf/2502.00987. Args: - r (`int`, *optional*, defaults to `10`): + r (`int`, *optional*, defaults to `32`): RandLora's random basis rank dimension. Contrary to Lora, this parameter is inversely proportional to the amount of trainable parameters as reducing it increases trainable parameters. - target_modules (`Union[List[str], str]`): + target_modules (`Union[list[str], str]`): The names of the modules to apply RandLora to. Only linear layers are supported. projection_prng_key (`int`): RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a @@ -52,8 +52,7 @@ class RandLoraConfig(PeftConfig): randlora_dropout (`float`): The dropout probability for RandLora layers. randlora_alpha (`float`): - The scaling coefficient for RandLora layers, this would be typically be the same as LoRA, e.g. 2 times the - rank. + The scaling coefficient for RandLora layers, this would typically be 20 times the rank. fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. @@ -61,12 +60,12 @@ class RandLoraConfig(PeftConfig): Bias type. Can be 'none', 'all' or 'randlora_only'. If 'all' or 'randlora_only', the corresponding biases will be updated during training. Be aware that this means that, even when disabling the adapters, the model will not produce the same output as the base model would have without adaptation. - modules_to_save (`List[str]`): - List of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. + modules_to_save (`list[str]`): + list of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. init_weights (`bool`): Whether to initialize the weights of the RandLora layers with their default initialization. Don't change this setting, except if you know exactly what you're doing. - layers_to_transform (`Union[List[int],int]`): + layers_to_transform (`Union[list[int],int]`): The layer indexes to transform, if this argument is specified, it will apply the RandLora transformations on the layer indexes that are specified in this list. If a single integer is passed, it will apply the RandLora transformations on the layer at this index. @@ -75,13 +74,13 @@ class RandLoraConfig(PeftConfig): pattern is not in the common layers pattern. """ - r: int = field(default=10, metadata={"help": "RandLora random basis rank"}) + r: int = field(default=32, metadata={"help": "RandLora random basis rank"}) - target_modules: Optional[Union[List[str], str]] = field( + target_modules: Optional[Union[list[str], str]] = field( default=None, metadata={ "help": ( - "List of module names or regex expression of the module names to replace with RandLora." + "list of module names or regex expression of the module names to replace with RandLora." "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. " "Only linear layers are supported." ) @@ -132,19 +131,19 @@ class RandLoraConfig(PeftConfig): metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, ) randlora_alpha: int = field( - default=20, + default=640, metadata={ - "help": "Scaling coefficient in the adapter layers, typically 2 times the rank of the random bases." + "help": "Scaling coefficient in the adapter layers, typically 20 times the rank of the random bases." }, ) bias: str = field( default="none", metadata={"help": "Bias type for RandLora. Can be 'none', 'all' or 'randlora_only'"} ) - modules_to_save: Optional[List[str]] = field( + modules_to_save: Optional[list[str]] = field( default=None, metadata={ "help": ( - "List of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. For" + "list of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. For" " example, in Sequence Classification or Token Classification tasks, the final layer" " `classifier/score` are randomly initialized and as such need to be trainable and saved." ) @@ -159,7 +158,7 @@ class RandLoraConfig(PeftConfig): ), }, ) - layers_to_transform: Optional[Union[List[int], int]] = field( + layers_to_transform: Optional[Union[list[int], int]] = field( default=None, metadata={ "help": ( diff --git a/src/peft/tuners/randlora/layer.py b/src/peft/tuners/randlora/layer.py index 0e7d76c867..7719aa6e59 100644 --- a/src/peft/tuners/randlora/layer.py +++ b/src/peft/tuners/randlora/layer.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import warnings -from typing import List, Optional +from typing import Optional import torch import torch.nn as nn @@ -72,6 +71,9 @@ def __init__(self, base_layer: nn.Module, **kwargs): self._disable_adapters = False self.merged_adapters = [] + # flag to enable/disable casting of input to weight dtype during forward call + self.cast_input_dtype_enabled = True + base_layer = self.get_base_layer() if isinstance(base_layer, nn.Linear): in_features, out_features = base_layer.in_features, base_layer.out_features @@ -118,7 +120,7 @@ def update_layer( requires_grad=True, ) - self.scaling[adapter_name] = randlora_alpha / r / math.sqrt(self.num_bases) + self.scaling[adapter_name] = randlora_alpha / r # non trainable references to randlora_A/B buffers self.randlora_A = randlora_A @@ -153,6 +155,7 @@ def update_layer( ) if randlora_A_param.shape[0] < self.r[adapter_name]: raise ValueError(error_tmpl.format("randlora_A", randlora_A_param.shape[0], self.r[adapter_name])) + if randlora_B_param.shape[-1] < self.r[adapter_name]: raise ValueError(error_tmpl.format("randlora_B", randlora_B_param.shape[-1], self.r[adapter_name])) @@ -169,9 +172,7 @@ def reset_randlora_parameters(self, adapter_name): if adapter_name in self.randlora_lambda.keys(): with torch.no_grad(): nn.init.zeros_(self.randlora_lambda[adapter_name]) - nn.init.ones_(self.randlora_gamma[adapter_name]).fill_( - 1 / max(self.randlora_gamma[adapter_name].shape) - ) + nn.init.constant_(self.randlora_gamma[adapter_name], 1 / max(self.randlora_gamma[adapter_name].shape)) class Linear(nn.Linear, RandLoraLayer): @@ -198,7 +199,7 @@ def __init__( self.update_layer(adapter_name, randlora_A, randlora_B, r, randlora_alpha, randlora_dropout, init_weights) self.is_target_conv_1d_layer = is_target_conv_1d_layer - def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ Merge the active adapter weights into the base weights @@ -207,7 +208,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N If True, the merge operation will be performed in a copy of the original weights and check for NaNs before merging the weights. This is useful if you want to check if the merge operation will produce NaNs. Defaults to `False`. - adapter_names (`List[str]`, *optional*): + adapter_names (`list[str]`, *optional*): The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults to `None`. """ @@ -219,6 +220,8 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N for active_adapter in adapter_names: if active_adapter in self.randlora_lambda.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype + if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. @@ -231,9 +234,11 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - base_layer.weight.data = orig_weights + base_layer.weight.data = orig_weights.to(orig_dtype) else: - base_layer.weight.data += self.get_delta_weight(active_adapter) + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data += delta_weight.to(orig_dtype) + self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -242,9 +247,12 @@ def unmerge(self) -> None: return while len(self.merged_adapters) > 0: + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype active_adapter = self.merged_adapters.pop() if active_adapter in self.randlora_lambda.keys(): - self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data -= delta_weight.to(orig_dtype) def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -289,7 +297,7 @@ def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: update_B = sliced_B.flatten(start_dim=1) update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) - # Since update_A is applied on the smallest dimension, test whether update_A or update_B should applied first. This is done to reduce trainable parameters. + # Since update_A is applied on the smallest dimension, test whether update_A or update_B should be applied first. This is done to reduce trainable parameters. if min_dim == self.in_features: return update_A, update_B return update_B.T, update_A.T diff --git a/src/peft/tuners/randlora/model.py b/src/peft/tuners/randlora/model.py index 5d8e2dfc2c..154ba321b0 100644 --- a/src/peft/tuners/randlora/model.py +++ b/src/peft/tuners/randlora/model.py @@ -188,6 +188,7 @@ def _init_randlora_A_randlora_B(self, config: RandLoraConfig, adapter_name: str) # deterministic init of randlora_A and randlora_B if we know the key generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) + # The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices. # We also set randlora_A as the smallest matrix to reduce trainable parameters. randlora_A = _kaiming_init((config.r, 1, min_dim), generator=generator) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index b1f0993e4c..2a798fdbfa 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -515,17 +515,28 @@ ######## # RandLora # ######## - ("Vanilla MLP 1 RandLora", "MLP", RandLoraConfig, {"target_modules": "lin0"}), - ("Vanilla MLP 2 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0"]}), - ("Vanilla MLP 3 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin1"]}), - ("Vanilla MLP 4 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"]}), - ("Vanilla MLP 5 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "sparse": True}), - ("Vanilla MLP 6 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "very_sparse": True}), + # We have to reduce the default scaling parameter to avoid nans when using large learning rates + ("Vanilla MLP 1 RandLora", "MLP", RandLoraConfig, {"target_modules": "lin0", "randlora_alpha": 64}), + ("Vanilla MLP 2 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0"], "randlora_alpha": 64}), + ("Vanilla MLP 3 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin1"], "randlora_alpha": 64}), + ("Vanilla MLP 4 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "randlora_alpha": 64}), + ( + "Vanilla MLP 5 RandLora", + "MLP", + RandLoraConfig, + {"target_modules": ["lin0", "lin1"], "sparse": True, "randlora_alpha": 64}, + ), + ( + "Vanilla MLP 6 RandLora", + "MLP", + RandLoraConfig, + {"target_modules": ["lin0", "lin1"], "very_sparse": True, "randlora_alpha": 64}, + ), ( "Vanilla MLP 7 RandLora", "MLP", RandLoraConfig, - {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}, + {"target_modules": ["lin0"], "modules_to_save": ["lin1"], "randlora_alpha": 64}, ), ] @@ -1465,7 +1476,7 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c lr = 0.1 # otherwise we get nan elif "mha" in model_id.lower(): lr = 1e-3 # we get exploding gradients with MHA when learning rate is too high - elif issubclass(config_cls, VBLoRAConfig): + elif issubclass(config_cls, VBLoRAConfig) or issubclass(config_cls, RandLoraConfig): lr = 0.01 # otherwise we get nan optimizer = torch.optim.SGD(model.parameters(), lr=lr) diff --git a/tests/testing_common.py b/tests/testing_common.py index ba3a88db22..8a1acc6daf 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -145,8 +145,8 @@ }, # RandLoRA { - "r": 10, - "randlora_alpha": 20, + "r": 32, + "randlora_alpha": 64, "target_modules": None, "randlora_dropout": 0.05, "projection_prng_key": 0xFF, From 95c8fb1cd96ef724c04c9945177c09f02237712c Mon Sep 17 00:00:00 2001 From: PaulAlbert31 Date: Sun, 20 Apr 2025 09:17:18 +0000 Subject: [PATCH 08/13] new tests for shared weights, temp multi-gpu fix, better var names and docstrings --- src/peft/tuners/randlora/bnb.py | 71 ++++-- src/peft/tuners/randlora/config.py | 26 +- src/peft/tuners/randlora/layer.py | 23 +- src/peft/tuners/randlora/model.py | 29 +-- tests/test_custom_models.py | 2 +- tests/test_feature_extraction_models.py | 4 +- tests/test_gpu_examples.py | 4 +- tests/test_initialization.py | 2 +- tests/test_randlora.py | 301 ++++++++++++++++++++++++ 9 files changed, 396 insertions(+), 66 deletions(-) create mode 100644 tests/test_randlora.py diff --git a/src/peft/tuners/randlora/bnb.py b/src/peft/tuners/randlora/bnb.py index 8e3be08229..84983643ef 100644 --- a/src/peft/tuners/randlora/bnb.py +++ b/src/peft/tuners/randlora/bnb.py @@ -59,11 +59,18 @@ def __init__( ) def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: - if self.merged: - warnings.warn( - f"Already following adapters were merged {','.join(self.merged_adapters)}. " - f"You are now additionally merging {','.join(self.active_adapters)}." - ) + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ adapter_names = check_adapters_to_merge(self, adapter_names) if not adapter_names: @@ -98,6 +105,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N self.merged_adapters.append(active_adapter) def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do") return @@ -124,7 +134,7 @@ def unmerge(self) -> None: ).to(weight.device) state.reset_grads() - def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: + def get_scaled_bases(self, adapter, device=None) -> list[torch.Tensor, torch.Tensor]: """ Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order to fit the target layers' dimensions @@ -137,7 +147,8 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: randlora_A = self.randlora_A[adapter] randlora_B = self.randlora_B[adapter] - device = randlora_B.device + if device is None: + device = randlora_B.device dtype = randlora_B.dtype # In case users wants to merge the adapter weights that are in @@ -145,8 +156,8 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: # (b)float16 because some CPUs have slow bf16/fp16 matmuls. cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) - randlora_lambda = self.randlora_lambda[adapter] - randlora_gamma = self.randlora_gamma[adapter] + randlora_lambda = self.randlora_lambda[adapter].to(device) + randlora_gamma = self.randlora_gamma[adapter].to(device) if cast_to_fp32: randlora_A = randlora_A.float() @@ -160,8 +171,8 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, # we initialize these matrices with the largest required size for each dimension. # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. - sliced_A = randlora_A[:, : self.num_bases, :min_dim] - sliced_B = randlora_B[:max_dim, : self.num_bases, :] + sliced_A = randlora_A[:, : self.num_bases, :min_dim].to(device) + sliced_B = randlora_B[:max_dim, : self.num_bases, :].to(device) # Flattening the matrices over the rank and number of bases dimensions is more memory efficient update_B = sliced_B.flatten(start_dim=1) @@ -216,7 +227,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if active_adapter not in self.randlora_lambda.keys(): continue - update_B, update_A = self.get_scaled_bases(active_adapter) + update_B, update_A = self.get_scaled_bases(active_adapter, device=x.device) requires_conversion = not torch.is_autocast_enabled() if requires_conversion: @@ -275,11 +286,18 @@ def __init__( ) def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: - if self.merged: - warnings.warn( - f"Already following adapters were merged {','.join(self.merged_adapters)}. " - f"You are now additionally merging {','.join(self.active_adapters)}." - ) + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ adapter_names = check_adapters_to_merge(self, adapter_names) if not adapter_names: @@ -309,6 +327,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N self.merged_adapters.append(active_adapter) def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do") return @@ -330,7 +351,7 @@ def unmerge(self) -> None: weight.device ) - def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: + def get_scaled_bases(self, adapter, device=None) -> list[torch.Tensor, torch.Tensor]: """ Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order to fit the target layers' dimensions @@ -342,8 +363,8 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: randlora_A = self.randlora_A[adapter] randlora_B = self.randlora_B[adapter] - - device = randlora_B.device + if device is None: + device = randlora_B.device dtype = randlora_B.dtype # In case users wants to merge the adapter weights that are in @@ -351,8 +372,8 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: # (b)float16 because some CPUs have slow bf16/fp16 matmuls. cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) - randlora_lambda = self.randlora_lambda[adapter] - randlora_gamma = self.randlora_gamma[adapter] + randlora_lambda = self.randlora_lambda[adapter].to(device) + randlora_gamma = self.randlora_gamma[adapter].to(device) if cast_to_fp32: randlora_A = randlora_A.float() @@ -366,8 +387,8 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]: # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, # we initialize these matrices with the largest required size for each dimension. # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. - sliced_A = randlora_A[:, : self.num_bases, :min_dim] - sliced_B = randlora_B[:max_dim, : self.num_bases, :] + sliced_A = randlora_A[:, : self.num_bases, :min_dim].to(device) + sliced_B = randlora_B[:max_dim, : self.num_bases, :].to(device) # Flattening the matrices over the rank and number of bases dimensions is more memory efficient update_B = sliced_B.flatten(start_dim=1) update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1) @@ -407,7 +428,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if active_adapter not in self.randlora_lambda.keys(): continue - update_B, update_A = self.get_scaled_bases(active_adapter) + update_B, update_A = self.get_scaled_bases(active_adapter, device=x.device) requires_conversion = not torch.is_autocast_enabled() if requires_conversion: diff --git a/src/peft/tuners/randlora/config.py b/src/peft/tuners/randlora/config.py index 341251af7e..06a739fbe6 100644 --- a/src/peft/tuners/randlora/config.py +++ b/src/peft/tuners/randlora/config.py @@ -29,8 +29,8 @@ class RandLoraConfig(PeftConfig): Args: r (`int`, *optional*, defaults to `32`): - RandLora's random basis rank dimension. Contrary to Lora, this parameter is inversely proportional to the amount of trainable - parameters as reducing it increases trainable parameters. + RandLora's random basis rank dimension. Contrary to Lora, this parameter is inversely proportional to the + amount of trainable parameters as reducing it increases trainable parameters. target_modules (`Union[list[str], str]`): The names of the modules to apply RandLora to. Only linear layers are supported. projection_prng_key (`int`): @@ -38,17 +38,21 @@ class RandLoraConfig(PeftConfig): checkpoint that did not include these projections. Defaults to `0`. save_projection (`bool`): Whether to save the global basis_A / basis_B random basis in the state dict alongside per layer lambda / - gamma diagonal matrices. This will increase the size of the checkpoint, but guarantee that we can - reload the checkpoint on all system configurations. Defaults to `True`. + gamma diagonal matrices. This will increase the size of the checkpoint, but guarantee that we can reload + the checkpoint on all system configurations. Defaults to `True`. sparse (`bool`): - Whether to use sparse random bases as described in the RandLora paper. The bases are ternary sparse bases (only containing -1, 0 and 1) where the attribution probability is 1/6 for -1 and 1 and 2/3 for 0. - These sparse matrices aim to be used for matmul free computation in the future, see https://arxiv.org/pdf/2406.02528v1 - The current implementation is a proof of concept however where the sparseness is not used to improve speed or memory usage. Using sparse matrices typically does not reduce performance and can even help reduce overfitting. - Defaults to `False`. + Whether to use sparse random bases as described in the RandLora paper. The bases are ternary sparse bases + (only containing -1, 0 and 1) where the attribution probability is 1/6 for -1 and 1 and 2/3 for 0. These + sparse matrices aim to be used for matmul free computation in the future, see + https://arxiv.org/pdf/2406.02528v1 The current implementation is a proof of concept however where the + sparseness is not used to improve speed or memory usage. Using sparse matrices typically does not reduce + performance and can even help reduce overfitting. Defaults to `False`. very_sparse (`bool`): - Whether to use highly sparse random bases as described in the RandLora paper. The very sparse bases are ternary sparse bases (only containing -1, 0 and 1) given a matrix with smallest dimension d, the attribution probability is 1/√D for -1 and 1 and 1- 2/√D for 0. - Using these sparse matrices can further reduce overfitting over the `sparse` alternatives but will most likely decrease performance as a results. Use carefully. - Defaults to `False`. + Whether to use highly sparse random bases as described in the RandLora paper. The very sparse bases are + ternary sparse bases (only containing -1, 0 and 1) given a matrix with smallest dimension d, the + attribution probability is 1/√D for -1 and 1 and 1- 2/√D for 0. Using these sparse matrices can further + reduce overfitting over the `sparse` alternatives but will most likely decrease performance as a results. + Use carefully. Defaults to `False`. randlora_dropout (`float`): The dropout probability for RandLora layers. randlora_alpha (`float`): diff --git a/src/peft/tuners/randlora/layer.py b/src/peft/tuners/randlora/layer.py index 7719aa6e59..40e5aeab6b 100644 --- a/src/peft/tuners/randlora/layer.py +++ b/src/peft/tuners/randlora/layer.py @@ -30,9 +30,9 @@ class UniqueBaseGrad(torch.autograd.Function): # Memory efficent for a unique base @staticmethod def forward(ctx, randlora_A, randlora_lambda, randlora_gamma): - Out = randlora_lambda[:, :, None] * randlora_A * randlora_gamma[None,] + out = randlora_lambda[:, :, None] * randlora_A * randlora_gamma[None,] ctx.save_for_backward(randlora_A, randlora_lambda, randlora_gamma) - return Out + return out @staticmethod def backward(ctx, grad_output): @@ -242,6 +242,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N self.merged_adapters.append(active_adapter) def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return @@ -254,7 +257,7 @@ def unmerge(self) -> None: delta_weight = self.get_delta_weight(active_adapter) base_layer.weight.data -= delta_weight.to(orig_dtype) - def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: + def get_scaled_bases(self, adapter, device=None) -> tuple[torch.Tensor, torch.Tensor]: """ Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct order to fit the target layers' dimensions @@ -266,8 +269,8 @@ def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: randlora_A = self.randlora_A[adapter] randlora_B = self.randlora_B[adapter] - - device = randlora_B.device + if device is None: + device = randlora_B.device dtype = randlora_B.dtype # In case users wants to merge the adapter weights that are in @@ -275,8 +278,8 @@ def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: # (b)float16 because some CPUs have slow bf16/fp16 matmuls. cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) - randlora_lambda = self.randlora_lambda[adapter] - randlora_gamma = self.randlora_gamma[adapter] + randlora_lambda = self.randlora_lambda[adapter].to(device) + randlora_gamma = self.randlora_gamma[adapter].to(device) if cast_to_fp32: randlora_A = randlora_A.float() @@ -290,8 +293,8 @@ def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: # As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices, # we initialize these matrices with the largest required size for each dimension. # During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B. - sliced_A = randlora_A[:, : self.num_bases, :min_dim] - sliced_B = randlora_B[:max_dim, : self.num_bases, :] + sliced_A = randlora_A[:, : self.num_bases, :min_dim].to(device) + sliced_B = randlora_B[:max_dim, : self.num_bases, :].to(device) # Flattening the matrices over the rank and number of bases dimensions is more memory efficient update_B = sliced_B.flatten(start_dim=1) @@ -334,7 +337,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if active_adapter not in self.randlora_lambda.keys(): continue dropout = self.randlora_dropout[active_adapter] - update_B, update_A = self.get_scaled_bases(active_adapter) + update_B, update_A = self.get_scaled_bases(active_adapter, device=x.device) x = x.to(update_A.dtype) scaling = self.scaling[active_adapter] result = result + F.linear(F.linear(dropout(x), update_B), update_A) * scaling diff --git a/src/peft/tuners/randlora/model.py b/src/peft/tuners/randlora/model.py index 154ba321b0..1373d581fd 100644 --- a/src/peft/tuners/randlora/model.py +++ b/src/peft/tuners/randlora/model.py @@ -137,7 +137,7 @@ def _find_dim(self, config) -> tuple[int, int]: return largest_shape - def _init_randlora_A_randlora_B_sparse(self, config: RandLoraConfig, adapter_name: str, s: int = 3) -> None: + def _init_randlora_A_randlora_B_sparse(self, config: RandLoraConfig, adapter_name: str, sparsity: int = 3) -> None: """ Sparse random projections as described in https://cs-people.bu.edu/evimaria/cs565/kdd-rp.pdf """ @@ -156,19 +156,19 @@ def _init_randlora_A_randlora_B_sparse(self, config: RandLoraConfig, adapter_nam # We also set randlora_A as the smallest matrix to reduce trainable parameters. randlora_A = torch.rand((config.r, 1, min_dim), generator=generator) - # Ensure full rank - n = min_dim / config.r - n = int(n) if n.is_integer() else int(n) + 1 # Ensure full rank - randlora_B = torch.rand((max_dim, n, config.r), generator=generator) + # Number of bases to ensure full rank + num_bases = min_dim / config.r + num_bases = int(n_bases) if num_bases.is_integer() else int(num_bases) + 1 # Ensure full rank + randlora_B = torch.rand((max_dim, num_bases, config.r), generator=generator) # The current implementation is a proof of concept and does take into consideration # the sparsity to reduce memory usage or speed up compute randlora_B_sparse = torch.zeros(randlora_B.shape) randlora_A_sparse = torch.zeros(randlora_A.shape) - randlora_B_sparse[randlora_B < 1 / (2 * s)] = -1 - randlora_B_sparse[randlora_B > 1 - 1 / (2 * s)] = 1 - randlora_A_sparse[randlora_A < 1 / (2 * s)] = -1 - randlora_A_sparse[randlora_A > 1 - 1 / (2 * s)] = 1 + randlora_B_sparse[randlora_B < 1 / (2 * sparsity)] = -1 + randlora_B_sparse[randlora_B > 1 - 1 / (2 * sparsity)] = 1 + randlora_A_sparse[randlora_A < 1 / (2 * sparsity)] = -1 + randlora_A_sparse[randlora_A > 1 - 1 / (2 * sparsity)] = 1 # Std normalization is empirically found to be the best randlora_A, randlora_B = ( @@ -194,9 +194,9 @@ def _init_randlora_A_randlora_B(self, config: RandLoraConfig, adapter_name: str) randlora_A = _kaiming_init((config.r, 1, min_dim), generator=generator) # Ensure full rank - n = min(linear_out_dim, linear_in_dim) / config.r - n = int(n) if n.is_integer() else int(n) + 1 - randlora_B = torch.cat([_kaiming_init((max_dim, 1, config.r), generator=generator) for _ in range(n)], dim=1) + num_bases = min(linear_out_dim, linear_in_dim) / config.r + num_bases = int(num_bases) if num_bases.is_integer() else int(num_bases) + 1 + randlora_B = torch.cat([_kaiming_init((max_dim, 1, config.r), generator=generator) for _ in range(num_bases)], dim=1) # Std normalization is empirically found to be the best randlora_A, randlora_B = randlora_A / randlora_A.std(), randlora_B / randlora_B.std() @@ -207,10 +207,10 @@ def _pre_injection_hook(self, model: nn.Module, config: RandLoraConfig, adapter_ if config.very_sparse: linear_out_dim, linear_in_dim = self._find_dim(config) self._init_randlora_A_randlora_B_sparse( - config, adapter_name, s=math.sqrt(min(linear_out_dim, linear_in_dim)) + config, adapter_name, sparsity=math.sqrt(min(linear_out_dim, linear_in_dim)) ) elif config.sparse: - self._init_randlora_A_randlora_B_sparse(config, adapter_name, s=3) + self._init_randlora_A_randlora_B_sparse(config, adapter_name, sparsity=3) else: self._init_randlora_A_randlora_B(config, adapter_name) @@ -521,6 +521,7 @@ def delete_adapter(self, adapter_name: str): new_adapter = target.active_adapter[:] self.active_adapter = new_adapter or [] + self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter) def merge_and_unload( self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 2a798fdbfa..2cfeb715dd 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -3816,7 +3816,7 @@ def forward(self, X): # active adapter is still "default" self.check_requires_grad( peft_model, - "nobase_model.model.lin1.vera_lambda_b.default", + "base_model.model.lin1.vera_lambda_b.default", "base_model.model.lin1.vera_lambda_d.default", ) diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index 06fe3779d5..2a4c5b88bd 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -44,8 +44,8 @@ def skip_non_prompt_tuning(test_list): def skip_deberta_lora_tests(test_list): r""" - Skip tests that are checkpointing with lora/ia3/boft/vera/randlora/fourierft for Deberta models (couldn't find much info on - the error) + Skip tests that are checkpointing with lora/ia3/boft/vera/randlora/fourierft for Deberta models (couldn't find much + info on the error) """ to_skip = ["lora", "ia3", "boft", "vera", "fourierft", "hra", "bone", "randlora"] return [test for test in test_list if not (any(k in test[0] for k in to_skip) and "Deberta" in test[0])] diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index bf7373a2a4..4e81e2c6a8 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1393,7 +1393,7 @@ def test_causal_lm_training_multi_gpu_4bit_vera(self): assert trainer.state.log_history[-1]["train_loss"] is not None @pytest.mark.single_gpu_tests - def test_causal_lm_training_randlora(self): + def test_causal_lm_training_8bit_randlora(self): r""" Same as test_causal_lm_training but with RandLora """ @@ -1501,7 +1501,7 @@ def test_causal_lm_training_4bit_randlora(self): assert trainer.state.log_history[-1]["train_loss"] is not None @pytest.mark.multi_gpu_tests - def test_causal_lm_training_multi_gpu_RandLora(self): + def test_causal_lm_training_multi_gpu_8bit_randlora(self): r""" Same as test_causal_lm_training_multi_gpu but with RandLoRA """ diff --git a/tests/test_initialization.py b/tests/test_initialization.py index ad7253993e..a4f86792e8 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -2661,7 +2661,7 @@ def fn(x, *args): if prepare_layer_inputs_keys is None: prepare_layer_inputs_fn = fn else: - prepare_layer_inputs_fn = dict.fromkeys(prepare_layer_inputs_keys, fn) + prepare_layer_inputs_fn = {k: fn for k in prepare_layer_inputs_keys} shuffled_dataset = dataset.shuffle(seed=0) dataloader = self.get_dataloader(shuffled_dataset) diff --git a/tests/test_randlora.py b/tests/test_randlora.py new file mode 100644 index 0000000000..03017d0bd7 --- /dev/null +++ b/tests/test_randlora.py @@ -0,0 +1,301 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This test file is for tests specific to RandLora, since Randlora has some specific challenges due to the shared weights. +# These tests are copied from the test_vera.py file + +import os + +import pytest +import torch +from safetensors import safe_open +from torch import nn + +from peft import PeftModel, RandLoraConfig, get_peft_model +from peft.utils import infer_device + + +class MLP(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.relu = nn.ReLU() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.lin1 = nn.Linear(20, 20, bias=bias) # lin1 and lin2 have same shape + self.lin2 = nn.Linear(20, 20, bias=bias) + self.lin3 = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + X = self.relu(X) + X = self.lin2(X) + X = self.relu(X) + X = self.lin3(X) + X = self.sm(X) + return X + + +class TestRandLora: + @pytest.fixture + def mlp(self): + torch.manual_seed(0) + model = MLP() + return model + + @pytest.fixture + def mlp_same_prng(self, mlp): + torch.manual_seed(0) + + config = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False) + # creates a default RandLora adapter + peft_model = get_peft_model(mlp, config) + config2 = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False) + peft_model.add_adapter("other", config2) + return peft_model + + def test_multiple_adapters_same_prng_weights(self, mlp_same_prng): + # we can have multiple adapters with the same prng key, in which case the weights should be shared + assert ( + mlp_same_prng.base_model.model.lin1.randlora_A["default"] + is mlp_same_prng.base_model.model.lin1.randlora_A["other"] + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_B["default"] + is mlp_same_prng.base_model.model.lin1.randlora_B["other"] + ) + assert ( + mlp_same_prng.base_model.model.lin2.randlora_A["default"] + is mlp_same_prng.base_model.model.lin2.randlora_A["other"] + ) + assert ( + mlp_same_prng.base_model.model.lin2.randlora_B["default"] + is mlp_same_prng.base_model.model.lin2.randlora_B["other"] + ) + + input = torch.randn(5, 10) + mlp_same_prng.set_adapter("default") + output_default = mlp_same_prng(input) + mlp_same_prng.set_adapter("other") + output_other = mlp_same_prng(input) + assert not torch.allclose(output_default, output_other, atol=1e-3, rtol=1e-3) + + def test_multiple_adapters_different_prng_raises(self): + # we cannot have multiple adapters with different prng keys + model = MLP() + config = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False) + # creates a default RandLora adapter + peft_model = get_peft_model(model, config) + config2 = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False, projection_prng_key=123) + + msg = ( + r"RandLora PRNG initialisation key must be the same for all adapters. Got config.projection_prng_key=123 but " + r"previous config had 0" + ) + with pytest.raises(ValueError, match=msg): + peft_model.add_adapter("other", config2) + + def test_multiple_adapters_save_load_save_projection_true(self, mlp_same_prng, tmp_path): + # check saving and loading works with multiple adapters and saved projection weights + torch.manual_seed(0) + input = torch.randn(5, 10) + mlp_same_prng.set_adapter("default") + output_default = mlp_same_prng(input) + mlp_same_prng.set_adapter("other") + output_other = mlp_same_prng(input) + + # sanity check + assert not torch.allclose(output_default, output_other, atol=1e-3, rtol=1e-3) + + save_path = tmp_path / "randlora" + mlp_same_prng.save_pretrained(save_path) + assert os.path.exists(save_path / "adapter_config.json") + assert os.path.exists(save_path / "other" / "adapter_config.json") + + torch.manual_seed(0) + mlp = MLP() + peft_model = PeftModel.from_pretrained(mlp, save_path) + peft_model.load_adapter(save_path / "other", "other") + + peft_model.set_adapter("default") + output_default_loaded = peft_model(input) + peft_model.set_adapter("other") + output_other_loaded = peft_model(input) + + assert torch.allclose(output_default, output_default_loaded, atol=1e-3, rtol=1e-3) + assert torch.allclose(output_other, output_other_loaded, atol=1e-3, rtol=1e-3) + + def test_multiple_adapters_save_load_save_projection_false(self, mlp, tmp_path): + # check saving and loading works with multiple adapters without saved projection weights + torch.manual_seed(1) + config = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + # creates a default RandLora adapter + peft_model = get_peft_model(mlp, config, adapter_name="first") + config2 = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + peft_model.add_adapter("second", config2) + + input = torch.randn(5, 10) + peft_model.set_adapter("first") + output_first = peft_model(input) + peft_model.set_adapter("second") + output_second = peft_model(input) + + # sanity check + assert not torch.allclose(output_first, output_second, atol=1e-3, rtol=1e-3) + + save_path = tmp_path / "randlora" + peft_model.save_pretrained(save_path) + assert os.path.exists(save_path / "first" / "adapter_config.json") + assert os.path.exists(save_path / "second" / "adapter_config.json") + + torch.manual_seed(0) + mlp = MLP() + peft_model = PeftModel.from_pretrained(mlp, save_path / "first", adapter_name="first") + peft_model.load_adapter(save_path / "second", "second") + + peft_model.set_adapter("first") + output_first_loaded = peft_model(input) + peft_model.set_adapter("second") + output_second_loaded = peft_model(input) + + assert torch.allclose(output_first, output_first_loaded, atol=1e-3, rtol=1e-3) + assert torch.allclose(output_second, output_second_loaded, atol=1e-3, rtol=1e-3) + + def test_multiple_adapters_save_projection_true_contains_randlora_A_randlora_B(self, mlp_same_prng, tmp_path): + # check that the state_dicts don't contain the projection weights + save_path = tmp_path / "randlora" + mlp_same_prng.save_pretrained(save_path) + + sd_default = {} + with safe_open(save_path / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_default[key] = f.get_tensor(key) + + assert any("randlora_A" in key for key in sd_default) + assert any("randlora_B" in key for key in sd_default) + # default rank for RandLora is 32 + assert sd_default["base_model.randlora_A"].shape == (32, 1, 20) + assert sd_default["base_model.randlora_B"].shape == (20, 1, 32) + + sd_other = {} + with safe_open(save_path / "other" / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_other[key] = f.get_tensor(key) + + assert any("randlora_A" in key for key in sd_other) + assert any("randlora_B" in key for key in sd_other) + assert sd_other["base_model.randlora_A"].shape == (32, 1, 20) + assert sd_other["base_model.randlora_B"].shape == (20, 1, 32) + + def test_multiple_adapters_save_projection_false_contains_no_randlora_A_randlora_B(self, mlp, tmp_path): + torch.manual_seed(1) + config = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + # creates a default RandLora adapter + peft_model = get_peft_model(mlp, config, adapter_name="first") + config2 = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + peft_model.add_adapter("second", config2) + + save_path = tmp_path / "randlora" + peft_model.save_pretrained(save_path) + + sd_default = {} + with safe_open(save_path / "first" / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_default[key] = f.get_tensor(key) + + assert not any("randlora_A" in key for key in sd_default) + assert not any("randlora_B" in key for key in sd_default) + + sd_other = {} + with safe_open(save_path / "second" / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_other[key] = f.get_tensor(key) + + assert not any("randlora_A" in key for key in sd_other) + assert not any("randlora_B" in key for key in sd_other) + + def test_randlora_A_randlora_B_share_memory(self, mlp_same_prng): + randlora_A = mlp_same_prng.randlora_A["default"] + randlora_B = mlp_same_prng.randlora_B["default"] + + # these tensors should share the same data + assert randlora_A.data_ptr() == mlp_same_prng.base_model.model.lin1.randlora_A["default"].data_ptr() + assert randlora_B.data_ptr() == mlp_same_prng.base_model.model.lin1.randlora_B["default"].data_ptr() + assert randlora_A.data_ptr() == mlp_same_prng.base_model.model.lin2.randlora_A["default"].data_ptr() + assert randlora_B.data_ptr() == mlp_same_prng.base_model.model.lin2.randlora_B["default"].data_ptr() + # sanity check: these tensors shouldn't share the same data + assert randlora_A.data_ptr() != randlora_B.data_ptr() + + def test_randlora_lambda_dont_share_memory(self, mlp_same_prng): + # sanity check: these tensors shouldn't share the same data + assert ( + mlp_same_prng.base_model.model.lin1.randlora_lambda["default"].data_ptr() + != mlp_same_prng.base_model.model.lin1.randlora_lambda["other"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_lambda["default"].data_ptr() + != mlp_same_prng.base_model.model.lin2.randlora_lambda["default"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_lambda["other"].data_ptr() + != mlp_same_prng.base_model.model.lin2.randlora_lambda["other"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_gamma["default"].data_ptr() + != mlp_same_prng.base_model.model.lin1.randlora_gamma["other"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_gamma["default"].data_ptr() + != mlp_same_prng.base_model.model.lin2.randlora_gamma["default"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.randlora_gamma["other"].data_ptr() + != mlp_same_prng.base_model.model.lin2.randlora_gamma["other"].data_ptr() + ) + + def test_randlora_different_shapes(self, mlp): + config = RandLoraConfig(target_modules=["lin0", "lin3"], init_weights=False) + mlp_different_shapes = get_peft_model(mlp, config) + + randlora_A = mlp_different_shapes.randlora_A["default"] + randlora_B = mlp_different_shapes.randlora_B["default"] + + # sanity check + assert mlp.lin0.base_layer.weight.shape != mlp.lin3.base_layer.weight.shape + + # lin0 has the largest output dimension, lin3 has the largest input dimension + # randlora_A should have the shape of (rank, largest_in), randlora_B should have the shape of (largest_out, rank) + assert randlora_A.shape == (config.r, 1, mlp.lin3.in_features) + assert randlora_B.shape == (mlp.lin0.out_features, 1, config.r) + + # should not raise + input = torch.randn(5, 10) + mlp_different_shapes(input) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_randlora_dtypes(self, dtype): + if dtype == torch.bfloat16: + # skip if bf16 is not supported on hardware, see #1872 + is_xpu = infer_device() == "xpu" + is_cuda_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() + if not (is_xpu or is_cuda_bf16): + pytest.skip("bfloat16 not supported on this system, skipping the test") + + model = MLP().to(dtype) + config = RandLoraConfig(target_modules=["lin1", "lin2"], init_weights=False) + peft_model = get_peft_model(model, config) + inputs = torch.randn(5, 10).to(dtype) + output = peft_model(inputs) # should not raise + assert output.dtype == dtype From 6c6d4414651519c9e559caa74635381b56eef0f3 Mon Sep 17 00:00:00 2001 From: Paul Albert Date: Wed, 23 Apr 2025 08:52:02 +1000 Subject: [PATCH 09/13] Update src/peft/tuners/randlora/model.py Bug fixing Co-authored-by: Benjamin Bossan --- src/peft/tuners/randlora/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/randlora/model.py b/src/peft/tuners/randlora/model.py index 1373d581fd..21f1740f71 100644 --- a/src/peft/tuners/randlora/model.py +++ b/src/peft/tuners/randlora/model.py @@ -158,7 +158,7 @@ def _init_randlora_A_randlora_B_sparse(self, config: RandLoraConfig, adapter_nam # Number of bases to ensure full rank num_bases = min_dim / config.r - num_bases = int(n_bases) if num_bases.is_integer() else int(num_bases) + 1 # Ensure full rank + num_bases = int(num_bases) if num_bases.is_integer() else int(num_bases) + 1 # Ensure full rank randlora_B = torch.rand((max_dim, num_bases, config.r), generator=generator) # The current implementation is a proof of concept and does take into consideration From 409f64760381bb3428e58f77125a16548101ea12 Mon Sep 17 00:00:00 2001 From: PaulAlbert31 Date: Tue, 22 Apr 2025 22:53:34 +0000 Subject: [PATCH 10/13] comment on TesRandLora --- tests/test_randlora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_randlora.py b/tests/test_randlora.py index 03017d0bd7..a3b208ba57 100644 --- a/tests/test_randlora.py +++ b/tests/test_randlora.py @@ -47,7 +47,8 @@ def forward(self, X): X = self.sm(X) return X - +# Tests copied from the TestVera class in test_vera.py. +# Changes to the code file should be reflected here. class TestRandLora: @pytest.fixture def mlp(self): From 9d280aff423425e5b226c4398f00e1fe76c844a7 Mon Sep 17 00:00:00 2001 From: PaulAlbert31 Date: Wed, 23 Apr 2025 23:25:44 +0000 Subject: [PATCH 11/13] make style --- src/peft/tuners/randlora/model.py | 4 +++- tests/test_randlora.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/randlora/model.py b/src/peft/tuners/randlora/model.py index 21f1740f71..f0d7498d79 100644 --- a/src/peft/tuners/randlora/model.py +++ b/src/peft/tuners/randlora/model.py @@ -196,7 +196,9 @@ def _init_randlora_A_randlora_B(self, config: RandLoraConfig, adapter_name: str) # Ensure full rank num_bases = min(linear_out_dim, linear_in_dim) / config.r num_bases = int(num_bases) if num_bases.is_integer() else int(num_bases) + 1 - randlora_B = torch.cat([_kaiming_init((max_dim, 1, config.r), generator=generator) for _ in range(num_bases)], dim=1) + randlora_B = torch.cat( + [_kaiming_init((max_dim, 1, config.r), generator=generator) for _ in range(num_bases)], dim=1 + ) # Std normalization is empirically found to be the best randlora_A, randlora_B = randlora_A / randlora_A.std(), randlora_B / randlora_B.std() diff --git a/tests/test_randlora.py b/tests/test_randlora.py index a3b208ba57..b553177ea9 100644 --- a/tests/test_randlora.py +++ b/tests/test_randlora.py @@ -47,7 +47,8 @@ def forward(self, X): X = self.sm(X) return X -# Tests copied from the TestVera class in test_vera.py. + +# Tests copied from the TestVera class in test_vera.py. # Changes to the code file should be reflected here. class TestRandLora: @pytest.fixture From 321fb5c95a59ee294b1eead395986c25ee1c011d Mon Sep 17 00:00:00 2001 From: PaulAlbert31 Date: Thu, 24 Apr 2025 10:45:37 +0000 Subject: [PATCH 12/13] reduce randlora_alpha in tests --- tests/test_custom_models.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 2cfeb715dd..ebe3f09847 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -516,27 +516,27 @@ # RandLora # ######## # We have to reduce the default scaling parameter to avoid nans when using large learning rates - ("Vanilla MLP 1 RandLora", "MLP", RandLoraConfig, {"target_modules": "lin0", "randlora_alpha": 64}), - ("Vanilla MLP 2 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0"], "randlora_alpha": 64}), - ("Vanilla MLP 3 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin1"], "randlora_alpha": 64}), - ("Vanilla MLP 4 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "randlora_alpha": 64}), + ("Vanilla MLP 1 RandLora", "MLP", RandLoraConfig, {"target_modules": "lin0", "randlora_alpha": 1}), + ("Vanilla MLP 2 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0"], "randlora_alpha": 1}), + ("Vanilla MLP 3 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin1"], "randlora_alpha": 1}), + ("Vanilla MLP 4 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "randlora_alpha": 1}), ( "Vanilla MLP 5 RandLora", "MLP", RandLoraConfig, - {"target_modules": ["lin0", "lin1"], "sparse": True, "randlora_alpha": 64}, + {"target_modules": ["lin0", "lin1"], "sparse": True, "randlora_alpha": 1}, ), ( "Vanilla MLP 6 RandLora", "MLP", RandLoraConfig, - {"target_modules": ["lin0", "lin1"], "very_sparse": True, "randlora_alpha": 64}, + {"target_modules": ["lin0", "lin1"], "very_sparse": True, "randlora_alpha": 1}, ), ( "Vanilla MLP 7 RandLora", "MLP", RandLoraConfig, - {"target_modules": ["lin0"], "modules_to_save": ["lin1"], "randlora_alpha": 64}, + {"target_modules": ["lin0"], "modules_to_save": ["lin1"], "randlora_alpha": 1}, ), ] From 43093c24bdf32ae3b71a36e512830c092b9fbd73 Mon Sep 17 00:00:00 2001 From: PaulAlbert31 Date: Fri, 25 Apr 2025 01:08:21 +0000 Subject: [PATCH 13/13] add instability warining for larger randlora_alpha values --- src/peft/tuners/randlora/config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/randlora/config.py b/src/peft/tuners/randlora/config.py index 06a739fbe6..2eac42bf08 100644 --- a/src/peft/tuners/randlora/config.py +++ b/src/peft/tuners/randlora/config.py @@ -56,7 +56,10 @@ class RandLoraConfig(PeftConfig): randlora_dropout (`float`): The dropout probability for RandLora layers. randlora_alpha (`float`): - The scaling coefficient for RandLora layers, this would typically be 20 times the rank. + The scaling coefficient for RandLora layers, this would typically be 20 times the rank. Because the + `randlora_alpha` coefficient is large by default, it can lead to numerical instabilities especially when + learning rates are high. If training is unstable, consider reducing the learning rate or the + `randlora_alpha` coefficient. fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.