From 28b7fd4636da920b605786d7863d2505f59aaa71 Mon Sep 17 00:00:00 2001 From: Lucas Malo Belanger Date: Sun, 21 Sep 2025 15:35:29 -0400 Subject: [PATCH 1/9] Add conv2d support for fourierft --- src/peft/tuners/fourierft/layer.py | 142 +++++++++++++++++++++++++++++ src/peft/tuners/fourierft/model.py | 14 ++- 2 files changed, 151 insertions(+), 5 deletions(-) diff --git a/src/peft/tuners/fourierft/layer.py b/src/peft/tuners/fourierft/layer.py index a03a57f118..c86d0a6344 100644 --- a/src/peft/tuners/fourierft/layer.py +++ b/src/peft/tuners/fourierft/layer.py @@ -48,6 +48,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: self.in_features, self.out_features = ( base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape ) + elif isinstance(base_layer, nn.Conv2d): + pass else: raise ValueError(f"Unsupported layer type {type(base_layer)}") @@ -191,3 +193,143 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: def __repr__(self) -> str: rep = super().__repr__() return "fourierft." + rep + +class FourierFTConv2D(nn.Module, FourierFTLayer): + # FourierFT implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + n_frequency: int = 1000, + scaling: float = 150.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_weights: Union[bool, str] = False, + random_loc_seed: int = 777, + **kwargs, + ) -> None: + super().__init__() + FourierFTLayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.in_features = base_layer.in_channels + self.out_features = base_layer.out_channels + self.kW = base_layer.kernel_size[0] + self.kH = base_layer.kernel_size[1] + self.stride = base_layer.stride + self.padding = base_layer.padding + self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) + + + def update_layer( + self, adapter_name, n_frequency, scaling, init_weights, random_loc_seed, inference_mode: bool = False, **kwargs + ): + if n_frequency <= 0: + raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}") + if n_frequency > self.in_features * self.out_features: + raise ValueError( + f"`n_frequency` should be less than or equal to the product of the input and output dimensions " + f"but the value passed is {n_frequency} and the product is {self.in_features * self.out_features}" + ) + self.fourierft_n_frequency[adapter_name] = n_frequency + self.fourierft_random_loc_seed[adapter_name] = random_loc_seed + self.indices[adapter_name] = torch.randperm( + self.out_features * self.in_features, + generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), + )[:n_frequency] + self.indices[adapter_name] = torch.stack( + [self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0 + ) + self.fourierft_scaling[adapter_name] = scaling + # Actual trainable parameters + self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency, self.kW, self.kH), requires_grad=True) + + if init_weights: + self.reset_fourier_parameters(adapter_name) + + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.fourierft_spectrum.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + orig_weights += self.get_delta_weight(active_adapter) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + base_layer.weight.data += self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.fourierft_spectrum.keys(): + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + + def get_delta_weight(self, adapter) -> torch.Tensor: + # careful: ifft2 does not work with float16 or bfloat16 + spectrum = self.fourierft_spectrum[adapter] + indices = self.indices[adapter].to(spectrum.device) + dense_spectrum = torch.zeros(self.out_features, self.in_features, self.kW, self.kH, device=spectrum.device) + dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() + delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter] + return delta_weight + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.fourierft_spectrum.keys(): + continue + + delta_w = self.get_delta_weight(active_adapter) + x = x.to(delta_w.dtype) + y = F.conv2d(x, delta_w, stride=self.stride, padding=self.padding) + result += y + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "fourierft." + rep \ No newline at end of file diff --git a/src/peft/tuners/fourierft/model.py b/src/peft/tuners/fourierft/model.py index 5347d90b17..4c45f362b4 100644 --- a/src/peft/tuners/fourierft/model.py +++ b/src/peft/tuners/fourierft/model.py @@ -19,13 +19,15 @@ import torch from transformers.pytorch_utils import Conv1D +from torch.nn import Conv2d from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer from peft.utils import ( TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, ) -from .layer import FourierFTLayer, FourierFTLinear +from .config import FourierFTConfig +from .layer import FourierFTLayer, FourierFTLinear, FourierFTConv2D class FourierFTModel(BaseTuner): @@ -110,6 +112,7 @@ def _create_new_module(fourierft_config, adapter_name, target, **kwargs): "Setting fan_in_fan_out to False." ) kwargs["fan_in_fan_out"] = fourierft_config.fan_in_fan_out = False + new_module = FourierFTLinear(target, adapter_name, **kwargs) elif isinstance(target_base_layer, Conv1D): kwargs["is_target_conv_1d_layer"] = True if not kwargs["fan_in_fan_out"]: @@ -117,12 +120,13 @@ def _create_new_module(fourierft_config, adapter_name, target, **kwargs): "fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True." ) kwargs["fan_in_fan_out"] = fourierft_config.fan_in_fan_out = True + new_module = FourierFTLinear(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, Conv2d): + new_module = FourierFTConv2D(target, adapter_name, **kwargs) else: raise ValueError( f"Target module {target} is not supported. Currently, only the following modules are supported: " - "`torch.nn.Linear`." + "`torch.nn.Linear`" + "`torch.nn.Conv2d`" ) - - new_module = FourierFTLinear(target, adapter_name, **kwargs) - return new_module From bf4dce5dd585ad56607885b945df84f4b4dd161a Mon Sep 17 00:00:00 2001 From: Lucas Malo Belanger Date: Mon, 22 Sep 2025 14:33:00 -0400 Subject: [PATCH 2/9] FourierFT: add alpha for dynamic n_frequency --- src/peft/tuners/fourierft/config.py | 9 +++++++++ src/peft/tuners/fourierft/layer.py | 23 +++++++++++++++++------ src/peft/tuners/fourierft/model.py | 2 ++ 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/peft/tuners/fourierft/config.py b/src/peft/tuners/fourierft/config.py index dbbb80d8e0..ba60af08d6 100644 --- a/src/peft/tuners/fourierft/config.py +++ b/src/peft/tuners/fourierft/config.py @@ -185,6 +185,15 @@ class FourierFTConfig(PeftConfig): }, ) + alpha: float = field( + default=None, + metadata={ + "help": ( + "The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)" + ) + }, + ) + def __post_init__(self): super().__post_init__() self.peft_type = PeftType.FOURIERFT diff --git a/src/peft/tuners/fourierft/layer.py b/src/peft/tuners/fourierft/layer.py index c86d0a6344..935ed7e6b2 100644 --- a/src/peft/tuners/fourierft/layer.py +++ b/src/peft/tuners/fourierft/layer.py @@ -29,7 +29,7 @@ class FourierFTLayer(BaseTunerLayer): # All names of other parameters that may contain adapter-related parameters other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed") - def __init__(self, base_layer: nn.Module, **kwargs) -> None: + def __init__(self, base_layer: nn.Module, alpha, **kwargs) -> None: self.base_layer = base_layer self.fourierft_n_frequency = {} self.fourierft_scaling = {} @@ -49,7 +49,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape ) elif isinstance(base_layer, nn.Conv2d): - pass + self.in_features = base_layer.in_channels + self.out_features = base_layer.out_channels else: raise ValueError(f"Unsupported layer type {type(base_layer)}") @@ -104,6 +105,7 @@ def __init__( base_layer, adapter_name: str, n_frequency: int = 1000, + alpha: float = None, scaling: float = 150.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) init_weights: Union[bool, str] = False, @@ -111,7 +113,12 @@ def __init__( **kwargs, ) -> None: super().__init__() - FourierFTLayer.__init__(self, base_layer, **kwargs) + FourierFTLayer.__init__(self, base_layer, alpha, **kwargs) + + # apply alpha patch + if alpha: + n_frequency = int(alpha * self.in_features * self.out_features) + self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) @@ -201,6 +208,7 @@ def __init__( base_layer, adapter_name: str, n_frequency: int = 1000, + alpha: float = None, scaling: float = 150.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) init_weights: Union[bool, str] = False, @@ -208,11 +216,14 @@ def __init__( **kwargs, ) -> None: super().__init__() - FourierFTLayer.__init__(self, base_layer, **kwargs) + FourierFTLayer.__init__(self, base_layer, alpha, **kwargs) + + # apply alpha patch + if alpha: + n_frequency = int(alpha * self.in_features * self.out_features) + self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name - self.in_features = base_layer.in_channels - self.out_features = base_layer.out_channels self.kW = base_layer.kernel_size[0] self.kH = base_layer.kernel_size[1] self.stride = base_layer.stride diff --git a/src/peft/tuners/fourierft/model.py b/src/peft/tuners/fourierft/model.py index 4c45f362b4..18bbd1dff0 100644 --- a/src/peft/tuners/fourierft/model.py +++ b/src/peft/tuners/fourierft/model.py @@ -73,10 +73,12 @@ def _create_and_replace( n_frequency = fourierft_config.n_frequency_pattern.get(target_name_key, fourierft_config.n_frequency) scaling = fourierft_config.scaling + alpha = fourierft_config.alpha random_loc_seed = fourierft_config.random_loc_seed bias = hasattr(target, "bias") and target.bias is not None kwargs = { "n_frequency": n_frequency, + "alpha": alpha, "scaling": scaling, "fan_in_fan_out": fourierft_config.fan_in_fan_out, "init_weights": fourierft_config.init_weights, From 3172bf09252fba072177524ebe9bb0576bc5baab Mon Sep 17 00:00:00 2001 From: Lucas Malo Belanger Date: Thu, 25 Sep 2025 16:56:48 -0400 Subject: [PATCH 3/9] Stick closer to the paper for conv2d and add the norm option in the config --- src/peft/tuners/fourierft/config.py | 11 +++++++++++ src/peft/tuners/fourierft/layer.py | 23 +++++++++++------------ src/peft/tuners/fourierft/model.py | 2 ++ 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/peft/tuners/fourierft/config.py b/src/peft/tuners/fourierft/config.py index ba60af08d6..1ca8d4a2df 100644 --- a/src/peft/tuners/fourierft/config.py +++ b/src/peft/tuners/fourierft/config.py @@ -174,6 +174,17 @@ class FourierFTConfig(PeftConfig): ) }, ) + + ifft2_norm: Optional[str] = field( + default_factory='backward', + metadata={ + "help": ( + "The normalization applied for the ifft2 operation." + "It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details" + "The default value is `backward`." + ) + }, + ) init_weights: bool = field( default=False, metadata={ diff --git a/src/peft/tuners/fourierft/layer.py b/src/peft/tuners/fourierft/layer.py index 935ed7e6b2..c86f5d8549 100644 --- a/src/peft/tuners/fourierft/layer.py +++ b/src/peft/tuners/fourierft/layer.py @@ -94,7 +94,7 @@ def get_delta_weight(self, adapter) -> torch.Tensor: indices = self.indices[adapter].to(spectrum.device) dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device) dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() - delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter] + delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.kwargs['ifft2_norm']).real * self.fourierft_scaling[adapter] return delta_weight.to(spectrum.dtype) @@ -217,10 +217,6 @@ def __init__( ) -> None: super().__init__() FourierFTLayer.__init__(self, base_layer, alpha, **kwargs) - - # apply alpha patch - if alpha: - n_frequency = int(alpha * self.in_features * self.out_features) self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name @@ -228,6 +224,10 @@ def __init__( self.kH = base_layer.kernel_size[1] self.stride = base_layer.stride self.padding = base_layer.padding + + # apply alpha patch + if alpha: + n_frequency = int(alpha * self.in_features * self.out_features * self.kW * self.kH) self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) @@ -244,15 +244,15 @@ def update_layer( self.fourierft_n_frequency[adapter_name] = n_frequency self.fourierft_random_loc_seed[adapter_name] = random_loc_seed self.indices[adapter_name] = torch.randperm( - self.out_features * self.in_features, + self.out_features * self.in_features * self.kW * self.kH, generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), )[:n_frequency] self.indices[adapter_name] = torch.stack( - [self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0 + [self.indices[adapter_name] // (self.in_features * self.kW), self.indices[adapter_name] % (self.in_features * self.kW)], dim=0 ) self.fourierft_scaling[adapter_name] = scaling # Actual trainable parameters - self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency, self.kW, self.kH), requires_grad=True) + self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency), requires_grad=True) if init_weights: self.reset_fourier_parameters(adapter_name) @@ -310,13 +310,12 @@ def unmerge(self) -> None: self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) def get_delta_weight(self, adapter) -> torch.Tensor: - # careful: ifft2 does not work with float16 or bfloat16 spectrum = self.fourierft_spectrum[adapter] indices = self.indices[adapter].to(spectrum.device) - dense_spectrum = torch.zeros(self.out_features, self.in_features, self.kW, self.kH, device=spectrum.device) + dense_spectrum = torch.zeros(self.out_features*self.kH, self.in_features*self.kW, device=spectrum.device) dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() - delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter] - return delta_weight + delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.kwargs['ifft2_norm']).real * self.fourierft_scaling[adapter] + return torch.reshape(delta_weight, (self.out_features, self.in_features, self.kW, self.kH)) def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: previous_dtype = x.dtype diff --git a/src/peft/tuners/fourierft/model.py b/src/peft/tuners/fourierft/model.py index 18bbd1dff0..d19f9c2b8c 100644 --- a/src/peft/tuners/fourierft/model.py +++ b/src/peft/tuners/fourierft/model.py @@ -74,12 +74,14 @@ def _create_and_replace( n_frequency = fourierft_config.n_frequency_pattern.get(target_name_key, fourierft_config.n_frequency) scaling = fourierft_config.scaling alpha = fourierft_config.alpha + ifft2_norm = fourierft_config.ifft2_norm random_loc_seed = fourierft_config.random_loc_seed bias = hasattr(target, "bias") and target.bias is not None kwargs = { "n_frequency": n_frequency, "alpha": alpha, "scaling": scaling, + "ifft2_norm": ifft2_norm, "fan_in_fan_out": fourierft_config.fan_in_fan_out, "init_weights": fourierft_config.init_weights, "random_loc_seed": fourierft_config.random_loc_seed, From 7b408d8ca078e8f4f6c08be741420e60e6b5ff2c Mon Sep 17 00:00:00 2001 From: Lucas Malo Belanger Date: Thu, 25 Sep 2025 18:02:36 -0400 Subject: [PATCH 4/9] Address most of the comments in the review --- src/peft/tuners/fourierft/layer.py | 187 +++++++++++------------------ src/peft/tuners/fourierft/model.py | 3 +- 2 files changed, 68 insertions(+), 122 deletions(-) diff --git a/src/peft/tuners/fourierft/layer.py b/src/peft/tuners/fourierft/layer.py index c86f5d8549..4930084e7b 100644 --- a/src/peft/tuners/fourierft/layer.py +++ b/src/peft/tuners/fourierft/layer.py @@ -29,7 +29,7 @@ class FourierFTLayer(BaseTunerLayer): # All names of other parameters that may contain adapter-related parameters other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed") - def __init__(self, base_layer: nn.Module, alpha, **kwargs) -> None: + def __init__(self, base_layer: nn.Module, **kwargs) -> None: self.base_layer = base_layer self.fourierft_n_frequency = {} self.fourierft_scaling = {} @@ -59,20 +59,22 @@ def update_layer( ): if n_frequency <= 0: raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}") - if n_frequency > self.in_features * self.out_features: + + if isinstance(self, FourierFTLinear): + max_freqs = self.in_features * self.out_features + else: + kW = self.base_layer.kernel_size[0] + kH = self.base_layer.kernel_size[1] + max_freqs = self.in_features * self.out_features * kW * kH + + if n_frequency >= max_freqs: raise ValueError( f"`n_frequency` should be less than or equal to the product of the input and output dimensions " - f"but the value passed is {n_frequency} and the product is {self.in_features * self.out_features}" + f"but the value passed is {n_frequency} and the product is {max_freqs}" ) self.fourierft_n_frequency[adapter_name] = n_frequency self.fourierft_random_loc_seed[adapter_name] = random_loc_seed - self.indices[adapter_name] = torch.randperm( - self.out_features * self.in_features, - generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), - )[:n_frequency] - self.indices[adapter_name] = torch.stack( - [self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0 - ) + self.set_indices(adapter_name, n_frequency) self.fourierft_scaling[adapter_name] = scaling # Actual trainable parameters self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency), requires_grad=True) @@ -96,33 +98,7 @@ def get_delta_weight(self, adapter) -> torch.Tensor: dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.kwargs['ifft2_norm']).real * self.fourierft_scaling[adapter] return delta_weight.to(spectrum.dtype) - - -class FourierFTLinear(nn.Module, FourierFTLayer): - # FourierFT implemented in a dense layer - def __init__( - self, - base_layer, - adapter_name: str, - n_frequency: int = 1000, - alpha: float = None, - scaling: float = 150.0, - fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) - init_weights: Union[bool, str] = False, - random_loc_seed: int = 777, - **kwargs, - ) -> None: - super().__init__() - FourierFTLayer.__init__(self, base_layer, alpha, **kwargs) - - # apply alpha patch - if alpha: - n_frequency = int(alpha * self.in_features * self.out_features) - - self.fan_in_fan_out = fan_in_fan_out - self._active_adapter = adapter_name - self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) - + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ Merge the active adapter weights into the base weights @@ -172,6 +148,41 @@ def unmerge(self) -> None: if active_adapter in self.fourierft_spectrum.keys(): self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + def set_indices(self, adapter_name : str, n_frequency : int): + self.indices[adapter_name] = torch.randperm( + self.out_features * self.in_features, + generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), + )[:n_frequency] + self.indices[adapter_name] = torch.stack( + [self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0 + ) + + +class FourierFTLinear(nn.Module, FourierFTLayer): + # FourierFT implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + n_frequency: int = 1000, + alpha: Optional[float] = None, + scaling: float = 150.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_weights: Union[bool, str] = False, + random_loc_seed: int = 777, + **kwargs, + ) -> None: + super().__init__() + FourierFTLayer.__init__(self, base_layer, **kwargs) + + # apply alpha patch + if alpha: + n_frequency = int(alpha * self.in_features * self.out_features) + + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) + def get_delta_weight(self, adapter) -> torch.Tensor: return super().get_delta_weight(adapter) @@ -201,6 +212,10 @@ def __repr__(self) -> str: rep = super().__repr__() return "fourierft." + rep + def set_indices(self, adapter_name : str, n_frequency : int): + super().set_indices(adapter_name, n_frequency) + + class FourierFTConv2D(nn.Module, FourierFTLayer): # FourierFT implemented in a dense layer def __init__( @@ -208,7 +223,7 @@ def __init__( base_layer, adapter_name: str, n_frequency: int = 1000, - alpha: float = None, + alpha: Optional[float] = None, scaling: float = 150.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) init_weights: Union[bool, str] = False, @@ -216,106 +231,38 @@ def __init__( **kwargs, ) -> None: super().__init__() - FourierFTLayer.__init__(self, base_layer, alpha, **kwargs) + FourierFTLayer.__init__(self, base_layer, **kwargs) self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name - self.kW = base_layer.kernel_size[0] - self.kH = base_layer.kernel_size[1] - self.stride = base_layer.stride - self.padding = base_layer.padding + kW = base_layer.kernel_size[0] + kH = base_layer.kernel_size[1] # apply alpha patch if alpha: - n_frequency = int(alpha * self.in_features * self.out_features * self.kW * self.kH) + n_frequency = int(alpha * self.in_features * self.out_features * kW * kH) self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) - - def update_layer( - self, adapter_name, n_frequency, scaling, init_weights, random_loc_seed, inference_mode: bool = False, **kwargs - ): - if n_frequency <= 0: - raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}") - if n_frequency > self.in_features * self.out_features: - raise ValueError( - f"`n_frequency` should be less than or equal to the product of the input and output dimensions " - f"but the value passed is {n_frequency} and the product is {self.in_features * self.out_features}" - ) - self.fourierft_n_frequency[adapter_name] = n_frequency - self.fourierft_random_loc_seed[adapter_name] = random_loc_seed + def set_indices(self, adapter_name : str, n_frequency : int): + kW = self.base_layer.kernel_size[0] + kH = self.base_layer.kernel_size[1] self.indices[adapter_name] = torch.randperm( - self.out_features * self.in_features * self.kW * self.kH, + self.out_features * self.in_features * kW * kH, generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), )[:n_frequency] self.indices[adapter_name] = torch.stack( - [self.indices[adapter_name] // (self.in_features * self.kW), self.indices[adapter_name] % (self.in_features * self.kW)], dim=0 + [self.indices[adapter_name] // (self.in_features * kW), self.indices[adapter_name] % (self.in_features * kW)], dim=0 ) - self.fourierft_scaling[adapter_name] = scaling - # Actual trainable parameters - self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency), requires_grad=True) - - if init_weights: - self.reset_fourier_parameters(adapter_name) - - self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters, inference_mode=inference_mode) - - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: - """ - Merge the active adapter weights into the base weights - - Args: - safe_merge (`bool`, *optional*): - If True, the merge operation will be performed in a copy of the original weights and check for NaNs - before merging the weights. This is useful if you want to check if the merge operation will produce - NaNs. Defaults to `False`. - adapter_names (`List[str]`, *optional*): - The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults - to `None`. - """ - adapter_names = check_adapters_to_merge(self, adapter_names) - if not adapter_names: - # no adapter to merge - return - - for active_adapter in adapter_names: - if active_adapter in self.fourierft_spectrum.keys(): - base_layer = self.get_base_layer() - if safe_merge: - # Note that safe_merge will be slower than the normal merge - # because of the copy operation. - orig_weights = base_layer.weight.data.clone() - orig_weights += self.get_delta_weight(active_adapter) - - if not torch.isfinite(orig_weights).all(): - raise ValueError( - f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" - ) - - base_layer.weight.data = orig_weights - else: - base_layer.weight.data += self.get_delta_weight(active_adapter) - self.merged_adapters.append(active_adapter) - - def unmerge(self) -> None: - """ - This method unmerges all merged adapter layers from the base weights. - """ - if not self.merged: - warnings.warn("Already unmerged. Nothing to do.") - return - while len(self.merged_adapters) > 0: - active_adapter = self.merged_adapters.pop() - if active_adapter in self.fourierft_spectrum.keys(): - self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) def get_delta_weight(self, adapter) -> torch.Tensor: + kW = self.base_layer.kernel_size[0] + kH = self.base_layer.kernel_size[1] spectrum = self.fourierft_spectrum[adapter] indices = self.indices[adapter].to(spectrum.device) - dense_spectrum = torch.zeros(self.out_features*self.kH, self.in_features*self.kW, device=spectrum.device) + dense_spectrum = torch.zeros(self.out_features*kH, self.in_features*kW, device=spectrum.device) dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.kwargs['ifft2_norm']).real * self.fourierft_scaling[adapter] - return torch.reshape(delta_weight, (self.out_features, self.in_features, self.kW, self.kH)) + return torch.reshape(delta_weight, (self.out_features, self.in_features, kW, kH)) def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: previous_dtype = x.dtype @@ -334,7 +281,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: delta_w = self.get_delta_weight(active_adapter) x = x.to(delta_w.dtype) - y = F.conv2d(x, delta_w, stride=self.stride, padding=self.padding) + y = F.conv2d(x, delta_w, stride=self.base_layer.stride, padding=self.base_layer.padding) result += y result = result.to(previous_dtype) diff --git a/src/peft/tuners/fourierft/model.py b/src/peft/tuners/fourierft/model.py index d19f9c2b8c..c13da93ade 100644 --- a/src/peft/tuners/fourierft/model.py +++ b/src/peft/tuners/fourierft/model.py @@ -130,7 +130,6 @@ def _create_new_module(fourierft_config, adapter_name, target, **kwargs): else: raise ValueError( f"Target module {target} is not supported. Currently, only the following modules are supported: " - "`torch.nn.Linear`" - "`torch.nn.Conv2d`" + "`torch.nn.Linear`, `torch.nn.Conv2d`" ) return new_module From 89268779ca2c0326f47ff4d3964bf4c616e87e55 Mon Sep 17 00:00:00 2001 From: Lucas Malo Belanger Date: Thu, 25 Sep 2025 18:09:56 -0400 Subject: [PATCH 5/9] Add some tests for FourierFT conv2d layers --- tests/test_custom_models.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 33cded4116..f883a18959 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -663,6 +663,27 @@ "init_weights": True, }, ), + ( + "Conv2d 1 FourierFT", + "Conv2d", + FourierFTConfig, + { + "target_modules": ["conv2d"], + "n_frequency": 1000, + } + ), + ( + "Conv2d 2 FourierFT", + "Conv2d", + FourierFTConfig, + { + "target_modules": ["conv2d", "lin0"], + "alpha" : 0.01, + "init_weights": True, + "ifft2_norm": "ortho", + } + ), + ########## # VBLoRA # ########## From 9cbc72d7788b2074fccc95df6463c25b84356890 Mon Sep 17 00:00:00 2001 From: Lucas Malo Belanger Date: Sat, 27 Sep 2025 10:21:29 -0400 Subject: [PATCH 6/9] Add fourierft_ifft2_norm to the parameters and other fixes --- src/peft/tuners/fourierft/config.py | 10 ++++++++-- src/peft/tuners/fourierft/layer.py | 7 ++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/peft/tuners/fourierft/config.py b/src/peft/tuners/fourierft/config.py index 1ca8d4a2df..579cdf8d03 100644 --- a/src/peft/tuners/fourierft/config.py +++ b/src/peft/tuners/fourierft/config.py @@ -15,7 +15,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Optional, Union +from typing import Optional, Union, Literal from peft.config import PeftConfig from peft.utils import PeftType @@ -175,7 +175,7 @@ class FourierFTConfig(PeftConfig): }, ) - ifft2_norm: Optional[str] = field( + ifft2_norm: Optional[Literal["backward", "forward", "ortho"]] = field( default_factory='backward', metadata={ "help": ( @@ -224,3 +224,9 @@ def __post_init__(self): # check for layers_to_transform and layers_pattern if self.layers_pattern and not self.layers_to_transform: raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") + + if (self.alpha is not None) and (self.n_frequency != 1000): + raise ValueError("Don't set both alpha and n_frequency, as alpha overrides ...") + + if (self.alpha is not None) and (self.n_frequency_pattern != {}): + raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides ...") \ No newline at end of file diff --git a/src/peft/tuners/fourierft/layer.py b/src/peft/tuners/fourierft/layer.py index 4930084e7b..ce38c34e9a 100644 --- a/src/peft/tuners/fourierft/layer.py +++ b/src/peft/tuners/fourierft/layer.py @@ -27,7 +27,7 @@ class FourierFTLayer(BaseTunerLayer): # All names of layers that may contain (trainable) adapter weights adapter_layer_names = ("fourierft_spectrum",) # All names of other parameters that may contain adapter-related parameters - other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed") + other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed", "fourierft_ifft2_norm") def __init__(self, base_layer: nn.Module, **kwargs) -> None: self.base_layer = base_layer @@ -39,6 +39,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + self.fourierft_ifft2_norm = kwargs['ifft2_norm'] self.kwargs = kwargs base_layer = self.get_base_layer() @@ -96,7 +97,7 @@ def get_delta_weight(self, adapter) -> torch.Tensor: indices = self.indices[adapter].to(spectrum.device) dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device) dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() - delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.kwargs['ifft2_norm']).real * self.fourierft_scaling[adapter] + delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter] return delta_weight.to(spectrum.dtype) def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: @@ -261,7 +262,7 @@ def get_delta_weight(self, adapter) -> torch.Tensor: indices = self.indices[adapter].to(spectrum.device) dense_spectrum = torch.zeros(self.out_features*kH, self.in_features*kW, device=spectrum.device) dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() - delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.kwargs['ifft2_norm']).real * self.fourierft_scaling[adapter] + delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter] return torch.reshape(delta_weight, (self.out_features, self.in_features, kW, kH)) def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: From 2b205f39c6f72bcc77e04aba4473d84c51f15b40 Mon Sep 17 00:00:00 2001 From: Lucas Malo Belanger Date: Sat, 27 Sep 2025 10:35:49 -0400 Subject: [PATCH 7/9] Make style --- src/peft/tuners/fourierft/config.py | 14 +++++------ src/peft/tuners/fourierft/layer.py | 39 +++++++++++++++++++---------- src/peft/tuners/fourierft/model.py | 5 ++-- tests/test_custom_models.py | 21 ++++++++-------- 4 files changed, 44 insertions(+), 35 deletions(-) diff --git a/src/peft/tuners/fourierft/config.py b/src/peft/tuners/fourierft/config.py index 579cdf8d03..f3d4b69229 100644 --- a/src/peft/tuners/fourierft/config.py +++ b/src/peft/tuners/fourierft/config.py @@ -15,7 +15,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Optional, Union, Literal +from typing import Literal, Optional, Union from peft.config import PeftConfig from peft.utils import PeftType @@ -176,7 +176,7 @@ class FourierFTConfig(PeftConfig): ) ifft2_norm: Optional[Literal["backward", "forward", "ortho"]] = field( - default_factory='backward', + default_factory="backward", metadata={ "help": ( "The normalization applied for the ifft2 operation." @@ -199,9 +199,7 @@ class FourierFTConfig(PeftConfig): alpha: float = field( default=None, metadata={ - "help": ( - "The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)" - ) + "help": ("The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)") }, ) @@ -224,9 +222,9 @@ def __post_init__(self): # check for layers_to_transform and layers_pattern if self.layers_pattern and not self.layers_to_transform: raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") - + if (self.alpha is not None) and (self.n_frequency != 1000): raise ValueError("Don't set both alpha and n_frequency, as alpha overrides ...") - + if (self.alpha is not None) and (self.n_frequency_pattern != {}): - raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides ...") \ No newline at end of file + raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides ...") diff --git a/src/peft/tuners/fourierft/layer.py b/src/peft/tuners/fourierft/layer.py index ce38c34e9a..12a5e58161 100644 --- a/src/peft/tuners/fourierft/layer.py +++ b/src/peft/tuners/fourierft/layer.py @@ -27,7 +27,12 @@ class FourierFTLayer(BaseTunerLayer): # All names of layers that may contain (trainable) adapter weights adapter_layer_names = ("fourierft_spectrum",) # All names of other parameters that may contain adapter-related parameters - other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed", "fourierft_ifft2_norm") + other_param_names = ( + "fourierft_n_frequency", + "fourierft_scaling", + "fourierft_random_loc_seed", + "fourierft_ifft2_norm", + ) def __init__(self, base_layer: nn.Module, **kwargs) -> None: self.base_layer = base_layer @@ -39,7 +44,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] - self.fourierft_ifft2_norm = kwargs['ifft2_norm'] + self.fourierft_ifft2_norm = kwargs["ifft2_norm"] self.kwargs = kwargs base_layer = self.get_base_layer() @@ -60,7 +65,7 @@ def update_layer( ): if n_frequency <= 0: raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}") - + if isinstance(self, FourierFTLinear): max_freqs = self.in_features * self.out_features else: @@ -97,9 +102,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor: indices = self.indices[adapter].to(spectrum.device) dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device) dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() - delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter] + delta_weight = ( + torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter] + ) return delta_weight.to(spectrum.dtype) - + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ Merge the active adapter weights into the base weights @@ -149,7 +156,7 @@ def unmerge(self) -> None: if active_adapter in self.fourierft_spectrum.keys(): self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) - def set_indices(self, adapter_name : str, n_frequency : int): + def set_indices(self, adapter_name: str, n_frequency: int): self.indices[adapter_name] = torch.randperm( self.out_features * self.in_features, generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), @@ -213,7 +220,7 @@ def __repr__(self) -> str: rep = super().__repr__() return "fourierft." + rep - def set_indices(self, adapter_name : str, n_frequency : int): + def set_indices(self, adapter_name: str, n_frequency: int): super().set_indices(adapter_name, n_frequency) @@ -233,7 +240,7 @@ def __init__( ) -> None: super().__init__() FourierFTLayer.__init__(self, base_layer, **kwargs) - + self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name kW = base_layer.kernel_size[0] @@ -244,7 +251,7 @@ def __init__( n_frequency = int(alpha * self.in_features * self.out_features * kW * kH) self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) - def set_indices(self, adapter_name : str, n_frequency : int): + def set_indices(self, adapter_name: str, n_frequency: int): kW = self.base_layer.kernel_size[0] kH = self.base_layer.kernel_size[1] self.indices[adapter_name] = torch.randperm( @@ -252,7 +259,11 @@ def set_indices(self, adapter_name : str, n_frequency : int): generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), )[:n_frequency] self.indices[adapter_name] = torch.stack( - [self.indices[adapter_name] // (self.in_features * kW), self.indices[adapter_name] % (self.in_features * kW)], dim=0 + [ + self.indices[adapter_name] // (self.in_features * kW), + self.indices[adapter_name] % (self.in_features * kW), + ], + dim=0, ) def get_delta_weight(self, adapter) -> torch.Tensor: @@ -260,9 +271,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor: kH = self.base_layer.kernel_size[1] spectrum = self.fourierft_spectrum[adapter] indices = self.indices[adapter].to(spectrum.device) - dense_spectrum = torch.zeros(self.out_features*kH, self.in_features*kW, device=spectrum.device) + dense_spectrum = torch.zeros(self.out_features * kH, self.in_features * kW, device=spectrum.device) dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() - delta_weight = torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter] + delta_weight = ( + torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter] + ) return torch.reshape(delta_weight, (self.out_features, self.in_features, kW, kH)) def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: @@ -290,4 +303,4 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: def __repr__(self) -> str: rep = super().__repr__() - return "fourierft." + rep \ No newline at end of file + return "fourierft." + rep diff --git a/src/peft/tuners/fourierft/model.py b/src/peft/tuners/fourierft/model.py index c13da93ade..abf4226a90 100644 --- a/src/peft/tuners/fourierft/model.py +++ b/src/peft/tuners/fourierft/model.py @@ -18,16 +18,15 @@ from itertools import chain import torch -from transformers.pytorch_utils import Conv1D from torch.nn import Conv2d +from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer from peft.utils import ( TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, ) -from .config import FourierFTConfig -from .layer import FourierFTLayer, FourierFTLinear, FourierFTConv2D +from .layer import FourierFTConv2D, FourierFTLayer, FourierFTLinear class FourierFTModel(BaseTuner): diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index f883a18959..b4b5b2b019 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -664,26 +664,25 @@ }, ), ( - "Conv2d 1 FourierFT", - "Conv2d", + "Conv2d 1 FourierFT", + "Conv2d", FourierFTConfig, { "target_modules": ["conv2d"], "n_frequency": 1000, - } - ), + }, + ), ( - "Conv2d 2 FourierFT", - "Conv2d", - FourierFTConfig, + "Conv2d 2 FourierFT", + "Conv2d", + FourierFTConfig, { "target_modules": ["conv2d", "lin0"], - "alpha" : 0.01, + "alpha": 0.01, "init_weights": True, "ifft2_norm": "ortho", - } - ), - + }, + ), ########## # VBLoRA # ########## From 3613607405f04a4956575b367145ea5731a431b2 Mon Sep 17 00:00:00 2001 From: Lucas Malo Belanger Date: Sat, 18 Oct 2025 10:30:37 -0400 Subject: [PATCH 8/9] Add tests for FourierFT --- docs/source/package_reference/fourierft.md | 2 +- src/peft/tuners/fourierft/config.py | 15 +++++---- src/peft/tuners/fourierft/layer.py | 3 -- tests/test_initialization.py | 39 ++++++++++++++++++++++ 4 files changed, 49 insertions(+), 10 deletions(-) diff --git a/docs/source/package_reference/fourierft.md b/docs/source/package_reference/fourierft.md index 1d298a9042..641ad65441 100644 --- a/docs/source/package_reference/fourierft.md +++ b/docs/source/package_reference/fourierft.md @@ -20,7 +20,7 @@ rendered properly in your Markdown viewer. FourierFT currently has the following constraints: -- Only `nn.Linear` layers are supported. +- Only `nn.Linear` and `nn.Conv2d` layers are supported. - Quantized layers are not supported. If these constraints don't work for your use case, consider other methods instead. diff --git a/src/peft/tuners/fourierft/config.py b/src/peft/tuners/fourierft/config.py index f3d4b69229..0f4af2e117 100644 --- a/src/peft/tuners/fourierft/config.py +++ b/src/peft/tuners/fourierft/config.py @@ -179,9 +179,9 @@ class FourierFTConfig(PeftConfig): default_factory="backward", metadata={ "help": ( - "The normalization applied for the ifft2 operation." - "It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details" - "The default value is `backward`." + "The normalization applied for the ifft2 operation. " + "It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details " + "(https://docs.pytorch.org/docs/stable/generated/torch.fft.ifft2.html) The default value is `backward`." ) }, ) @@ -199,7 +199,10 @@ class FourierFTConfig(PeftConfig): alpha: float = field( default=None, metadata={ - "help": ("The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)") + "help": ( + "The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)" + "If alpha is set, the n_frequency and n_frequency_pattern parameters should not be set." + ) }, ) @@ -224,7 +227,7 @@ def __post_init__(self): raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") if (self.alpha is not None) and (self.n_frequency != 1000): - raise ValueError("Don't set both alpha and n_frequency, as alpha overrides ...") + raise ValueError("Don't set both alpha and n_frequency, as alpha overrides n_frequency.") if (self.alpha is not None) and (self.n_frequency_pattern != {}): - raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides ...") + raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides n_frequency_pattern.") diff --git a/src/peft/tuners/fourierft/layer.py b/src/peft/tuners/fourierft/layer.py index 12a5e58161..7afe3e0244 100644 --- a/src/peft/tuners/fourierft/layer.py +++ b/src/peft/tuners/fourierft/layer.py @@ -220,9 +220,6 @@ def __repr__(self) -> str: rep = super().__repr__() return "fourierft." + rep - def set_indices(self, adapter_name: str, n_frequency: int): - super().set_indices(adapter_name, n_frequency) - class FourierFTConv2D(nn.Module, FourierFTLayer): # FourierFT implemented in a dense layer diff --git a/tests/test_initialization.py b/tests/test_initialization.py index f37e0c2cbe..0d9f780f15 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -39,6 +39,7 @@ AdaLoraConfig, C3AConfig, EvaConfig, + FourierFTConfig, IA3Config, LoftQConfig, LoKrConfig, @@ -4499,3 +4500,41 @@ def test_key_mapping_save_old_load_new_vblora(self, old_model, new_model, tmp_pa def test_key_mapping_save_new_load_old_vblora(self, old_model, new_model, tmp_path): # save the new model, load it into the old model, should work without issues (forwards compatibility) self.check_vblora_load_no_warning(new_model, old_model, tmp_path) + + +class TestFourierFTInitialization: + torch_device = infer_device() + + def get_model(self, bias=True): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + # choose a large weight so that averages are close to expected values + self.linear = nn.Linear(1000, 1000, bias=bias) + self.embed = nn.Embedding(1000, 1000) + self.conv2d = nn.Conv2d(100, 100, 3, bias=bias) + + def forward(self, x): + x_int = (100 * x).int() + x_4d = x.flatten().reshape(1, 100, 10, 10) + return self.linear(x), self.embed(x_int), self.conv2d(x_4d) + + return MyModule().eval().to(self.torch_device) + + def test_fourierft_set_alpha_and_n_frequency_raises(self): + torch.manual_seed(0) + + model = self.get_model() + config = FourierFTConfig(target_modules=["linear"], alpha=0.1, n_frequency=2000) + msg = "User shoudn't set both alpha and n_frequency parameters." + with pytest.raises(ValueError, match=msg): + get_peft_model(model, config) + + def test_fourierft_set_alpha_and_n_frequency_pattern_raises(self): + torch.manual_seed(0) + + model = self.get_model() + config = FourierFTConfig(target_modules=["linear"], alpha=0.1, n_frequency_pattern={"linear": 2000}) + msg = "User shoudn't set both alpha and n_frequency_pattern parameters." + with pytest.raises(ValueError, match=msg): + get_peft_model(model, config) From 2f2312614c3bd383c4bf8a0aeff7e88efb266a46 Mon Sep 17 00:00:00 2001 From: Lucas Malo Belanger Date: Sat, 18 Oct 2025 10:34:04 -0400 Subject: [PATCH 9/9] Add alpha and ifft2_norm to the fourierft docstring --- src/peft/tuners/fourierft/config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/peft/tuners/fourierft/config.py b/src/peft/tuners/fourierft/config.py index 0f4af2e117..c30115e84b 100644 --- a/src/peft/tuners/fourierft/config.py +++ b/src/peft/tuners/fourierft/config.py @@ -78,6 +78,14 @@ class FourierFTConfig(PeftConfig): init_weights (`bool`): The initialization of the Fourier weights. Set this to False (the default) if the spectrum are initialized to a standard normal distribution. Set this to True if the spectrum are initialized to zeros. + alpha: + The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features) If alpha is set, + the n_frequency and n_frequency_pattern parameters should not be set. + ifft2_norm: + The normalization applied for the ifft2 operation. It has to be either `backward`, `forward` or `ortho`. + See the pytorch documentation for the ifft2 function for more details + (https://docs.pytorch.org/docs/stable/generated/torch.fft.ifft2.html) The default value is `backward`. + """ n_frequency: int = field(