Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion docs/source/developer_guides/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ Installing PEFT from source is useful for keeping up with the latest development
python -m pip install git+https://github.com/huggingface/peft
```

## ValueError: Attempting to unscale FP16 gradients
## Dtype-related issues

### ValueError: Attempting to unscale FP16 gradients

This error probably occurred because the model was loaded with `torch_dtype=torch.float16` and then used in an automatic mixed precision (AMP) context, e.g. by setting `fp16=True` in the [`~transformers.Trainer`] class from 🤗 Transformers. The reason is that when using AMP, trainable weights should never use fp16. To make this work without loading the whole model in fp32, add the following to your code:

Expand Down Expand Up @@ -75,6 +77,23 @@ Starting from PEFT verion v0.12.0, PEFT automatically promotes the dtype of adap

</Tip>

### Selecting the dtype of the adapter

Most PEFT methods, like LoRA, work by adding trainable adapter weights. By default, those weights are stored in float32 dtype (fp32), i.e. at a relatively high precision. Therefore, even if the base model is loaded in float16 (fp16) or bfloat16 (bf16), the adapter weights are float32. When the adapter results are calculated during the forward pass, the input will typically be in the dtype of the base model, thus it will be upcast to float32 if necessary, then cast back to the original dtype.

If you prefer to have the adapter weights in the lower precision of the base model, i.e. in float16 or bfloat16, you can pass `autocast_adapter_dtype=False` when creating the model ([`~get_peft_model`]) or loading the model ([`~PeftModel.from_pretrained`]). There are some advantages and disadvantages to this:

Advantages of half precision adapter:
- computation slightly faster
- slightly less memory
- smaller file size of checkpoint (half the size)

Disadvantages of half precision adapter:
- slightly worse loss
- higher risk of overflow or underflow

Note that for most use cases, overall runtime and memory cost will be determined by the size of the base model and by the dataset, while the dtype of the PEFT adapter will only have a small impact.

## Bad results from a loaded PEFT model

There can be several reasons for getting a poor result from a loaded PEFT model which are listed below. If you're still unable to troubleshoot the problem, see if anyone else had a similar [issue](https://github.com/huggingface/peft/issues) on GitHub, and if you can't find any, open a new issue.
Expand Down
7 changes: 3 additions & 4 deletions src/peft/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .peft_model import PeftConfig, PeftModel
from .tuners.lora import LoraLayer
from .tuners.tuners_utils import BaseTunerLayer


def update_forward_signature(model: PeftModel) -> None:
Expand Down Expand Up @@ -218,8 +219,6 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
"""
Context manager disables input dtype casting to the dtype of the weight.

Currently specifically works for LoRA.

Parameters:
model (nn.Module):
The model containing PEFT modules whose input dtype casting is to be adjusted.
Expand All @@ -237,7 +236,7 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):

original_values = {}
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
if not isinstance(module, BaseTunerLayer):
continue
original_values[name] = module.cast_input_dtype_enabled
module.cast_input_dtype_enabled = False
Expand All @@ -246,7 +245,7 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
yield
finally:
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
if not isinstance(module, BaseTunerLayer):
continue
if name in original_values:
module.cast_input_dtype_enabled = original_values[name]
47 changes: 29 additions & 18 deletions src/peft/tuners/boft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled = True
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand Down Expand Up @@ -503,13 +505,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
for active_adapter in adapter_names:
if active_adapter in self.boft_R.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = base_layer.weight.data.clone()
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s

Expand All @@ -518,16 +521,16 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

self.base_layer.weight.data = orig_weight.contiguous()
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
else:
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
orig_weight = base_layer.weight.data.clone()
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s

self.base_layer.weight.data = orig_weight.contiguous()
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)

self.merged_adapters.append(active_adapter)

Expand All @@ -538,17 +541,20 @@ def unmerge(self) -> None:
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return

while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if active_adapter in self.boft_R.keys():
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)

orig_weight = self.get_base_layer().weight.data.clone()
orig_weight = base_layer.weight.data.clone()
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight)
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)

self.get_base_layer().weight.data = orig_weight * (1 / boft_s)
base_layer.weight.data = (orig_weight * (1 / boft_s)).to(orig_dtype)

def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down Expand Up @@ -804,6 +810,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
for active_adapter in adapter_names:
if active_adapter in self.boft_R.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
Expand All @@ -814,14 +821,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s
orig_weight = orig_weight.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
)

self.base_layer.weight.data = orig_weight.contiguous()
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
else:
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)

Expand All @@ -830,14 +837,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s
orig_weight = orig_weight.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
)

self.base_layer.weight.data = orig_weight.contiguous()
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)

self.merged_adapters.append(active_adapter)

Expand All @@ -850,26 +857,28 @@ def unmerge(self) -> None:
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if active_adapter in self.boft_R.keys():
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)

orig_weight = self.get_base_layer().weight.data.clone()
orig_weight = base_layer.weight.data.clone()
orig_weight = orig_weight.view(
self.out_features,
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0],
)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight)
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * (1 / boft_s)
orig_weight = orig_weight.view(
self.out_features,
self.in_features,
self.get_base_layer().kernel_size[0],
self.get_base_layer().kernel_size[0],
base_layer.kernel_size[0],
base_layer.kernel_size[0],
)

self.get_base_layer().weight.data = orig_weight
base_layer.weight.data = orig_weight.to(orig_dtype)

def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down Expand Up @@ -968,10 +977,12 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
scaled_rotated_weight = scaled_rotated_weight.view(
self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0]
)
x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
bias = self._cast_input_dtype(self.base_layer.bias, scaled_rotated_weight.dtype)
result = F.conv2d(
input=x,
weight=scaled_rotated_weight,
bias=self.base_layer.bias,
bias=bias,
padding=self.base_layer.padding[0],
stride=self.base_layer.stride[0],
)
Expand Down
22 changes: 17 additions & 5 deletions src/peft/tuners/bone/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled = True
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand Down Expand Up @@ -150,6 +152,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
for active_adapter in adapter_names:
if active_adapter in self.bone_block.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
Expand All @@ -166,14 +169,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

self.base_layer.weight.data = orig_weight
base_layer.weight.data = orig_weight.to(orig_dtype)
else:
if self.bone_fn == "bat":
delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data)
self.base_layer.weight.data += delta_weight
base_layer.weight.data += delta_weight.to(orig_dtype)
else:
delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data)
self.base_layer.weight.data = delta_weight
base_layer.weight.data = delta_weight.to(orig_dtype)
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand All @@ -183,16 +186,19 @@ def unmerge(self) -> None:
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return

while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if active_adapter in self.bone_block.keys():
orig_weight = self.get_base_layer().weight.data.clone()
if self.bone_fn == "bat":
delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True)
else:
delta_weight = self.get_delta_weight_bone(active_adapter, orig_weight, re=True)

self.get_base_layer().weight.data = delta_weight
base_layer.weight.data = delta_weight.to(orig_dtype)

def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
"""
Expand All @@ -213,12 +219,15 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens

if cast_to_fp32:
weight_bone = weight_bone.float()
orig_weight = orig_weight.to(weight_bone.dtype)

r = weight_bone.size(-1)
if re:
o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
one = torch.eye(weight_bone.size(-1)).to(weight_bone.device)
# inverse must be in float32, after that the dtype can be adjusted if needed
inv_I_plus_b = torch.inverse(one + weight_bone)
inv_I_plus_b = inv_I_plus_b.to(weight_bone.dtype)
w = (o - weight_bone) @ inv_I_plus_b
output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
else:
Expand Down Expand Up @@ -318,7 +327,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
delta_weight = self.get_delta_weight(active_adapter, orig_weight)
orig_weight = orig_weight + delta_weight

result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias)
x = self._cast_input_dtype(x, orig_weight.dtype)
bias = self._cast_input_dtype(self.base_layer.bias, orig_weight.dtype)
result = F.linear(input=x, weight=orig_weight, bias=bias)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
Expand All @@ -329,6 +340,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
if x.size(-1) % r != 0:
padding_size = (r - x.size(-1) % r) % r
x = F.pad(x, (0, padding_size))
x = self._cast_input_dtype(x, bone.dtype)
result = result + torch.sum(x.reshape(*x.shape[:-1], x.size(-1) // r, r), dim=-2) @ bone

result = result.to(previous_dtype)
Expand Down
7 changes: 4 additions & 3 deletions src/peft/tuners/fourierft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ def reset_fourier_parameters(self, adapter_name):
nn.init.zeros_(self.fourierft_spectrum[adapter_name])

def get_delta_weight(self, adapter) -> torch.Tensor:
# careful: ifft2 does not work with float16 or bfloat16
spectrum = self.fourierft_spectrum[adapter]
indices = self.indices[adapter].to(spectrum.device)
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device, dtype=spectrum.dtype)
dense_spectrum[indices[0, :], indices[1, :]] = spectrum
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device)
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter]
return delta_weight
return delta_weight.to(spectrum.dtype)


class FourierFTLinear(nn.Module, FourierFTLayer):
Expand Down
Loading