2424from peft .utils .integrations import dequantize_bnb_weight
2525from peft .utils .other import transpose
2626
27- from .layer import LoraLayer
27+ from .layer import LoraLayer , LoraVariant
2828
2929
3030if is_bnb_available ():
@@ -60,6 +60,14 @@ def __init__(
6060 lora_bias = lora_bias ,
6161 )
6262
63+ def resolve_lora_variant (self , * , use_dora : bool , ** kwargs ) -> Optional [LoraVariant ]:
64+ if not use_dora :
65+ return None
66+
67+ from .variants import DoraLinearVariant
68+
69+ return DoraLinearVariant ()
70+
6371 def merge (self , safe_merge : bool = False , adapter_names : Optional [list [str ]] = None ) -> None :
6472 """
6573 Merge the active adapter weights into the base weights
@@ -85,7 +93,6 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
8593 warnings .warn (
8694 "Merge lora module to 8-bit linear may get different generations due to rounding errors."
8795 )
88- lora_data = self .get_delta_weight (active_adapter )
8996
9097 weight = self .get_base_layer ().weight
9198 state = self .get_base_layer ().state
@@ -95,22 +102,11 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
95102 # Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8
96103 # dequantization directly
97104 output = dequantize_bnb_weight (weight , state = state )
98- if not self .use_dora [active_adapter ]:
105+ if active_adapter not in self .lora_variant : # vanilla LoRA
106+ lora_data = self .get_delta_weight (active_adapter )
99107 w_data = output .to (lora_data .dtype ).to (lora_data .device ) + lora_data
100108 else :
101- # handle dora
102- # since output already includes scaling, set it to 1 here
103- weight_norm = (
104- self .lora_magnitude_vector [active_adapter ]
105- .get_weight_norm (output , lora_data , scaling = 1 )
106- .detach ()
107- )
108- # We need to cache weight_norm because it has to be based on the original weights. We
109- # cannot calculate it on the fly based on the merged weights when unmerging because its a
110- # different value
111- self ._cache_store (f"{ active_adapter } -weight_norm" , weight_norm )
112- dora_factor = self .lora_magnitude_vector [active_adapter ].weight / weight_norm
113- w_data = dora_factor .view (- 1 , 1 ) * (output + lora_data )
109+ w_data = self .lora_variant [active_adapter ].merge_safe (self , active_adapter , output )
114110
115111 if safe_merge and not torch .isfinite (w_data ).all ():
116112 raise ValueError (
@@ -120,6 +116,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
120116 self .get_base_layer ().weight = bnb .nn .Int8Params (
121117 w_data .to ("cpu" ), requires_grad = False , has_fp16_weights = weight .has_fp16_weights
122118 ).to (weight .device )
119+
123120 if self .lora_bias [active_adapter ]:
124121 bias_data = self .get_base_layer ().bias .data + self .lora_B [active_adapter ].bias
125122 if safe_merge and not torch .isfinite (bias_data ):
@@ -146,20 +143,18 @@ def unmerge(self) -> None:
146143 warnings .warn (
147144 "Unmerge lora module to 8-bit linear may get different generations due to rounding errors."
148145 )
149- lora_data = self .get_delta_weight (active_adapter )
150146
151147 weight = self .get_base_layer ().weight
152148 state = self .get_base_layer ().state
153149 if state .SCB is None :
154150 state .SCB = weight .SCB
155151 output = dequantize_bnb_weight (weight , state = state )
156152
157- if not self .use_dora [active_adapter ]:
153+ if active_adapter not in self .lora_variant : # vanilla LoRA
154+ lora_data = self .get_delta_weight (active_adapter )
158155 w_data = output .to (lora_data .dtype ).to (lora_data .device ) - lora_data
159156 else :
160- weight_norm = self ._cache_pop (f"{ active_adapter } -weight_norm" )
161- dora_factor = self .lora_magnitude_vector [active_adapter ].weight / weight_norm
162- w_data = output .data / dora_factor .view (- 1 , 1 ) - lora_data
157+ w_data = self .lora_variant [active_adapter ].unmerge (self , active_adapter , output )
163158
164159 self .get_base_layer ().weight = bnb .nn .Int8Params (
165160 w_data .to ("cpu" ), requires_grad = False , has_fp16_weights = weight .has_fp16_weights
@@ -243,26 +238,20 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
243238 expected_dtype = result .dtype
244239 x = self ._cast_input_dtype (x , lora_A .weight .dtype )
245240
246- if not self .use_dora [ active_adapter ]:
241+ if active_adapter not in self .lora_variant : # vanilla LoRA
247242 output = lora_B (lora_A (dropout (x ))) * scaling
243+ if requires_conversion :
244+ output = output .to (expected_dtype )
245+ result = result + output
248246 else :
249- if isinstance (dropout , torch .nn .Identity ) or not self .training :
250- base_result = result
251- else :
252- x = dropout (x )
253- base_result = None
254-
255- output = self .lora_magnitude_vector [active_adapter ](
256- x ,
257- lora_A = lora_A ,
258- lora_B = lora_B ,
259- scaling = scaling ,
260- base_layer = self .get_base_layer (),
261- base_result = base_result ,
247+ result = self .lora_variant [active_adapter ].forward (
248+ self ,
249+ active_adapter = active_adapter ,
250+ x = x ,
251+ result = result ,
262252 )
263- if requires_conversion :
264- output = output .to (expected_dtype )
265- result = result + output
253+ if requires_conversion :
254+ result = result .to (expected_dtype )
266255
267256 return result
268257
@@ -326,6 +315,14 @@ def __init__(
326315 lora_bias = lora_bias ,
327316 )
328317
318+ def resolve_lora_variant (self , * , use_dora : bool ) -> Optional [LoraVariant ]:
319+ if not use_dora :
320+ return None
321+
322+ from .variants import DoraLinearVariant
323+
324+ return DoraLinearVariant ()
325+
329326 def merge (self , safe_merge : bool = False , adapter_names : Optional [list [str ]] = None ) -> None :
330327 """
331328 Merge the active adapter weights into the base weights
@@ -354,37 +351,27 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
354351 # Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
355352 weight = self .get_base_layer ().weight
356353 kwargs = weight .__dict__
357- lora_data = self .get_delta_weight (active_adapter )
358354
359355 output = dequantize_bnb_weight (weight , state = weight .quant_state )
360- if not self .use_dora [active_adapter ]:
356+ if active_adapter not in self .lora_variant : # vanilla LoRA
357+ lora_data = self .get_delta_weight (active_adapter )
361358 w_data = output + lora_data
362359 else :
363- # handle dora
364- # since output already includes scaling, set it to 1 here
365- weight_norm = (
366- self .lora_magnitude_vector [active_adapter ]
367- .get_weight_norm (output , lora_data , scaling = 1 )
368- .detach ()
369- )
370- # We need to cache weight_norm because it has to be based on the original weights. We
371- # cannot calculate it on the fly based on the merged weights when unmerging because its a
372- # different value
373- self ._cache_store (f"{ active_adapter } -weight_norm" , weight_norm )
374- dora_factor = self .lora_magnitude_vector [active_adapter ].weight / weight_norm
375- w_data = dora_factor .view (- 1 , 1 ) * (output + lora_data )
360+ w_data = self .lora_variant [active_adapter ].merge_safe (self , active_adapter , output )
376361
377362 if safe_merge and not torch .isfinite (w_data ).all ():
378363 raise ValueError (
379364 f"NaNs detected in the merged weights. The adapter { active_adapter } seems to be broken"
380365 )
366+
381367 if "bnb_quantized" in kwargs :
382368 kwargs ["bnb_quantized" ] = False
383369 kwargs ["requires_grad" ] = False
384370 kwargs .pop ("data" , None )
385371 # torch.compile can introduce attributes preceded by '_', remove them
386372 kwargs = {k : v for k , v in kwargs .items () if not k .startswith ("_" )}
387373 self .get_base_layer ().weight = bnb .nn .Params4bit (w_data .to ("cpu" ), ** kwargs ).to (weight .device )
374+
388375 if self .lora_bias [active_adapter ]:
389376 bias_data = self .get_base_layer ().bias .data + self .lora_B [active_adapter ].bias
390377 if safe_merge and not torch .isfinite (bias_data ):
@@ -411,23 +398,22 @@ def unmerge(self) -> None:
411398 "Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
412399 )
413400
414- lora_data = self .get_delta_weight (active_adapter )
415401 weight = self .get_base_layer ().weight
416402 kwargs = weight .__dict__
417403 output = dequantize_bnb_weight (weight , state = weight .quant_state )
418404
419- if not self .use_dora [active_adapter ]:
405+ if active_adapter not in self .lora_variant : # vanilla LoRA
406+ lora_data = self .get_delta_weight (active_adapter )
420407 w_data = output - lora_data
421408 else :
422- weight_norm = self ._cache_pop (f"{ active_adapter } -weight_norm" )
423- dora_factor = self .lora_magnitude_vector [active_adapter ].weight / weight_norm
424- w_data = output .data / dora_factor .view (- 1 , 1 ) - lora_data
409+ w_data = self .lora_variant [active_adapter ].unmerge (self , active_adapter , output )
425410
426411 if "bnb_quantized" in kwargs :
427412 kwargs ["bnb_quantized" ] = False
428413 kwargs ["requires_grad" ] = False
429414 kwargs .pop ("data" , None )
430415 self .get_base_layer ().weight = bnb .nn .Params4bit (w_data .to ("cpu" ), ** kwargs ).to (weight .device )
416+
431417 if self .lora_bias [active_adapter ]:
432418 self .get_base_layer ().bias .data -= self .lora_B [active_adapter ].bias
433419
@@ -512,26 +498,20 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
512498 expected_dtype = result .dtype
513499 x = self ._cast_input_dtype (x , lora_A .weight .dtype )
514500
515- if not self .use_dora [ active_adapter ]:
501+ if active_adapter not in self .lora_variant : # vanilla LoRA
516502 output = lora_B (lora_A (dropout (x ))) * scaling
503+ if requires_conversion :
504+ output = output .to (expected_dtype )
505+ result = result + output
517506 else :
518- if isinstance (dropout , torch .nn .Identity ) or not self .training :
519- base_result = result
520- else :
521- x = dropout (x )
522- base_result = None
523-
524- output = self .lora_magnitude_vector [active_adapter ](
525- x ,
526- lora_A = lora_A ,
527- lora_B = lora_B ,
528- scaling = scaling ,
529- base_layer = self .get_base_layer (),
530- base_result = base_result ,
507+ result = self .lora_variant [active_adapter ].forward (
508+ self ,
509+ active_adapter = active_adapter ,
510+ x = x ,
511+ result = result ,
531512 )
532- if requires_conversion :
533- output = output .to (expected_dtype )
534- result = result + output
513+ if requires_conversion :
514+ result = result .to (expected_dtype )
535515
536516 return result
537517
0 commit comments