You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/peft/tuners/randlora/config.py
+31-20Lines changed: 31 additions & 20 deletions
Original file line number
Diff line number
Diff line change
@@ -1,4 +1,4 @@
1
-
# Copyright 2023-present the HuggingFace Inc. team.
1
+
# Copyright 2025-present the HuggingFace Inc. team.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
@@ -13,40 +13,44 @@
13
13
# limitations under the License.
14
14
15
15
importwarnings
16
-
importmath
17
16
fromdataclassesimportdataclass, field
18
17
fromtypingimportList, Optional, Union
19
18
20
19
frompeft.configimportPeftConfig
21
20
frompeft.utilsimportPeftType
22
21
22
+
23
23
@dataclass
24
24
classRandLoraConfig(PeftConfig):
25
25
"""
26
26
This is the configuration class to store the configuration of a [`RandLoraModel`].
27
27
28
-
Paper: {}.
28
+
Paper: https://arxiv.org/pdf/2502.00987.
29
29
30
30
Args:
31
31
r (`int`, *optional*, defaults to `32`):
32
-
RandLora's random basis rank dimension. This parameter is inversely proportional to the amount of trainable parameters.
32
+
RandLora's random basis rank dimension. This parameter is inversely proportional to the amount of trainable
33
+
parameters.
33
34
target_modules (`Union[List[str], str]`):
34
35
The names of the modules to apply RandLora to. Only linear layers are supported.
35
36
projection_prng_key (`int`):
36
-
RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a checkpoint
37
-
that did not include these projections. Defaults to `int(math.exp(1)*3.1415*1000)`.
37
+
RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a
38
+
checkpoint that did not include these projections. Defaults to `0`.
38
39
save_projection (`bool`):
39
-
Whether to save the global basis_A / basis_B random basis in the state dict alongside per layer lambda / gamma diagonal matrices.
40
-
weights. This will increase the size of the checkpoint, but guarantee that we can reload the checkpoint on
41
-
all system configurations. Defaults to `True`.
40
+
Whether to save the global basis_A / basis_B random basis in the state dict alongside per layer lambda /
41
+
gamma diagonal matrices. This will increase the size of the checkpoint, but guarantee that we can
42
+
reload the checkpoint on all system configurations. Defaults to `True`.
42
43
sparse (`bool`):
43
-
Whether to use sparse random bases as described in the RandLora paper. The current implementation is a proof of concept where the sparseness is not used to improve speed or memory usage. Defaults to `False`.
44
+
Whether to use sparse random bases as described in the RandLora paper. The current implementation is a
45
+
proof of concept where the sparseness is not used to improve speed or memory usage. Defaults to `False`.
44
46
very_sparse (`bool`):
45
-
Whether to use very sparse random bases. The current implementation is a proof of concept where the sparseness is not used to improve speed or memory usage. Defaults to `False`.
47
+
Whether to use very sparse random bases. The current implementation is a proof of concept where the
48
+
sparseness is not used to improve speed or memory usage. Defaults to `False`.
46
49
randlora_dropout (`float`):
47
50
The dropout probability for RandLora layers.
48
51
randlora_alpha (`float`):
49
-
The scaling coefficient for RandLora layers, this would be typically be the same as LoRA, e.g. 2 times the rank.
52
+
The scaling coefficient for RandLora layers, this would be typically be the same as LoRA, e.g. 2 times the
53
+
rank.
50
54
fan_in_fan_out (`bool`):
51
55
Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses
52
56
`Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.
@@ -57,12 +61,12 @@ class RandLoraConfig(PeftConfig):
57
61
modules_to_save (`List[str]`):
58
62
List of modules apart from RandLora layers to be set as trainable and saved in the final checkpoint.
59
63
init_weights (`bool`):
60
-
Whether to initialize the weights of the RandLora layers with their default initialization. Don't change this
61
-
setting, except if you know exactly what you're doing.
64
+
Whether to initialize the weights of the RandLora layers with their default initialization. Don't change
65
+
this setting, except if you know exactly what you're doing.
62
66
layers_to_transform (`Union[List[int],int]`):
63
-
The layer indexes to transform, if this argument is specified, it will apply the RandLora transformations on
64
-
the layer indexes that are specified in this list. If a single integer is passed, it will apply the RandLora
65
-
transformations on the layer at this index.
67
+
The layer indexes to transform, if this argument is specified, it will apply the RandLora transformations
68
+
on the layer indexes that are specified in this list. If a single integer is passed, it will apply the
69
+
RandLora transformations on the layer at this index.
66
70
layers_pattern (`str`):
67
71
The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer
68
72
pattern is not in the common layers pattern.
@@ -81,7 +85,7 @@ class RandLoraConfig(PeftConfig):
81
85
},
82
86
)
83
87
projection_prng_key: int=field(
84
-
default=int(math.exp(1)*3.1415*1000),
88
+
default=0,
85
89
metadata={
86
90
"help": (
87
91
"RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a "
@@ -124,8 +128,15 @@ class RandLoraConfig(PeftConfig):
124
128
default=False,
125
129
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
126
130
)
127
-
randlora_alpha: int=field(default=64, metadata={"help": "Scaling coefficient in the adapter layers, typically 2 times the rank of the random bases."})
128
-
bias: str=field(default="none", metadata={"help": "Bias type for RandLora. Can be 'none', 'all' or 'randlora_only'"})
131
+
randlora_alpha: int=field(
132
+
default=64,
133
+
metadata={
134
+
"help": "Scaling coefficient in the adapter layers, typically 2 times the rank of the random bases."
135
+
},
136
+
)
137
+
bias: str=field(
138
+
default="none", metadata={"help": "Bias type for RandLora. Can be 'none', 'all' or 'randlora_only'"}
0 commit comments