Skip to content

Commit 68c0d1b

Browse files
committed
Refactor to better support LoRA variants (#2443)
Provide a framework to add LoRA variants like DoRA. This way, it will be easier in the future to add variants to LoRA without the need to either copy all the LoRA code and make small changes, or clutter the existing LoRA code with countless if statements. Adding more LoRA variants in the future will not balloon the size of the proper LoRA implementation. The new approach is to add LoraVariant subclass to peft/tuners/lora/variants.py. Typically, this will require one subclass per supported layer type. The subclass should basically be stateless by only implementing static methods (which will facilitate composition, e.g. if a new variant can be combined with DoRA). The subclass needs to provide a set of methods (init, merge_safe, merge_unsafe, unmerge, forward). In the LoRA code itself, these methods will be called if the corresponding adapter uses the subclass. The choice which variant to dispatch to is determined by the resolve_lora_variant method. It is called during update_layer and can be overridden by each LoRA layer (so e.g. lora.Linear dispatches to another class than lora.Embedding, or bnb LoRA layers could theoretically dispatch to a different class than normal lora.Linear). For now, the only LoRA variant is DoRA. This has been refactored to use the new approach.
1 parent 40e7071 commit 68c0d1b

File tree

6 files changed

+576
-406
lines changed

6 files changed

+576
-406
lines changed

src/peft/tuners/lora/bnb.py

Lines changed: 55 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from peft.utils.integrations import dequantize_bnb_weight
2525
from peft.utils.other import transpose
2626

27-
from .layer import LoraLayer
27+
from .layer import LoraLayer, LoraVariant
2828

2929

3030
if is_bnb_available():
@@ -60,6 +60,14 @@ def __init__(
6060
lora_bias=lora_bias,
6161
)
6262

63+
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
64+
if not use_dora:
65+
return None
66+
67+
from .variants import DoraLinearVariant
68+
69+
return DoraLinearVariant()
70+
6371
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
6472
"""
6573
Merge the active adapter weights into the base weights
@@ -85,7 +93,6 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
8593
warnings.warn(
8694
"Merge lora module to 8-bit linear may get different generations due to rounding errors."
8795
)
88-
lora_data = self.get_delta_weight(active_adapter)
8996

9097
weight = self.get_base_layer().weight
9198
state = self.get_base_layer().state
@@ -95,22 +102,11 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
95102
# Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8
96103
# dequantization directly
97104
output = dequantize_bnb_weight(weight, state=state)
98-
if not self.use_dora[active_adapter]:
105+
if active_adapter not in self.lora_variant: # vanilla LoRA
106+
lora_data = self.get_delta_weight(active_adapter)
99107
w_data = output.to(lora_data.dtype).to(lora_data.device) + lora_data
100108
else:
101-
# handle dora
102-
# since output already includes scaling, set it to 1 here
103-
weight_norm = (
104-
self.lora_magnitude_vector[active_adapter]
105-
.get_weight_norm(output, lora_data, scaling=1)
106-
.detach()
107-
)
108-
# We need to cache weight_norm because it has to be based on the original weights. We
109-
# cannot calculate it on the fly based on the merged weights when unmerging because its a
110-
# different value
111-
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
112-
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
113-
w_data = dora_factor.view(-1, 1) * (output + lora_data)
109+
w_data = self.lora_variant[active_adapter].merge_safe(self, active_adapter, output)
114110

115111
if safe_merge and not torch.isfinite(w_data).all():
116112
raise ValueError(
@@ -120,6 +116,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
120116
self.get_base_layer().weight = bnb.nn.Int8Params(
121117
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
122118
).to(weight.device)
119+
123120
if self.lora_bias[active_adapter]:
124121
bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias
125122
if safe_merge and not torch.isfinite(bias_data):
@@ -146,20 +143,18 @@ def unmerge(self) -> None:
146143
warnings.warn(
147144
"Unmerge lora module to 8-bit linear may get different generations due to rounding errors."
148145
)
149-
lora_data = self.get_delta_weight(active_adapter)
150146

151147
weight = self.get_base_layer().weight
152148
state = self.get_base_layer().state
153149
if state.SCB is None:
154150
state.SCB = weight.SCB
155151
output = dequantize_bnb_weight(weight, state=state)
156152

157-
if not self.use_dora[active_adapter]:
153+
if active_adapter not in self.lora_variant: # vanilla LoRA
154+
lora_data = self.get_delta_weight(active_adapter)
158155
w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data
159156
else:
160-
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
161-
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
162-
w_data = output.data / dora_factor.view(-1, 1) - lora_data
157+
w_data = self.lora_variant[active_adapter].unmerge(self, active_adapter, output)
163158

164159
self.get_base_layer().weight = bnb.nn.Int8Params(
165160
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
@@ -243,26 +238,20 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
243238
expected_dtype = result.dtype
244239
x = self._cast_input_dtype(x, lora_A.weight.dtype)
245240

246-
if not self.use_dora[active_adapter]:
241+
if active_adapter not in self.lora_variant: # vanilla LoRA
247242
output = lora_B(lora_A(dropout(x))) * scaling
243+
if requires_conversion:
244+
output = output.to(expected_dtype)
245+
result = result + output
248246
else:
249-
if isinstance(dropout, torch.nn.Identity) or not self.training:
250-
base_result = result
251-
else:
252-
x = dropout(x)
253-
base_result = None
254-
255-
output = self.lora_magnitude_vector[active_adapter](
256-
x,
257-
lora_A=lora_A,
258-
lora_B=lora_B,
259-
scaling=scaling,
260-
base_layer=self.get_base_layer(),
261-
base_result=base_result,
247+
result = self.lora_variant[active_adapter].forward(
248+
self,
249+
active_adapter=active_adapter,
250+
x=x,
251+
result=result,
262252
)
263-
if requires_conversion:
264-
output = output.to(expected_dtype)
265-
result = result + output
253+
if requires_conversion:
254+
result = result.to(expected_dtype)
266255

267256
return result
268257

@@ -326,6 +315,14 @@ def __init__(
326315
lora_bias=lora_bias,
327316
)
328317

318+
def resolve_lora_variant(self, *, use_dora: bool) -> Optional[LoraVariant]:
319+
if not use_dora:
320+
return None
321+
322+
from .variants import DoraLinearVariant
323+
324+
return DoraLinearVariant()
325+
329326
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
330327
"""
331328
Merge the active adapter weights into the base weights
@@ -354,37 +351,27 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
354351
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
355352
weight = self.get_base_layer().weight
356353
kwargs = weight.__dict__
357-
lora_data = self.get_delta_weight(active_adapter)
358354

359355
output = dequantize_bnb_weight(weight, state=weight.quant_state)
360-
if not self.use_dora[active_adapter]:
356+
if active_adapter not in self.lora_variant: # vanilla LoRA
357+
lora_data = self.get_delta_weight(active_adapter)
361358
w_data = output + lora_data
362359
else:
363-
# handle dora
364-
# since output already includes scaling, set it to 1 here
365-
weight_norm = (
366-
self.lora_magnitude_vector[active_adapter]
367-
.get_weight_norm(output, lora_data, scaling=1)
368-
.detach()
369-
)
370-
# We need to cache weight_norm because it has to be based on the original weights. We
371-
# cannot calculate it on the fly based on the merged weights when unmerging because its a
372-
# different value
373-
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
374-
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
375-
w_data = dora_factor.view(-1, 1) * (output + lora_data)
360+
w_data = self.lora_variant[active_adapter].merge_safe(self, active_adapter, output)
376361

377362
if safe_merge and not torch.isfinite(w_data).all():
378363
raise ValueError(
379364
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
380365
)
366+
381367
if "bnb_quantized" in kwargs:
382368
kwargs["bnb_quantized"] = False
383369
kwargs["requires_grad"] = False
384370
kwargs.pop("data", None)
385371
# torch.compile can introduce attributes preceded by '_', remove them
386372
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
387373
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
374+
388375
if self.lora_bias[active_adapter]:
389376
bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias
390377
if safe_merge and not torch.isfinite(bias_data):
@@ -411,23 +398,22 @@ def unmerge(self) -> None:
411398
"Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
412399
)
413400

414-
lora_data = self.get_delta_weight(active_adapter)
415401
weight = self.get_base_layer().weight
416402
kwargs = weight.__dict__
417403
output = dequantize_bnb_weight(weight, state=weight.quant_state)
418404

419-
if not self.use_dora[active_adapter]:
405+
if active_adapter not in self.lora_variant: # vanilla LoRA
406+
lora_data = self.get_delta_weight(active_adapter)
420407
w_data = output - lora_data
421408
else:
422-
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
423-
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
424-
w_data = output.data / dora_factor.view(-1, 1) - lora_data
409+
w_data = self.lora_variant[active_adapter].unmerge(self, active_adapter, output)
425410

426411
if "bnb_quantized" in kwargs:
427412
kwargs["bnb_quantized"] = False
428413
kwargs["requires_grad"] = False
429414
kwargs.pop("data", None)
430415
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
416+
431417
if self.lora_bias[active_adapter]:
432418
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias
433419

@@ -512,26 +498,20 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
512498
expected_dtype = result.dtype
513499
x = self._cast_input_dtype(x, lora_A.weight.dtype)
514500

515-
if not self.use_dora[active_adapter]:
501+
if active_adapter not in self.lora_variant: # vanilla LoRA
516502
output = lora_B(lora_A(dropout(x))) * scaling
503+
if requires_conversion:
504+
output = output.to(expected_dtype)
505+
result = result + output
517506
else:
518-
if isinstance(dropout, torch.nn.Identity) or not self.training:
519-
base_result = result
520-
else:
521-
x = dropout(x)
522-
base_result = None
523-
524-
output = self.lora_magnitude_vector[active_adapter](
525-
x,
526-
lora_A=lora_A,
527-
lora_B=lora_B,
528-
scaling=scaling,
529-
base_layer=self.get_base_layer(),
530-
base_result=base_result,
507+
result = self.lora_variant[active_adapter].forward(
508+
self,
509+
active_adapter=active_adapter,
510+
x=x,
511+
result=result,
531512
)
532-
if requires_conversion:
533-
output = output.to(expected_dtype)
534-
result = result + output
513+
if requires_conversion:
514+
result = result.to(expected_dtype)
535515

536516
return result
537517

src/peft/tuners/lora/hqq.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
2424
from peft.utils.other import transpose
2525

26-
from .layer import LoraLayer
26+
from .layer import LoraLayer, LoraVariant
2727

2828

2929
if is_hqq_available():
@@ -63,6 +63,14 @@ def __init__(
6363
lora_bias=lora_bias,
6464
)
6565

66+
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
67+
if not use_dora:
68+
return None
69+
70+
from .variants import DoraLinearVariant
71+
72+
return DoraLinearVariant()
73+
6674
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
6775
"""
6876
Merge the active adapter weights into the base weights
@@ -87,26 +95,19 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
8795

8896
layer = self.get_base_layer()
8997
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
90-
lora_data = self.get_delta_weight(active_adapter)
9198

9299
output = layer.dequantize()
93-
if not self.use_dora[active_adapter]:
100+
if active_adapter not in self.lora_variant: # vanilla LoRA
101+
lora_data = self.get_delta_weight(active_adapter)
94102
w_data = output + lora_data
95103
else:
96-
# handle dora
97-
# since output already includes scaling, set it to 1 here
98-
weight_norm = self._get_weight_norm(output, lora_data, scaling=1).detach()
99-
# We need to cache weight_norm because it has to be based on the original weights. We
100-
# cannot calculate it on the fly based on the merged weights when unmerging because its a
101-
# different value
102-
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
103-
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
104-
w_data = dora_factor.view(-1, 1) * (output + lora_data)
104+
w_data = self.lora_variant[active_adapter].merge_safe(self, active_adapter, output)
105105

106106
if safe_merge and not torch.isfinite(w_data).all():
107107
raise ValueError(
108108
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
109109
)
110+
110111
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
111112
quant_config.pop("offload_meta", None)
112113
new_hqq_layer.quantize(w_data, **quant_config)
@@ -126,17 +127,15 @@ def unmerge(self) -> None:
126127
if active_adapter not in self.lora_A.keys():
127128
continue
128129

129-
lora_data = self.get_delta_weight(active_adapter)
130130
layer = self.get_base_layer()
131131
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
132132
output = layer.dequantize()
133133

134-
if not self.use_dora[active_adapter]:
135-
w_data = output - lora_data
134+
if active_adapter not in self.lora_variant: # vanilla LoRA
135+
lora_data = self.get_delta_weight(active_adapter)
136+
w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data
136137
else:
137-
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
138-
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
139-
w_data = output.data / dora_factor.view(-1, 1) - lora_data
138+
w_data = self.lora_variant[active_adapter].unmerge(self, active_adapter, output)
140139

141140
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
142141
quant_config.pop("offload_meta", None)
@@ -218,23 +217,16 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
218217
expected_dtype = result.dtype
219218
x = self._cast_input_dtype(x, lora_A.weight.dtype)
220219

221-
if not self.use_dora[active_adapter]:
220+
if active_adapter not in self.lora_variant: # vanilla LoRA
222221
result = result + lora_B(lora_A(dropout(x))) * scaling
223222
else:
224-
if isinstance(dropout, torch.nn.Identity) or not self.training:
225-
base_result = result
226-
else:
227-
x = dropout(x)
228-
base_result = None
229-
230-
result = result + self.lora_magnitude_vector[active_adapter](
231-
x,
232-
lora_A=lora_A,
233-
lora_B=lora_B,
234-
scaling=scaling,
235-
base_layer=self.get_base_layer(),
236-
base_result=base_result,
223+
result = self.lora_variant[active_adapter].forward(
224+
self,
225+
active_adapter=active_adapter,
226+
x=x,
227+
result=result,
237228
)
229+
238230
if requires_conversion:
239231
result = result.to(expected_dtype)
240232

0 commit comments

Comments
 (0)