Skip to content

Commit 95c8fb1

Browse files
committed
new tests for shared weights, temp multi-gpu fix, better var names and docstrings
1 parent 7942a9c commit 95c8fb1

File tree

9 files changed

+396
-66
lines changed

9 files changed

+396
-66
lines changed

src/peft/tuners/randlora/bnb.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,18 @@ def __init__(
5959
)
6060

6161
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
62-
if self.merged:
63-
warnings.warn(
64-
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
65-
f"You are now additionally merging {','.join(self.active_adapters)}."
66-
)
62+
"""
63+
Merge the active adapter weights into the base weights
64+
65+
Args:
66+
safe_merge (`bool`, *optional*):
67+
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
68+
before merging the weights. This is useful if you want to check if the merge operation will produce
69+
NaNs. Defaults to `False`.
70+
adapter_names (`list[str]`, *optional*):
71+
The list of adapter names that should be merged. If None, all active adapters will be merged.
72+
Defaults to `None`.
73+
"""
6774

6875
adapter_names = check_adapters_to_merge(self, adapter_names)
6976
if not adapter_names:
@@ -98,6 +105,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
98105
self.merged_adapters.append(active_adapter)
99106

100107
def unmerge(self) -> None:
108+
"""
109+
This method unmerges all merged adapter layers from the base weights.
110+
"""
101111
if not self.merged:
102112
warnings.warn("Already unmerged. Nothing to do")
103113
return
@@ -124,7 +134,7 @@ def unmerge(self) -> None:
124134
).to(weight.device)
125135
state.reset_grads()
126136

127-
def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]:
137+
def get_scaled_bases(self, adapter, device=None) -> list[torch.Tensor, torch.Tensor]:
128138
"""
129139
Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the
130140
correct order to fit the target layers' dimensions
@@ -137,16 +147,17 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]:
137147
randlora_A = self.randlora_A[adapter]
138148
randlora_B = self.randlora_B[adapter]
139149

140-
device = randlora_B.device
150+
if device is None:
151+
device = randlora_B.device
141152
dtype = randlora_B.dtype
142153

143154
# In case users wants to merge the adapter weights that are in
144155
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
145156
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
146157
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
147158

148-
randlora_lambda = self.randlora_lambda[adapter]
149-
randlora_gamma = self.randlora_gamma[adapter]
159+
randlora_lambda = self.randlora_lambda[adapter].to(device)
160+
randlora_gamma = self.randlora_gamma[adapter].to(device)
150161

151162
if cast_to_fp32:
152163
randlora_A = randlora_A.float()
@@ -160,8 +171,8 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]:
160171
# As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices,
161172
# we initialize these matrices with the largest required size for each dimension.
162173
# During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B.
163-
sliced_A = randlora_A[:, : self.num_bases, :min_dim]
164-
sliced_B = randlora_B[:max_dim, : self.num_bases, :]
174+
sliced_A = randlora_A[:, : self.num_bases, :min_dim].to(device)
175+
sliced_B = randlora_B[:max_dim, : self.num_bases, :].to(device)
165176

166177
# Flattening the matrices over the rank and number of bases dimensions is more memory efficient
167178
update_B = sliced_B.flatten(start_dim=1)
@@ -216,7 +227,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
216227
if active_adapter not in self.randlora_lambda.keys():
217228
continue
218229

219-
update_B, update_A = self.get_scaled_bases(active_adapter)
230+
update_B, update_A = self.get_scaled_bases(active_adapter, device=x.device)
220231

221232
requires_conversion = not torch.is_autocast_enabled()
222233
if requires_conversion:
@@ -275,11 +286,18 @@ def __init__(
275286
)
276287

277288
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
278-
if self.merged:
279-
warnings.warn(
280-
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
281-
f"You are now additionally merging {','.join(self.active_adapters)}."
282-
)
289+
"""
290+
Merge the active adapter weights into the base weights
291+
292+
Args:
293+
safe_merge (`bool`, *optional*):
294+
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
295+
before merging the weights. This is useful if you want to check if the merge operation will produce
296+
NaNs. Defaults to `False`.
297+
adapter_names (`list[str]`, *optional*):
298+
The list of adapter names that should be merged. If None, all active adapters will be merged.
299+
Defaults to `None`.
300+
"""
283301

284302
adapter_names = check_adapters_to_merge(self, adapter_names)
285303
if not adapter_names:
@@ -309,6 +327,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
309327
self.merged_adapters.append(active_adapter)
310328

311329
def unmerge(self) -> None:
330+
"""
331+
This method unmerges all merged adapter layers from the base weights.
332+
"""
312333
if not self.merged:
313334
warnings.warn("Already unmerged. Nothing to do")
314335
return
@@ -330,7 +351,7 @@ def unmerge(self) -> None:
330351
weight.device
331352
)
332353

333-
def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]:
354+
def get_scaled_bases(self, adapter, device=None) -> list[torch.Tensor, torch.Tensor]:
334355
"""
335356
Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the
336357
correct order to fit the target layers' dimensions
@@ -342,17 +363,17 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]:
342363

343364
randlora_A = self.randlora_A[adapter]
344365
randlora_B = self.randlora_B[adapter]
345-
346-
device = randlora_B.device
366+
if device is None:
367+
device = randlora_B.device
347368
dtype = randlora_B.dtype
348369

349370
# In case users wants to merge the adapter weights that are in
350371
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
351372
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
352373
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
353374

354-
randlora_lambda = self.randlora_lambda[adapter]
355-
randlora_gamma = self.randlora_gamma[adapter]
375+
randlora_lambda = self.randlora_lambda[adapter].to(device)
376+
randlora_gamma = self.randlora_gamma[adapter].to(device)
356377

357378
if cast_to_fp32:
358379
randlora_A = randlora_A.float()
@@ -366,8 +387,8 @@ def get_scaled_bases(self, adapter) -> list[torch.Tensor, torch.Tensor]:
366387
# As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices,
367388
# we initialize these matrices with the largest required size for each dimension.
368389
# During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B.
369-
sliced_A = randlora_A[:, : self.num_bases, :min_dim]
370-
sliced_B = randlora_B[:max_dim, : self.num_bases, :]
390+
sliced_A = randlora_A[:, : self.num_bases, :min_dim].to(device)
391+
sliced_B = randlora_B[:max_dim, : self.num_bases, :].to(device)
371392
# Flattening the matrices over the rank and number of bases dimensions is more memory efficient
372393
update_B = sliced_B.flatten(start_dim=1)
373394
update_A = UniqueBaseGrad.apply(sliced_A, randlora_lambda, randlora_gamma).flatten(end_dim=1)
@@ -407,7 +428,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
407428
if active_adapter not in self.randlora_lambda.keys():
408429
continue
409430

410-
update_B, update_A = self.get_scaled_bases(active_adapter)
431+
update_B, update_A = self.get_scaled_bases(active_adapter, device=x.device)
411432

412433
requires_conversion = not torch.is_autocast_enabled()
413434
if requires_conversion:

src/peft/tuners/randlora/config.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,30 @@ class RandLoraConfig(PeftConfig):
2929
3030
Args:
3131
r (`int`, *optional*, defaults to `32`):
32-
RandLora's random basis rank dimension. Contrary to Lora, this parameter is inversely proportional to the amount of trainable
33-
parameters as reducing it increases trainable parameters.
32+
RandLora's random basis rank dimension. Contrary to Lora, this parameter is inversely proportional to the
33+
amount of trainable parameters as reducing it increases trainable parameters.
3434
target_modules (`Union[list[str], str]`):
3535
The names of the modules to apply RandLora to. Only linear layers are supported.
3636
projection_prng_key (`int`):
3737
RandLora PRNG init key. Used for initialising basis_A and basis_B for new models or when loading a
3838
checkpoint that did not include these projections. Defaults to `0`.
3939
save_projection (`bool`):
4040
Whether to save the global basis_A / basis_B random basis in the state dict alongside per layer lambda /
41-
gamma diagonal matrices. This will increase the size of the checkpoint, but guarantee that we can
42-
reload the checkpoint on all system configurations. Defaults to `True`.
41+
gamma diagonal matrices. This will increase the size of the checkpoint, but guarantee that we can reload
42+
the checkpoint on all system configurations. Defaults to `True`.
4343
sparse (`bool`):
44-
Whether to use sparse random bases as described in the RandLora paper. The bases are ternary sparse bases (only containing -1, 0 and 1) where the attribution probability is 1/6 for -1 and 1 and 2/3 for 0.
45-
These sparse matrices aim to be used for matmul free computation in the future, see https://arxiv.org/pdf/2406.02528v1
46-
The current implementation is a proof of concept however where the sparseness is not used to improve speed or memory usage. Using sparse matrices typically does not reduce performance and can even help reduce overfitting.
47-
Defaults to `False`.
44+
Whether to use sparse random bases as described in the RandLora paper. The bases are ternary sparse bases
45+
(only containing -1, 0 and 1) where the attribution probability is 1/6 for -1 and 1 and 2/3 for 0. These
46+
sparse matrices aim to be used for matmul free computation in the future, see
47+
https://arxiv.org/pdf/2406.02528v1 The current implementation is a proof of concept however where the
48+
sparseness is not used to improve speed or memory usage. Using sparse matrices typically does not reduce
49+
performance and can even help reduce overfitting. Defaults to `False`.
4850
very_sparse (`bool`):
49-
Whether to use highly sparse random bases as described in the RandLora paper. The very sparse bases are ternary sparse bases (only containing -1, 0 and 1) given a matrix with smallest dimension d, the attribution probability is 1/√D for -1 and 1 and 1- 2/√D for 0.
50-
Using these sparse matrices can further reduce overfitting over the `sparse` alternatives but will most likely decrease performance as a results. Use carefully.
51-
Defaults to `False`.
51+
Whether to use highly sparse random bases as described in the RandLora paper. The very sparse bases are
52+
ternary sparse bases (only containing -1, 0 and 1) given a matrix with smallest dimension d, the
53+
attribution probability is 1/√D for -1 and 1 and 1- 2/√D for 0. Using these sparse matrices can further
54+
reduce overfitting over the `sparse` alternatives but will most likely decrease performance as a results.
55+
Use carefully. Defaults to `False`.
5256
randlora_dropout (`float`):
5357
The dropout probability for RandLora layers.
5458
randlora_alpha (`float`):

src/peft/tuners/randlora/layer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ class UniqueBaseGrad(torch.autograd.Function):
3030
# Memory efficent for a unique base
3131
@staticmethod
3232
def forward(ctx, randlora_A, randlora_lambda, randlora_gamma):
33-
Out = randlora_lambda[:, :, None] * randlora_A * randlora_gamma[None,]
33+
out = randlora_lambda[:, :, None] * randlora_A * randlora_gamma[None,]
3434
ctx.save_for_backward(randlora_A, randlora_lambda, randlora_gamma)
35-
return Out
35+
return out
3636

3737
@staticmethod
3838
def backward(ctx, grad_output):
@@ -242,6 +242,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
242242
self.merged_adapters.append(active_adapter)
243243

244244
def unmerge(self) -> None:
245+
"""
246+
This method unmerges all merged adapter layers from the base weights.
247+
"""
245248
if not self.merged:
246249
warnings.warn("Already unmerged. Nothing to do.")
247250
return
@@ -254,7 +257,7 @@ def unmerge(self) -> None:
254257
delta_weight = self.get_delta_weight(active_adapter)
255258
base_layer.weight.data -= delta_weight.to(orig_dtype)
256259

257-
def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
260+
def get_scaled_bases(self, adapter, device=None) -> tuple[torch.Tensor, torch.Tensor]:
258261
"""
259262
Performs scaling on the smallest random base (randlora_A) and returns randlora_A and randlora_B in the correct
260263
order to fit the target layers' dimensions
@@ -266,17 +269,17 @@ def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
266269

267270
randlora_A = self.randlora_A[adapter]
268271
randlora_B = self.randlora_B[adapter]
269-
270-
device = randlora_B.device
272+
if device is None:
273+
device = randlora_B.device
271274
dtype = randlora_B.dtype
272275

273276
# In case users wants to merge the adapter weights that are in
274277
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
275278
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
276279
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
277280

278-
randlora_lambda = self.randlora_lambda[adapter]
279-
randlora_gamma = self.randlora_gamma[adapter]
281+
randlora_lambda = self.randlora_lambda[adapter].to(device)
282+
randlora_gamma = self.randlora_gamma[adapter].to(device)
280283

281284
if cast_to_fp32:
282285
randlora_A = randlora_A.float()
@@ -290,8 +293,8 @@ def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
290293
# As adapted layers may have different shapes and RandLora contains a single shared pair of A and B matrices,
291294
# we initialize these matrices with the largest required size for each dimension.
292295
# During the forward pass, required submatrices are sliced out from the shared randlora_A and randlora_B.
293-
sliced_A = randlora_A[:, : self.num_bases, :min_dim]
294-
sliced_B = randlora_B[:max_dim, : self.num_bases, :]
296+
sliced_A = randlora_A[:, : self.num_bases, :min_dim].to(device)
297+
sliced_B = randlora_B[:max_dim, : self.num_bases, :].to(device)
295298

296299
# Flattening the matrices over the rank and number of bases dimensions is more memory efficient
297300
update_B = sliced_B.flatten(start_dim=1)
@@ -334,7 +337,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
334337
if active_adapter not in self.randlora_lambda.keys():
335338
continue
336339
dropout = self.randlora_dropout[active_adapter]
337-
update_B, update_A = self.get_scaled_bases(active_adapter)
340+
update_B, update_A = self.get_scaled_bases(active_adapter, device=x.device)
338341
x = x.to(update_A.dtype)
339342
scaling = self.scaling[active_adapter]
340343
result = result + F.linear(F.linear(dropout(x), update_B), update_A) * scaling

0 commit comments

Comments
 (0)