Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/peft/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .peft_model import PeftConfig, PeftModel
from .tuners.lora import LoraLayer
from .tuners.tuners_utils import BaseTunerLayer


def update_forward_signature(model: PeftModel) -> None:
Expand Down Expand Up @@ -218,8 +219,6 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
"""
Context manager disables input dtype casting to the dtype of the weight.

Currently specifically works for LoRA.

Parameters:
model (nn.Module):
The model containing PEFT modules whose input dtype casting is to be adjusted.
Expand All @@ -237,7 +236,7 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):

original_values = {}
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
if not isinstance(module, BaseTunerLayer):
continue
original_values[name] = module.cast_input_dtype_enabled
module.cast_input_dtype_enabled = False
Expand All @@ -246,7 +245,7 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
yield
finally:
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
if not isinstance(module, BaseTunerLayer):
continue
if name in original_values:
module.cast_input_dtype_enabled = original_values[name]
47 changes: 29 additions & 18 deletions src/peft/tuners/boft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled = True
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand Down Expand Up @@ -503,13 +505,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
for active_adapter in adapter_names:
if active_adapter in self.boft_R.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = base_layer.weight.data.clone()
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s

Expand All @@ -518,16 +521,16 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

self.base_layer.weight.data = orig_weight.contiguous()
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
else:
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)
orig_weight = base_layer.weight.data.clone()
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s

self.base_layer.weight.data = orig_weight.contiguous()
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)

self.merged_adapters.append(active_adapter)

Expand All @@ -538,17 +541,20 @@ def unmerge(self) -> None:
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return

while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if active_adapter in self.boft_R.keys():
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)

orig_weight = self.get_base_layer().weight.data.clone()
orig_weight = base_layer.weight.data.clone()
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight)
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)

self.get_base_layer().weight.data = orig_weight * (1 / boft_s)
base_layer.weight.data = (orig_weight * (1 / boft_s)).to(orig_dtype)

def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down Expand Up @@ -804,6 +810,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
for active_adapter in adapter_names:
if active_adapter in self.boft_R.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
Expand All @@ -814,14 +821,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s
orig_weight = orig_weight.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
)

self.base_layer.weight.data = orig_weight.contiguous()
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
else:
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)

Expand All @@ -830,14 +837,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s
orig_weight = orig_weight.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
)

self.base_layer.weight.data = orig_weight.contiguous()
self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)

self.merged_adapters.append(active_adapter)

Expand All @@ -850,26 +857,28 @@ def unmerge(self) -> None:
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if active_adapter in self.boft_R.keys():
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)

orig_weight = self.get_base_layer().weight.data.clone()
orig_weight = base_layer.weight.data.clone()
orig_weight = orig_weight.view(
self.out_features,
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0],
)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight)
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight.to(butterfly_oft_mat.dtype))
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * (1 / boft_s)
orig_weight = orig_weight.view(
self.out_features,
self.in_features,
self.get_base_layer().kernel_size[0],
self.get_base_layer().kernel_size[0],
base_layer.kernel_size[0],
base_layer.kernel_size[0],
)

self.get_base_layer().weight.data = orig_weight
self.get_base_layer().weight.data = orig_weight.to(orig_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.get_base_layer().weight.data = orig_weight.to(orig_dtype)
base_layer.weight.data = orig_weight.to(orig_dtype)


def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down Expand Up @@ -968,10 +977,12 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
scaled_rotated_weight = scaled_rotated_weight.view(
self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0]
)
x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
bias = self._cast_input_dtype(self.base_layer.bias, scaled_rotated_weight.dtype)
result = F.conv2d(
input=x,
weight=scaled_rotated_weight,
bias=self.base_layer.bias,
bias=bias,
padding=self.base_layer.padding[0],
stride=self.base_layer.stride[0],
)
Expand Down
22 changes: 17 additions & 5 deletions src/peft/tuners/bone/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled = True
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand Down Expand Up @@ -150,6 +152,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
for active_adapter in adapter_names:
if active_adapter in self.bone_block.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
Expand All @@ -166,14 +169,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

self.base_layer.weight.data = orig_weight
self.base_layer.weight.data = orig_weight.to(orig_dtype)
else:
if self.bone_fn == "bat":
delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data)
self.base_layer.weight.data += delta_weight
self.base_layer.weight.data += delta_weight.to(orig_dtype)
else:
delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data)
self.base_layer.weight.data = delta_weight
self.base_layer.weight.data = delta_weight.to(orig_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use base_layer here as well, right?

self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand All @@ -183,16 +186,19 @@ def unmerge(self) -> None:
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return

while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if active_adapter in self.bone_block.keys():
orig_weight = self.get_base_layer().weight.data.clone()
if self.bone_fn == "bat":
delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True)
else:
delta_weight = self.get_delta_weight_bone(active_adapter, orig_weight, re=True)

self.get_base_layer().weight.data = delta_weight
base_layer.weight.data = delta_weight.to(orig_dtype)

def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
"""
Expand All @@ -213,12 +219,15 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens

if cast_to_fp32:
weight_bone = weight_bone.float()
orig_weight = orig_weight.to(weight_bone.dtype)

r = weight_bone.size(-1)
if re:
o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
one = torch.eye(weight_bone.size(-1)).to(weight_bone.device)
# inverse must be in float32, after that the dtype can be adjusted if needed
inv_I_plus_b = torch.inverse(one + weight_bone)
inv_I_plus_b = inv_I_plus_b.to(weight_bone.dtype)
w = (o - weight_bone) @ inv_I_plus_b
output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
else:
Expand Down Expand Up @@ -318,7 +327,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
delta_weight = self.get_delta_weight(active_adapter, orig_weight)
orig_weight = orig_weight + delta_weight

result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias)
x = self._cast_input_dtype(x, orig_weight.dtype)
bias = self._cast_input_dtype(self.base_layer.bias, orig_weight.dtype)
result = F.linear(input=x, weight=orig_weight, bias=bias)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
Expand All @@ -329,6 +340,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
if x.size(-1) % r != 0:
padding_size = (r - x.size(-1) % r) % r
x = F.pad(x, (0, padding_size))
x = self._cast_input_dtype(x, bone.dtype)
result = result + torch.sum(x.reshape(*x.shape[:-1], x.size(-1) // r, r), dim=-2) @ bone

result = result.to(previous_dtype)
Expand Down
7 changes: 4 additions & 3 deletions src/peft/tuners/fourierft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ def reset_fourier_parameters(self, adapter_name):
nn.init.zeros_(self.fourierft_spectrum[adapter_name])

def get_delta_weight(self, adapter) -> torch.Tensor:
# careful: ifft2 does not work with float16 or bfloat16
spectrum = self.fourierft_spectrum[adapter]
indices = self.indices[adapter].to(spectrum.device)
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device, dtype=spectrum.dtype)
dense_spectrum[indices[0, :], indices[1, :]] = spectrum
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device)
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter]
return delta_weight
return delta_weight.to(spectrum.dtype)


class FourierFTLinear(nn.Module, FourierFTLayer):
Expand Down
Loading
Loading