You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/peft/tuners/randlora/config.py
+12-9Lines changed: 12 additions & 9 deletions
Original file line number
Diff line number
Diff line change
@@ -28,9 +28,9 @@ class RandLoraConfig(PeftConfig):
28
28
Paper: https://arxiv.org/pdf/2502.00987.
29
29
30
30
Args:
31
-
r (`int`, *optional*, defaults to `32`):
32
-
RandLora's random basis rank dimension. This parameter is inversely proportional to the amount of trainable
33
-
parameters.
31
+
r (`int`, *optional*, defaults to `10`):
32
+
RandLora's random basis rank dimension. Contrary to Lora, this parameter is inversely proportional to the amount of trainable
33
+
parameters as reducing it increases trainable parameters.
34
34
target_modules (`Union[List[str], str]`):
35
35
The names of the modules to apply RandLora to. Only linear layers are supported.
36
36
projection_prng_key (`int`):
@@ -41,11 +41,14 @@ class RandLoraConfig(PeftConfig):
41
41
gamma diagonal matrices. This will increase the size of the checkpoint, but guarantee that we can
42
42
reload the checkpoint on all system configurations. Defaults to `True`.
43
43
sparse (`bool`):
44
-
Whether to use sparse random bases as described in the RandLora paper. The current implementation is a
45
-
proof of concept where the sparseness is not used to improve speed or memory usage. Defaults to `False`.
44
+
Whether to use sparse random bases as described in the RandLora paper. The bases are ternary sparse bases (only containing -1, 0 and 1) where the attribution probability is 1/6 for -1 and 1 and 2/3 for 0.
45
+
These sparse matrices aim to be used for matmul free computation in the future, see https://arxiv.org/pdf/2406.02528v1
46
+
The current implementation is a proof of concept however where the sparseness is not used to improve speed or memory usage. Using sparse matrices typically does not reduce performance and can even help reduce overfitting.
47
+
Defaults to `False`.
46
48
very_sparse (`bool`):
47
-
Whether to use very sparse random bases. The current implementation is a proof of concept where the
48
-
sparseness is not used to improve speed or memory usage. Defaults to `False`.
49
+
Whether to use highly sparse random bases as described in the RandLora paper. The very sparse bases are ternary sparse bases (only containing -1, 0 and 1) given a matrix with smallest dimension d, the attribution probability is 1/√D for -1 and 1 and 1- 2/√D for 0.
50
+
Using these sparse matrices can further reduce overfitting over the `sparse` alternatives but will most likely decrease performance as a results. Use carefully.
51
+
Defaults to `False`.
49
52
randlora_dropout (`float`):
50
53
The dropout probability for RandLora layers.
51
54
randlora_alpha (`float`):
@@ -72,7 +75,7 @@ class RandLoraConfig(PeftConfig):
72
75
pattern is not in the common layers pattern.
73
76
"""
74
77
75
-
r: int=field(default=32, metadata={"help": "RandLora random basis rank"})
78
+
r: int=field(default=10, metadata={"help": "RandLora random basis rank"})
# Since update_A is applied on the smallest dimension, test whether update_A or update_B should applied first. This is done to reduce trainable parameters.
0 commit comments