Skip to content

Commit 7ba3aee

Browse files
FIX Multiple issues with low precision base models
As a user, it should be possible to manually cast the base model to a lower precision dtype, float16 or bfloat16, and still have the different PEFT methods work correctly. Currently, this is not the case for many PEFT methods, as can be replicated by the added tests. To understand the problem, it helps to take a step back. By default, PEFT will treat the adapter weights with high precision, i.e. with float32. When the base model is lower precision, the user needs to pass inputs in lower precision too, as otherwise self.base_layer(x) would fail. However, this low precision input clashes with the high precision adapter weights. The solution implemented in this PR is to cast the input to a higher dtype [1]. That way, the whole adapter operation is conducted in high precision. Only once that has finished will the final result be cast to the original dtype. This should lead to better results, but it may require more memory. Note that this is how LoRA is implemented, so the changes in this PR bring the other methods more in line with what LoRA does. If the user does not want the adapter to be in float32, they can always pass autocast_adapter_dtype=False when calling get_peft_model or PeftModel.from_pretrained. This is also tested. Besides adjusting the forward method to account for these changes, the merge and unmerge methods also often had to be adjusted, as they did not correctly account for the base model dtype. Now, those methods should always conserve the original dtype of the base model. Note that if, for whatever reason, the input casting in [1] is not desired, users can use the disable_input_dtype_casting context manager to disable it (more context information on this feature can be found in PR huggingface#2353). I updated the corresponding code to be agnostic to the specific PEFT method (beforehand, it was only for LoRA).
1 parent 7320bb9 commit 7ba3aee

File tree

14 files changed

+339
-129
lines changed

14 files changed

+339
-129
lines changed

src/peft/helpers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .peft_model import PeftConfig, PeftModel
2424
from .tuners.lora import LoraLayer
25+
from .tuners.tuners_utils import BaseTunerLayer
2526

2627

2728
def update_forward_signature(model: PeftModel) -> None:
@@ -218,8 +219,6 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
218219
"""
219220
Context manager disables input dtype casting to the dtype of the weight.
220221
221-
Currently specifically works for LoRA.
222-
223222
Parameters:
224223
model (nn.Module):
225224
The model containing PEFT modules whose input dtype casting is to be adjusted.
@@ -237,7 +236,7 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
237236

238237
original_values = {}
239238
for name, module in model.named_modules():
240-
if not isinstance(module, LoraLayer):
239+
if not isinstance(module, BaseTunerLayer):
241240
continue
242241
original_values[name] = module.cast_input_dtype_enabled
243242
module.cast_input_dtype_enabled = False
@@ -246,7 +245,7 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
246245
yield
247246
finally:
248247
for name, module in model.named_modules():
249-
if not isinstance(module, LoraLayer):
248+
if not isinstance(module, BaseTunerLayer):
250249
continue
251250
if name in original_values:
252251
module.cast_input_dtype_enabled = original_values[name]

src/peft/tuners/boft/layer.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
220220
# Mark the weight as unmerged
221221
self._disable_adapters = False
222222
self.merged_adapters = []
223+
# flag to enable/disable casting of input to weight dtype during forward call
224+
self.cast_input_dtype_enabled = True
223225
self.kwargs = kwargs
224226

225227
base_layer = self.get_base_layer()
@@ -503,13 +505,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
503505
for active_adapter in adapter_names:
504506
if active_adapter in self.boft_R.keys():
505507
base_layer = self.get_base_layer()
508+
orig_dtype = base_layer.weight.dtype
506509
if safe_merge:
507510
# Note that safe_merge will be slower than the normal merge
508511
# because of the copy operation.
509512
orig_weight = base_layer.weight.data.clone()
510513
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
511514
orig_weight = torch.transpose(orig_weight, 0, 1)
512-
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
515+
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
513516
orig_weight = torch.transpose(orig_weight, 0, 1)
514517
orig_weight = orig_weight * boft_s
515518

@@ -518,16 +521,16 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
518521
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
519522
)
520523

521-
self.base_layer.weight.data = orig_weight.contiguous()
524+
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
522525
else:
523526
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
524527
orig_weight = base_layer.weight.data.clone()
525528
orig_weight = torch.transpose(orig_weight, 0, 1)
526-
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
529+
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
527530
orig_weight = torch.transpose(orig_weight, 0, 1)
528531
orig_weight = orig_weight * boft_s
529532

530-
self.base_layer.weight.data = orig_weight.contiguous()
533+
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
531534

532535
self.merged_adapters.append(active_adapter)
533536

@@ -538,17 +541,20 @@ def unmerge(self) -> None:
538541
if not self.merged:
539542
warnings.warn("Already unmerged. Nothing to do.")
540543
return
544+
541545
while len(self.merged_adapters) > 0:
542546
active_adapter = self.merged_adapters.pop()
547+
base_layer = self.get_base_layer()
548+
orig_dtype = base_layer.weight.dtype
543549
if active_adapter in self.boft_R.keys():
544550
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
545551

546-
orig_weight = self.get_base_layer().weight.data.clone()
552+
orig_weight = base_layer.weight.data.clone()
547553
orig_weight = torch.transpose(orig_weight, 0, 1)
548-
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight)
554+
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight.to(butterfly_oft_mat.dtype))
549555
orig_weight = torch.transpose(orig_weight, 0, 1)
550556

551-
self.get_base_layer().weight.data = orig_weight * (1 / boft_s)
557+
base_layer.weight.data = (orig_weight * (1 / boft_s)).to(orig_dtype)
552558

553559
def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
554560
"""
@@ -804,6 +810,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
804810
for active_adapter in adapter_names:
805811
if active_adapter in self.boft_R.keys():
806812
base_layer = self.get_base_layer()
813+
orig_dtype = base_layer.weight.dtype
807814
if safe_merge:
808815
# Note that safe_merge will be slower than the normal merge
809816
# because of the copy operation.
@@ -814,14 +821,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
814821
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
815822
)
816823
orig_weight = torch.transpose(orig_weight, 0, 1)
817-
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
824+
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
818825
orig_weight = torch.transpose(orig_weight, 0, 1)
819826
orig_weight = orig_weight * boft_s
820827
orig_weight = orig_weight.view(
821828
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
822829
)
823830

824-
self.base_layer.weight.data = orig_weight.contiguous()
831+
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
825832
else:
826833
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
827834

@@ -830,14 +837,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
830837
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
831838
)
832839
orig_weight = torch.transpose(orig_weight, 0, 1)
833-
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
840+
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
834841
orig_weight = torch.transpose(orig_weight, 0, 1)
835842
orig_weight = orig_weight * boft_s
836843
orig_weight = orig_weight.view(
837844
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
838845
)
839846

840-
self.base_layer.weight.data = orig_weight.contiguous()
847+
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
841848

842849
self.merged_adapters.append(active_adapter)
843850

@@ -850,26 +857,28 @@ def unmerge(self) -> None:
850857
return
851858
while len(self.merged_adapters) > 0:
852859
active_adapter = self.merged_adapters.pop()
860+
base_layer = self.get_base_layer()
861+
orig_dtype = base_layer.weight.dtype
853862
if active_adapter in self.boft_R.keys():
854863
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
855864

856-
orig_weight = self.get_base_layer().weight.data.clone()
865+
orig_weight = base_layer.weight.data.clone()
857866
orig_weight = orig_weight.view(
858867
self.out_features,
859-
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
868+
self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0],
860869
)
861870
orig_weight = torch.transpose(orig_weight, 0, 1)
862-
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight)
871+
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight.to(butterfly_oft_mat.dtype))
863872
orig_weight = torch.transpose(orig_weight, 0, 1)
864873
orig_weight = orig_weight * (1 / boft_s)
865874
orig_weight = orig_weight.view(
866875
self.out_features,
867876
self.in_features,
868-
self.get_base_layer().kernel_size[0],
869-
self.get_base_layer().kernel_size[0],
877+
base_layer.kernel_size[0],
878+
base_layer.kernel_size[0],
870879
)
871880

872-
self.get_base_layer().weight.data = orig_weight
881+
self.get_base_layer().weight.data = orig_weight.to(orig_dtype)
873882

874883
def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
875884
"""
@@ -968,10 +977,12 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
968977
scaled_rotated_weight = scaled_rotated_weight.view(
969978
self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0]
970979
)
980+
x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
981+
bias = self._cast_input_dtype(self.base_layer.bias, scaled_rotated_weight.dtype)
971982
result = F.conv2d(
972983
input=x,
973984
weight=scaled_rotated_weight,
974-
bias=self.base_layer.bias,
985+
bias=bias,
975986
padding=self.base_layer.padding[0],
976987
stride=self.base_layer.stride[0],
977988
)

src/peft/tuners/bone/layer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
3535
# Mark the weight as unmerged
3636
self._disable_adapters = False
3737
self.merged_adapters = []
38+
# flag to enable/disable casting of input to weight dtype during forward call
39+
self.cast_input_dtype_enabled = True
3840
self.kwargs = kwargs
3941

4042
base_layer = self.get_base_layer()
@@ -150,6 +152,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
150152
for active_adapter in adapter_names:
151153
if active_adapter in self.bone_block.keys():
152154
base_layer = self.get_base_layer()
155+
orig_dtype = base_layer.weight.dtype
153156
if safe_merge:
154157
# Note that safe_merge will be slower than the normal merge
155158
# because of the copy operation.
@@ -166,14 +169,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
166169
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
167170
)
168171

169-
self.base_layer.weight.data = orig_weight
172+
self.base_layer.weight.data = orig_weight.to(orig_dtype)
170173
else:
171174
if self.bone_fn == "bat":
172175
delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data)
173-
self.base_layer.weight.data += delta_weight
176+
self.base_layer.weight.data += delta_weight.to(orig_dtype)
174177
else:
175178
delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data)
176-
self.base_layer.weight.data = delta_weight
179+
self.base_layer.weight.data = delta_weight.to(orig_dtype)
177180
self.merged_adapters.append(active_adapter)
178181

179182
def unmerge(self) -> None:
@@ -183,16 +186,19 @@ def unmerge(self) -> None:
183186
if not self.merged:
184187
warnings.warn("Already unmerged. Nothing to do.")
185188
return
189+
186190
while len(self.merged_adapters) > 0:
187191
active_adapter = self.merged_adapters.pop()
192+
base_layer = self.get_base_layer()
193+
orig_dtype = base_layer.weight.dtype
188194
if active_adapter in self.bone_block.keys():
189195
orig_weight = self.get_base_layer().weight.data.clone()
190196
if self.bone_fn == "bat":
191197
delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True)
192198
else:
193199
delta_weight = self.get_delta_weight_bone(active_adapter, orig_weight, re=True)
194200

195-
self.get_base_layer().weight.data = delta_weight
201+
base_layer.weight.data = delta_weight.to(orig_dtype)
196202

197203
def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
198204
"""
@@ -213,12 +219,15 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens
213219

214220
if cast_to_fp32:
215221
weight_bone = weight_bone.float()
222+
orig_weight = orig_weight.to(weight_bone.dtype)
216223

217224
r = weight_bone.size(-1)
218225
if re:
219226
o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
220227
one = torch.eye(weight_bone.size(-1)).to(weight_bone.device)
228+
# inverse must be in float32, after that the dtype can be adjusted if needed
221229
inv_I_plus_b = torch.inverse(one + weight_bone)
230+
inv_I_plus_b = inv_I_plus_b.to(weight_bone.dtype)
222231
w = (o - weight_bone) @ inv_I_plus_b
223232
output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
224233
else:
@@ -318,7 +327,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
318327
delta_weight = self.get_delta_weight(active_adapter, orig_weight)
319328
orig_weight = orig_weight + delta_weight
320329

321-
result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias)
330+
x = self._cast_input_dtype(x, orig_weight.dtype)
331+
bias = self.base_layer.bias.to(orig_weight.dtype)
332+
result = F.linear(input=x, weight=orig_weight, bias=bias)
322333
else:
323334
result = self.base_layer(x, *args, **kwargs)
324335
for active_adapter in self.active_adapters:
@@ -329,6 +340,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
329340
if x.size(-1) % r != 0:
330341
padding_size = (r - x.size(-1) % r) % r
331342
x = F.pad(x, (0, padding_size))
343+
x = self._cast_input_dtype(x, bone.dtype)
332344
result = result + torch.sum(x.reshape(*x.shape[:-1], x.size(-1) // r, r), dim=-2) @ bone
333345

334346
result = result.to(previous_dtype)

src/peft/tuners/fourierft/layer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,13 @@ def reset_fourier_parameters(self, adapter_name):
8484
nn.init.zeros_(self.fourierft_spectrum[adapter_name])
8585

8686
def get_delta_weight(self, adapter) -> torch.Tensor:
87+
# careful: ifft2 does not work with float16 or bfloat16
8788
spectrum = self.fourierft_spectrum[adapter]
8889
indices = self.indices[adapter].to(spectrum.device)
89-
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device, dtype=spectrum.dtype)
90-
dense_spectrum[indices[0, :], indices[1, :]] = spectrum
90+
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device)
91+
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
9192
delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter]
92-
return delta_weight
93+
return delta_weight.to(spectrum.dtype)
9394

9495

9596
class FourierFTLinear(nn.Module, FourierFTLayer):

0 commit comments

Comments
 (0)