Skip to content

Commit a1ccc82

Browse files
Fix: Multiple PEFT methods have issues with models loaded in float16 or bfloat16 (huggingface#2433)
As a user, it should be possible to manually cast the base model to a lower precision dtype, float16 or bfloat16, and still have the different PEFT methods work correctly. Currently, this is not the case for many PEFT methods, as can be replicated by the added tests. To understand the problem, it helps to take a step back. By default, PEFT will treat the adapter weights with high precision, i.e. with float32. When the base model is lower precision, the user needs to pass inputs in lower precision too, as otherwise self.base_layer(x) would fail. However, this low precision input clashes with the high precision adapter weights. The solution implemented in this PR is to cast the input to a higher dtype [1]. That way, the whole adapter operation is conducted in high precision. Only once that has finished will the final result be cast to the original dtype. This should lead to better results, but it may require more memory. Note that this is how LoRA is implemented, so the changes in this PR bring the other methods more in line with what LoRA does. If the user does not want the adapter to be in float32, they can always pass autocast_adapter_dtype=False when calling get_peft_model or PeftModel.from_pretrained. This is also tested. Besides adjusting the forward method to account for these changes, the merge and unmerge methods also often had to be adjusted, as they did not correctly account for the base model dtype. Now, those methods should always conserve the original dtype of the base model. Note that if, for whatever reason, the input casting in [1] is not desired, users can use the disable_input_dtype_casting context manager to disable it (more context information on this feature can be found in PR huggingface#2353). I updated the corresponding code to be agnostic to the specific PEFT method (beforehand, it was only for LoRA). Note that model.merge_adapter(safe_merge=True) did not work so far, even though the argument was documented it was not actually there. This is now fixed.
1 parent c375f99 commit a1ccc82

File tree

19 files changed

+459
-153
lines changed

19 files changed

+459
-153
lines changed

docs/source/developer_guides/troubleshooting.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ Installing PEFT from source is useful for keeping up with the latest development
3939
python -m pip install git+https://github.com/huggingface/peft
4040
```
4141

42-
## ValueError: Attempting to unscale FP16 gradients
42+
## Dtype-related issues
43+
44+
### ValueError: Attempting to unscale FP16 gradients
4345

4446
This error probably occurred because the model was loaded with `torch_dtype=torch.float16` and then used in an automatic mixed precision (AMP) context, e.g. by setting `fp16=True` in the [`~transformers.Trainer`] class from 🤗 Transformers. The reason is that when using AMP, trainable weights should never use fp16. To make this work without loading the whole model in fp32, add the following to your code:
4547

@@ -75,6 +77,23 @@ Starting from PEFT verion v0.12.0, PEFT automatically promotes the dtype of adap
7577

7678
</Tip>
7779

80+
### Selecting the dtype of the adapter
81+
82+
Most PEFT methods, like LoRA, work by adding trainable adapter weights. By default, those weights are stored in float32 dtype (fp32), i.e. at a relatively high precision. Therefore, even if the base model is loaded in float16 (fp16) or bfloat16 (bf16), the adapter weights are float32. When the adapter results are calculated during the forward pass, the input will typically be in the dtype of the base model, thus it will be upcast to float32 if necessary, then cast back to the original dtype.
83+
84+
If you prefer to have the adapter weights in the lower precision of the base model, i.e. in float16 or bfloat16, you can pass `autocast_adapter_dtype=False` when creating the model ([`~get_peft_model`]) or loading the model ([`~PeftModel.from_pretrained`]). There are some advantages and disadvantages to this:
85+
86+
Advantages of half precision adapter:
87+
- computation slightly faster
88+
- slightly less memory
89+
- smaller file size of checkpoint (half the size)
90+
91+
Disadvantages of half precision adapter:
92+
- slightly worse loss
93+
- higher risk of overflow or underflow
94+
95+
Note that for most use cases, overall runtime and memory cost will be determined by the size of the base model and by the dataset, while the dtype of the PEFT adapter will only have a small impact.
96+
7897
## Bad results from a loaded PEFT model
7998

8099
There can be several reasons for getting a poor result from a loaded PEFT model which are listed below. If you're still unable to troubleshoot the problem, see if anyone else had a similar [issue](https://github.com/huggingface/peft/issues) on GitHub, and if you can't find any, open a new issue.

src/peft/helpers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .peft_model import PeftConfig, PeftModel
2424
from .tuners.lora import LoraLayer
25+
from .tuners.tuners_utils import BaseTunerLayer
2526

2627

2728
def update_forward_signature(model: PeftModel) -> None:
@@ -218,8 +219,6 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
218219
"""
219220
Context manager disables input dtype casting to the dtype of the weight.
220221
221-
Currently specifically works for LoRA.
222-
223222
Parameters:
224223
model (nn.Module):
225224
The model containing PEFT modules whose input dtype casting is to be adjusted.
@@ -237,7 +236,7 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
237236

238237
original_values = {}
239238
for name, module in model.named_modules():
240-
if not isinstance(module, LoraLayer):
239+
if not isinstance(module, BaseTunerLayer):
241240
continue
242241
original_values[name] = module.cast_input_dtype_enabled
243242
module.cast_input_dtype_enabled = False
@@ -246,7 +245,7 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
246245
yield
247246
finally:
248247
for name, module in model.named_modules():
249-
if not isinstance(module, LoraLayer):
248+
if not isinstance(module, BaseTunerLayer):
250249
continue
251250
if name in original_values:
252251
module.cast_input_dtype_enabled = original_values[name]

src/peft/tuners/boft/layer.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
220220
# Mark the weight as unmerged
221221
self._disable_adapters = False
222222
self.merged_adapters = []
223+
# flag to enable/disable casting of input to weight dtype during forward call
224+
self.cast_input_dtype_enabled = True
223225
self.kwargs = kwargs
224226

225227
base_layer = self.get_base_layer()
@@ -503,13 +505,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
503505
for active_adapter in adapter_names:
504506
if active_adapter in self.boft_R.keys():
505507
base_layer = self.get_base_layer()
508+
orig_dtype = base_layer.weight.dtype
506509
if safe_merge:
507510
# Note that safe_merge will be slower than the normal merge
508511
# because of the copy operation.
509512
orig_weight = base_layer.weight.data.clone()
510513
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
511514
orig_weight = torch.transpose(orig_weight, 0, 1)
512-
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
515+
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
513516
orig_weight = torch.transpose(orig_weight, 0, 1)
514517
orig_weight = orig_weight * boft_s
515518

@@ -518,16 +521,16 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
518521
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
519522
)
520523

521-
self.base_layer.weight.data = orig_weight.contiguous()
524+
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
522525
else:
523526
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
524527
orig_weight = base_layer.weight.data.clone()
525528
orig_weight = torch.transpose(orig_weight, 0, 1)
526-
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
529+
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
527530
orig_weight = torch.transpose(orig_weight, 0, 1)
528531
orig_weight = orig_weight * boft_s
529532

530-
self.base_layer.weight.data = orig_weight.contiguous()
533+
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
531534

532535
self.merged_adapters.append(active_adapter)
533536

@@ -538,17 +541,20 @@ def unmerge(self) -> None:
538541
if not self.merged:
539542
warnings.warn("Already unmerged. Nothing to do.")
540543
return
544+
541545
while len(self.merged_adapters) > 0:
542546
active_adapter = self.merged_adapters.pop()
547+
base_layer = self.get_base_layer()
548+
orig_dtype = base_layer.weight.dtype
543549
if active_adapter in self.boft_R.keys():
544550
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
545551

546-
orig_weight = self.get_base_layer().weight.data.clone()
552+
orig_weight = base_layer.weight.data.clone()
547553
orig_weight = torch.transpose(orig_weight, 0, 1)
548-
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight)
554+
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight.to(butterfly_oft_mat.dtype))
549555
orig_weight = torch.transpose(orig_weight, 0, 1)
550556

551-
self.get_base_layer().weight.data = orig_weight * (1 / boft_s)
557+
base_layer.weight.data = (orig_weight * (1 / boft_s)).to(orig_dtype)
552558

553559
def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
554560
"""
@@ -804,6 +810,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
804810
for active_adapter in adapter_names:
805811
if active_adapter in self.boft_R.keys():
806812
base_layer = self.get_base_layer()
813+
orig_dtype = base_layer.weight.dtype
807814
if safe_merge:
808815
# Note that safe_merge will be slower than the normal merge
809816
# because of the copy operation.
@@ -814,14 +821,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
814821
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
815822
)
816823
orig_weight = torch.transpose(orig_weight, 0, 1)
817-
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
824+
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
818825
orig_weight = torch.transpose(orig_weight, 0, 1)
819826
orig_weight = orig_weight * boft_s
820827
orig_weight = orig_weight.view(
821828
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
822829
)
823830

824-
self.base_layer.weight.data = orig_weight.contiguous()
831+
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
825832
else:
826833
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
827834

@@ -830,14 +837,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
830837
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
831838
)
832839
orig_weight = torch.transpose(orig_weight, 0, 1)
833-
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
840+
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
834841
orig_weight = torch.transpose(orig_weight, 0, 1)
835842
orig_weight = orig_weight * boft_s
836843
orig_weight = orig_weight.view(
837844
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
838845
)
839846

840-
self.base_layer.weight.data = orig_weight.contiguous()
847+
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
841848

842849
self.merged_adapters.append(active_adapter)
843850

@@ -850,26 +857,28 @@ def unmerge(self) -> None:
850857
return
851858
while len(self.merged_adapters) > 0:
852859
active_adapter = self.merged_adapters.pop()
860+
base_layer = self.get_base_layer()
861+
orig_dtype = base_layer.weight.dtype
853862
if active_adapter in self.boft_R.keys():
854863
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
855864

856-
orig_weight = self.get_base_layer().weight.data.clone()
865+
orig_weight = base_layer.weight.data.clone()
857866
orig_weight = orig_weight.view(
858867
self.out_features,
859-
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
868+
self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0],
860869
)
861870
orig_weight = torch.transpose(orig_weight, 0, 1)
862-
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight)
871+
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight.to(butterfly_oft_mat.dtype))
863872
orig_weight = torch.transpose(orig_weight, 0, 1)
864873
orig_weight = orig_weight * (1 / boft_s)
865874
orig_weight = orig_weight.view(
866875
self.out_features,
867876
self.in_features,
868-
self.get_base_layer().kernel_size[0],
869-
self.get_base_layer().kernel_size[0],
877+
base_layer.kernel_size[0],
878+
base_layer.kernel_size[0],
870879
)
871880

872-
self.get_base_layer().weight.data = orig_weight
881+
base_layer.weight.data = orig_weight.to(orig_dtype)
873882

874883
def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
875884
"""
@@ -968,10 +977,12 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
968977
scaled_rotated_weight = scaled_rotated_weight.view(
969978
self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0]
970979
)
980+
x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
981+
bias = self._cast_input_dtype(self.base_layer.bias, scaled_rotated_weight.dtype)
971982
result = F.conv2d(
972983
input=x,
973984
weight=scaled_rotated_weight,
974-
bias=self.base_layer.bias,
985+
bias=bias,
975986
padding=self.base_layer.padding[0],
976987
stride=self.base_layer.stride[0],
977988
)

src/peft/tuners/bone/layer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
3535
# Mark the weight as unmerged
3636
self._disable_adapters = False
3737
self.merged_adapters = []
38+
# flag to enable/disable casting of input to weight dtype during forward call
39+
self.cast_input_dtype_enabled = True
3840
self.kwargs = kwargs
3941

4042
base_layer = self.get_base_layer()
@@ -150,6 +152,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
150152
for active_adapter in adapter_names:
151153
if active_adapter in self.bone_block.keys():
152154
base_layer = self.get_base_layer()
155+
orig_dtype = base_layer.weight.dtype
153156
if safe_merge:
154157
# Note that safe_merge will be slower than the normal merge
155158
# because of the copy operation.
@@ -166,14 +169,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
166169
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
167170
)
168171

169-
self.base_layer.weight.data = orig_weight
172+
base_layer.weight.data = orig_weight.to(orig_dtype)
170173
else:
171174
if self.bone_fn == "bat":
172175
delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data)
173-
self.base_layer.weight.data += delta_weight
176+
base_layer.weight.data += delta_weight.to(orig_dtype)
174177
else:
175178
delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data)
176-
self.base_layer.weight.data = delta_weight
179+
base_layer.weight.data = delta_weight.to(orig_dtype)
177180
self.merged_adapters.append(active_adapter)
178181

179182
def unmerge(self) -> None:
@@ -183,16 +186,19 @@ def unmerge(self) -> None:
183186
if not self.merged:
184187
warnings.warn("Already unmerged. Nothing to do.")
185188
return
189+
186190
while len(self.merged_adapters) > 0:
187191
active_adapter = self.merged_adapters.pop()
192+
base_layer = self.get_base_layer()
193+
orig_dtype = base_layer.weight.dtype
188194
if active_adapter in self.bone_block.keys():
189195
orig_weight = self.get_base_layer().weight.data.clone()
190196
if self.bone_fn == "bat":
191197
delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True)
192198
else:
193199
delta_weight = self.get_delta_weight_bone(active_adapter, orig_weight, re=True)
194200

195-
self.get_base_layer().weight.data = delta_weight
201+
base_layer.weight.data = delta_weight.to(orig_dtype)
196202

197203
def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
198204
"""
@@ -213,12 +219,15 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens
213219

214220
if cast_to_fp32:
215221
weight_bone = weight_bone.float()
222+
orig_weight = orig_weight.to(weight_bone.dtype)
216223

217224
r = weight_bone.size(-1)
218225
if re:
219226
o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
220227
one = torch.eye(weight_bone.size(-1)).to(weight_bone.device)
228+
# inverse must be in float32, after that the dtype can be adjusted if needed
221229
inv_I_plus_b = torch.inverse(one + weight_bone)
230+
inv_I_plus_b = inv_I_plus_b.to(weight_bone.dtype)
222231
w = (o - weight_bone) @ inv_I_plus_b
223232
output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
224233
else:
@@ -318,7 +327,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
318327
delta_weight = self.get_delta_weight(active_adapter, orig_weight)
319328
orig_weight = orig_weight + delta_weight
320329

321-
result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias)
330+
x = self._cast_input_dtype(x, orig_weight.dtype)
331+
bias = self._cast_input_dtype(self.base_layer.bias, orig_weight.dtype)
332+
result = F.linear(input=x, weight=orig_weight, bias=bias)
322333
else:
323334
result = self.base_layer(x, *args, **kwargs)
324335
for active_adapter in self.active_adapters:
@@ -329,6 +340,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
329340
if x.size(-1) % r != 0:
330341
padding_size = (r - x.size(-1) % r) % r
331342
x = F.pad(x, (0, padding_size))
343+
x = self._cast_input_dtype(x, bone.dtype)
332344
result = result + torch.sum(x.reshape(*x.shape[:-1], x.size(-1) // r, r), dim=-2) @ bone
333345

334346
result = result.to(previous_dtype)

src/peft/tuners/fourierft/layer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,13 @@ def reset_fourier_parameters(self, adapter_name):
8484
nn.init.zeros_(self.fourierft_spectrum[adapter_name])
8585

8686
def get_delta_weight(self, adapter) -> torch.Tensor:
87+
# careful: ifft2 does not work with float16 or bfloat16
8788
spectrum = self.fourierft_spectrum[adapter]
8889
indices = self.indices[adapter].to(spectrum.device)
89-
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device, dtype=spectrum.dtype)
90-
dense_spectrum[indices[0, :], indices[1, :]] = spectrum
90+
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device)
91+
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
9192
delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter]
92-
return delta_weight
93+
return delta_weight.to(spectrum.dtype)
9394

9495

9596
class FourierFTLinear(nn.Module, FourierFTLayer):

0 commit comments

Comments
 (0)