Skip to content

Commit 7942a9c

Browse files
committed
better hyper-parameters and comformity with new dtype casts
1 parent a1f0539 commit 7942a9c

File tree

7 files changed

+62
-39
lines changed

7 files changed

+62
-39
lines changed

src/peft/tuners/randlora/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright 2025-present the HuggingFace Inc. team.
2+
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
45
# you may not use this file except in compliance with the License.

src/peft/tuners/randlora/bnb.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]:
162162
# During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B.
163163
sliced_A = randlora_A[:, : self.num_bases, :min_dim]
164164
sliced_B = randlora_B[:max_dim, : self.num_bases, :]
165+
165166
# Flattening the matrices over the rank and number of bases dimensions is more memory efficient
166167
update_B = sliced_B.flatten(start_dim=1)
167168
update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1)
@@ -216,6 +217,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
216217
continue
217218

218219
update_B, update_A = self.get_scaled_bases(active_adapter)
220+
219221
requires_conversion = not torch.is_autocast_enabled()
220222
if requires_conversion:
221223
expected_dtype = result.dtype
@@ -382,7 +384,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
382384
adapter (str):
383385
The name of the adapter for which the delta weight should be computed.
384386
"""
385-
386387
update_B, update_A = self.get_scaled_bases(adapter)
387388

388389
update = update_B @ update_A
@@ -405,7 +406,9 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
405406
for active_adapter in self.active_adapters:
406407
if active_adapter not in self.randlora_lambda.keys():
407408
continue
409+
408410
update_B, update_A = self.get_scaled_bases(active_adapter)
411+
409412
requires_conversion = not torch.is_autocast_enabled()
410413
if requires_conversion:
411414
expected_dtype = result.dtype

src/peft/tuners/randlora/config.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import warnings
1616
from dataclasses import dataclass, field
17-
from typing import List, Optional, Union
17+
from typing import Optional, Union
1818

1919
from peft.config import PeftConfig
2020
from peft.utils import PeftType
@@ -28,10 +28,10 @@ class RandLoraConfig(PeftConfig):
2828
Paper: https://arxiv.org/pdf/2502.00987.
2929
3030
Args:
31-
r (`int`, *optional*, defaults to `10`):
31+
r (`int`, *optional*, defaults to `32`):
3232
RandLora's random basis rank dimension. Contrary to Lora, this parameter is inversely proportional to the amount of trainable
3333
parameters as reducing it increases trainable parameters.
34-
target_modules (`Union[List[str], str]`):
34+
target_modules (`Union[list[str], str]`):
3535
The names of the modules to apply RandLora to. Only linear layers are supported.
3636
projection_prng_key (`int`):
3737
RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a
@@ -52,21 +52,20 @@ class RandLoraConfig(PeftConfig):
5252
randlora_dropout (`float`):
5353
The dropout probability for RandLora layers.
5454
randlora_alpha (`float`):
55-
The scaling coefficient for RandLora layers, this would be typically be the same as LoRA, e.g. 2 times the
56-
rank.
55+
The scaling coefficient for RandLora layers, this would typically be 20 times the rank.
5756
fan_in_fan_out (`bool`):
5857
Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses
5958
`Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.
6059
bias (`str`):
6160
Bias type. Can be 'none', 'all' or 'randlora_only'. If 'all' or 'randlora_only', the corresponding biases
6261
will be updated during training. Be aware that this means that, even when disabling the adapters, the model
6362
will not produce the same output as the base model would have without adaptation.
64-
modules_to_save (`List[str]`):
65-
List of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint.
63+
modules_to_save (`list[str]`):
64+
list of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint.
6665
init_weights (`bool`):
6766
Whether to initialize the weights of the RandLora layers with their default initialization. Don't change
6867
this setting, except if you know exactly what you're doing.
69-
layers_to_transform (`Union[List[int],int]`):
68+
layers_to_transform (`Union[list[int],int]`):
7069
The layer indexes to transform, if this argument is specified, it will apply the RandLora transformations
7170
on the layer indexes that are specified in this list. If a single integer is passed, it will apply the
7271
RandLora transformations on the layer at this index.
@@ -75,13 +74,13 @@ class RandLoraConfig(PeftConfig):
7574
pattern is not in the common layers pattern.
7675
"""
7776

78-
r: int = field(default=10, metadata={"help": "RandLora random basis rank"})
77+
r: int = field(default=32, metadata={"help": "RandLora random basis rank"})
7978

80-
target_modules: Optional[Union[List[str], str]] = field(
79+
target_modules: Optional[Union[list[str], str]] = field(
8180
default=None,
8281
metadata={
8382
"help": (
84-
"List of module names or regex expression of the module names to replace with RandLora."
83+
"list of module names or regex expression of the module names to replace with RandLora."
8584
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
8685
"Only linear layers are supported."
8786
)
@@ -132,19 +131,19 @@ class RandLoraConfig(PeftConfig):
132131
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
133132
)
134133
randlora_alpha: int = field(
135-
default=20,
134+
default=640,
136135
metadata={
137-
"help": "Scaling coefficient in the adapter layers, typically 2 times the rank of the random bases."
136+
"help": "Scaling coefficient in the adapter layers, typically 20 times the rank of the random bases."
138137
},
139138
)
140139
bias: str = field(
141140
default="none", metadata={"help": "Bias type for RandLora. Can be 'none', 'all' or 'randlora_only'"}
142141
)
143-
modules_to_save: Optional[List[str]] = field(
142+
modules_to_save: Optional[list[str]] = field(
144143
default=None,
145144
metadata={
146145
"help": (
147-
"List of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. For"
146+
"list of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint. For"
148147
" example, in Sequence Classification or Token Classification tasks, the final layer"
149148
" `classifier/score` are randomly initialized and as such need to be trainable and saved."
150149
)
@@ -159,7 +158,7 @@ class RandLoraConfig(PeftConfig):
159158
),
160159
},
161160
)
162-
layers_to_transform: Optional[Union[List[int], int]] = field(
161+
layers_to_transform: Optional[Union[list[int], int]] = field(
163162
default=None,
164163
metadata={
165164
"help": (

src/peft/tuners/randlora/layer.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
1615
import warnings
17-
from typing import List, Optional
16+
from typing import Optional
1817

1918
import torch
2019
import torch.nn as nn
@@ -72,6 +71,9 @@ def __init__(self, base_layer: nn.Module, **kwargs):
7271
self._disable_adapters = False
7372
self.merged_adapters = []
7473

74+
# flag to enable/disable casting of input to weight dtype during forward call
75+
self.cast_input_dtype_enabled = True
76+
7577
base_layer = self.get_base_layer()
7678
if isinstance(base_layer, nn.Linear):
7779
in_features, out_features = base_layer.in_features, base_layer.out_features
@@ -118,7 +120,7 @@ def update_layer(
118120
requires_grad=True,
119121
)
120122

121-
self.scaling[adapter_name] = randlora_alpha / r / math.sqrt(self.num_bases)
123+
self.scaling[adapter_name] = randlora_alpha / r
122124

123125
# non trainable references to randlora_A/B buffers
124126
self.randlora_A = randlora_A
@@ -153,6 +155,7 @@ def update_layer(
153155
)
154156
if randlora_A_param.shape[0] < self.r[adapter_name]:
155157
raise ValueError(error_tmpl.format("randlora_A", randlora_A_param.shape[0], self.r[adapter_name]))
158+
156159
if randlora_B_param.shape[-1] < self.r[adapter_name]:
157160
raise ValueError(error_tmpl.format("randlora_B", randlora_B_param.shape[-1], self.r[adapter_name]))
158161

@@ -169,9 +172,7 @@ def reset_randlora_parameters(self, adapter_name):
169172
if adapter_name in self.randlora_lambda.keys():
170173
with torch.no_grad():
171174
nn.init.zeros_(self.randlora_lambda[adapter_name])
172-
nn.init.ones_(self.randlora_gamma[adapter_name]).fill_(
173-
1 / max(self.randlora_gamma[adapter_name].shape)
174-
)
175+
nn.init.constant_(self.randlora_gamma[adapter_name], 1 / max(self.randlora_gamma[adapter_name].shape))
175176

176177

177178
class Linear(nn.Linear, RandLoraLayer):
@@ -198,7 +199,7 @@ def __init__(
198199
self.update_layer(adapter_name, randlora_A, randlora_B, r, randlora_alpha, randlora_dropout, init_weights)
199200
self.is_target_conv_1d_layer = is_target_conv_1d_layer
200201

201-
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
202+
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
202203
"""
203204
Merge the active adapter weights into the base weights
204205
@@ -207,7 +208,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
207208
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
208209
before merging the weights. This is useful if you want to check if the merge operation will produce
209210
NaNs. Defaults to `False`.
210-
adapter_names (`List[str]`, *optional*):
211+
adapter_names (`list[str]`, *optional*):
211212
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
212213
to `None`.
213214
"""
@@ -219,6 +220,8 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
219220
for active_adapter in adapter_names:
220221
if active_adapter in self.randlora_lambda.keys():
221222
base_layer = self.get_base_layer()
223+
orig_dtype = base_layer.weight.dtype
224+
222225
if safe_merge:
223226
# Note that safe_merge will be slower than the normal merge
224227
# because of the copy operation.
@@ -231,9 +234,11 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
231234
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
232235
)
233236

234-
base_layer.weight.data = orig_weights
237+
base_layer.weight.data = orig_weights.to(orig_dtype)
235238
else:
236-
base_layer.weight.data += self.get_delta_weight(active_adapter)
239+
delta_weight = self.get_delta_weight(active_adapter)
240+
base_layer.weight.data += delta_weight.to(orig_dtype)
241+
237242
self.merged_adapters.append(active_adapter)
238243

239244
def unmerge(self) -> None:
@@ -242,9 +247,12 @@ def unmerge(self) -> None:
242247
return
243248

244249
while len(self.merged_adapters) > 0:
250+
base_layer = self.get_base_layer()
251+
orig_dtype = base_layer.weight.dtype
245252
active_adapter = self.merged_adapters.pop()
246253
if active_adapter in self.randlora_lambda.keys():
247-
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
254+
delta_weight = self.get_delta_weight(active_adapter)
255+
base_layer.weight.data -= delta_weight.to(orig_dtype)
248256

249257
def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
250258
"""
@@ -289,7 +297,7 @@ def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
289297
update_B = sliced_B.flatten(start_dim=1)
290298
update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1)
291299

292-
# Since update_A is applied on the smallest dimension, test whether update_A or update_B should applied first. This is done to reduce trainable parameters.
300+
# Since update_A is applied on the smallest dimension, test whether update_A or update_B should be applied first. This is done to reduce trainable parameters.
293301
if min_dim == self.in_features:
294302
return update_A, update_B
295303
return update_B.T, update_A.T

src/peft/tuners/randlora/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def _init_randlora_A_randlora_B(self, config: RandLoraConfig, adapter_name: str)
188188

189189
# deterministic init of randlora_A and randlora_B if we know the key
190190
generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key)
191+
191192
# The gamma matrix is applied on A meaning it can be unique (shared) accross the n scaling matrices.
192193
# We also set randlora_A as the smallest matrix to reduce trainable parameters.
193194
randlora_A = _kaiming_init((config.r, 1, min_dim), generator=generator)

tests/test_custom_models.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -515,17 +515,28 @@
515515
########
516516
# RandLora #
517517
########
518-
("Vanilla MLP 1 RandLora", "MLP", RandLoraConfig, {"target_modules": "lin0"}),
519-
("Vanilla MLP 2 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0"]}),
520-
("Vanilla MLP 3 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin1"]}),
521-
("Vanilla MLP 4 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"]}),
522-
("Vanilla MLP 5 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "sparse": True}),
523-
("Vanilla MLP 6 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "very_sparse": True}),
518+
# We have to reduce the default scaling parameter to avoid nans when using large learning rates
519+
("Vanilla MLP 1 RandLora", "MLP", RandLoraConfig, {"target_modules": "lin0", "randlora_alpha": 64}),
520+
("Vanilla MLP 2 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0"], "randlora_alpha": 64}),
521+
("Vanilla MLP 3 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin1"], "randlora_alpha": 64}),
522+
("Vanilla MLP 4 RandLora", "MLP", RandLoraConfig, {"target_modules": ["lin0", "lin1"], "randlora_alpha": 64}),
523+
(
524+
"Vanilla MLP 5 RandLora",
525+
"MLP",
526+
RandLoraConfig,
527+
{"target_modules": ["lin0", "lin1"], "sparse": True, "randlora_alpha": 64},
528+
),
529+
(
530+
"Vanilla MLP 6 RandLora",
531+
"MLP",
532+
RandLoraConfig,
533+
{"target_modules": ["lin0", "lin1"], "very_sparse": True, "randlora_alpha": 64},
534+
),
524535
(
525536
"Vanilla MLP 7 RandLora",
526537
"MLP",
527538
RandLoraConfig,
528-
{"target_modules": ["lin0"], "modules_to_save": ["lin1"]},
539+
{"target_modules": ["lin0"], "modules_to_save": ["lin1"], "randlora_alpha": 64},
529540
),
530541
]
531542

@@ -1465,7 +1476,7 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c
14651476
lr = 0.1 # otherwise we get nan
14661477
elif "mha" in model_id.lower():
14671478
lr = 1e-3 # we get exploding gradients with MHA when learning rate is too high
1468-
elif issubclass(config_cls, VBLoRAConfig):
1479+
elif issubclass(config_cls, VBLoRAConfig) or issubclass(config_cls, RandLoraConfig):
14691480
lr = 0.01 # otherwise we get nan
14701481
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
14711482

tests/testing_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@
145145
},
146146
# RandLoRA
147147
{
148-
"r": 10,
149-
"randlora_alpha": 20,
148+
"r": 32,
149+
"randlora_alpha": 64,
150150
"target_modules": None,
151151
"randlora_dropout": 0.05,
152152
"projection_prng_key": 0xFF,

0 commit comments

Comments
 (0)