diff --git a/docs/source/developer_guides/troubleshooting.md b/docs/source/developer_guides/troubleshooting.md index 13c7de2238..d1aba43b6a 100644 --- a/docs/source/developer_guides/troubleshooting.md +++ b/docs/source/developer_guides/troubleshooting.md @@ -39,7 +39,9 @@ Installing PEFT from source is useful for keeping up with the latest development python -m pip install git+https://github.com/huggingface/peft ``` -## ValueError: Attempting to unscale FP16 gradients +## Dtype-related issues + +### ValueError: Attempting to unscale FP16 gradients This error probably occurred because the model was loaded with `torch_dtype=torch.float16` and then used in an automatic mixed precision (AMP) context, e.g. by setting `fp16=True` in the [`~transformers.Trainer`] class from 🤗 Transformers. The reason is that when using AMP, trainable weights should never use fp16. To make this work without loading the whole model in fp32, add the following to your code: @@ -75,6 +77,23 @@ Starting from PEFT verion v0.12.0, PEFT automatically promotes the dtype of adap +### Selecting the dtype of the adapter + +Most PEFT methods, like LoRA, work by adding trainable adapter weights. By default, those weights are stored in float32 dtype (fp32), i.e. at a relatively high precision. Therefore, even if the base model is loaded in float16 (fp16) or bfloat16 (bf16), the adapter weights are float32. When the adapter results are calculated during the forward pass, the input will typically be in the dtype of the base model, thus it will be upcast to float32 if necessary, then cast back to the original dtype. + +If you prefer to have the adapter weights in the lower precision of the base model, i.e. in float16 or bfloat16, you can pass `autocast_adapter_dtype=False` when creating the model ([`~get_peft_model`]) or loading the model ([`~PeftModel.from_pretrained`]). There are some advantages and disadvantages to this: + +Advantages of half precision adapter: +- computation slightly faster +- slightly less memory +- smaller file size of checkpoint (half the size) + +Disadvantages of half precision adapter: +- slightly worse loss +- higher risk of overflow or underflow + +Note that for most use cases, overall runtime and memory cost will be determined by the size of the base model and by the dataset, while the dtype of the PEFT adapter will only have a small impact. + ## Bad results from a loaded PEFT model There can be several reasons for getting a poor result from a loaded PEFT model which are listed below. If you're still unable to troubleshoot the problem, see if anyone else had a similar [issue](https://github.com/huggingface/peft/issues) on GitHub, and if you can't find any, open a new issue. diff --git a/src/peft/helpers.py b/src/peft/helpers.py index 225bc5003f..3246e74993 100644 --- a/src/peft/helpers.py +++ b/src/peft/helpers.py @@ -22,6 +22,7 @@ from .peft_model import PeftConfig, PeftModel from .tuners.lora import LoraLayer +from .tuners.tuners_utils import BaseTunerLayer def update_forward_signature(model: PeftModel) -> None: @@ -218,8 +219,6 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True): """ Context manager disables input dtype casting to the dtype of the weight. - Currently specifically works for LoRA. - Parameters: model (nn.Module): The model containing PEFT modules whose input dtype casting is to be adjusted. @@ -237,7 +236,7 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True): original_values = {} for name, module in model.named_modules(): - if not isinstance(module, LoraLayer): + if not isinstance(module, BaseTunerLayer): continue original_values[name] = module.cast_input_dtype_enabled module.cast_input_dtype_enabled = False @@ -246,7 +245,7 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True): yield finally: for name, module in model.named_modules(): - if not isinstance(module, LoraLayer): + if not isinstance(module, BaseTunerLayer): continue if name in original_values: module.cast_input_dtype_enabled = original_values[name] diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py index 06f372659c..299695cdfa 100644 --- a/src/peft/tuners/boft/layer.py +++ b/src/peft/tuners/boft/layer.py @@ -220,6 +220,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + # flag to enable/disable casting of input to weight dtype during forward call + self.cast_input_dtype_enabled = True self.kwargs = kwargs base_layer = self.get_base_layer() @@ -503,13 +505,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N for active_adapter in adapter_names: if active_adapter in self.boft_R.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weight = base_layer.weight.data.clone() butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) orig_weight = torch.transpose(orig_weight, 0, 1) - orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype)) orig_weight = torch.transpose(orig_weight, 0, 1) orig_weight = orig_weight * boft_s @@ -518,16 +521,16 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.base_layer.weight.data = orig_weight.contiguous() + self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype) else: butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) orig_weight = base_layer.weight.data.clone() orig_weight = torch.transpose(orig_weight, 0, 1) - orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype)) orig_weight = torch.transpose(orig_weight, 0, 1) orig_weight = orig_weight * boft_s - self.base_layer.weight.data = orig_weight.contiguous() + self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype) self.merged_adapters.append(active_adapter) @@ -538,17 +541,20 @@ def unmerge(self) -> None: if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return + while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if active_adapter in self.boft_R.keys(): butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) - orig_weight = self.get_base_layer().weight.data.clone() + orig_weight = base_layer.weight.data.clone() orig_weight = torch.transpose(orig_weight, 0, 1) - orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight) + orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight.to(butterfly_oft_mat.dtype)) orig_weight = torch.transpose(orig_weight, 0, 1) - self.get_base_layer().weight.data = orig_weight * (1 / boft_s) + base_layer.weight.data = (orig_weight * (1 / boft_s)).to(orig_dtype) def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -804,6 +810,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N for active_adapter in adapter_names: if active_adapter in self.boft_R.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. @@ -814,14 +821,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0] ) orig_weight = torch.transpose(orig_weight, 0, 1) - orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype)) orig_weight = torch.transpose(orig_weight, 0, 1) orig_weight = orig_weight * boft_s orig_weight = orig_weight.view( self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] ) - self.base_layer.weight.data = orig_weight.contiguous() + self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype) else: butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) @@ -830,14 +837,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0] ) orig_weight = torch.transpose(orig_weight, 0, 1) - orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight.to(butterfly_oft_mat.dtype)) orig_weight = torch.transpose(orig_weight, 0, 1) orig_weight = orig_weight * boft_s orig_weight = orig_weight.view( self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] ) - self.base_layer.weight.data = orig_weight.contiguous() + self.base_layer.weight.data = orig_weight.contiguous().to(orig_dtype) self.merged_adapters.append(active_adapter) @@ -850,26 +857,28 @@ def unmerge(self) -> None: return while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if active_adapter in self.boft_R.keys(): butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) - orig_weight = self.get_base_layer().weight.data.clone() + orig_weight = base_layer.weight.data.clone() orig_weight = orig_weight.view( self.out_features, - self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0], + self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], ) orig_weight = torch.transpose(orig_weight, 0, 1) - orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight) + orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight.to(butterfly_oft_mat.dtype)) orig_weight = torch.transpose(orig_weight, 0, 1) orig_weight = orig_weight * (1 / boft_s) orig_weight = orig_weight.view( self.out_features, self.in_features, - self.get_base_layer().kernel_size[0], - self.get_base_layer().kernel_size[0], + base_layer.kernel_size[0], + base_layer.kernel_size[0], ) - self.get_base_layer().weight.data = orig_weight + base_layer.weight.data = orig_weight.to(orig_dtype) def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -968,10 +977,12 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: scaled_rotated_weight = scaled_rotated_weight.view( self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0] ) + x = self._cast_input_dtype(x, scaled_rotated_weight.dtype) + bias = self._cast_input_dtype(self.base_layer.bias, scaled_rotated_weight.dtype) result = F.conv2d( input=x, weight=scaled_rotated_weight, - bias=self.base_layer.bias, + bias=bias, padding=self.base_layer.padding[0], stride=self.base_layer.stride[0], ) diff --git a/src/peft/tuners/bone/layer.py b/src/peft/tuners/bone/layer.py index 5c8b263710..e6d822d6cf 100644 --- a/src/peft/tuners/bone/layer.py +++ b/src/peft/tuners/bone/layer.py @@ -35,6 +35,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + # flag to enable/disable casting of input to weight dtype during forward call + self.cast_input_dtype_enabled = True self.kwargs = kwargs base_layer = self.get_base_layer() @@ -150,6 +152,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N for active_adapter in adapter_names: if active_adapter in self.bone_block.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. @@ -166,14 +169,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.base_layer.weight.data = orig_weight + base_layer.weight.data = orig_weight.to(orig_dtype) else: if self.bone_fn == "bat": delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data) - self.base_layer.weight.data += delta_weight + base_layer.weight.data += delta_weight.to(orig_dtype) else: delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data) - self.base_layer.weight.data = delta_weight + base_layer.weight.data = delta_weight.to(orig_dtype) self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -183,8 +186,11 @@ def unmerge(self) -> None: if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return + while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if active_adapter in self.bone_block.keys(): orig_weight = self.get_base_layer().weight.data.clone() if self.bone_fn == "bat": @@ -192,7 +198,7 @@ def unmerge(self) -> None: else: delta_weight = self.get_delta_weight_bone(active_adapter, orig_weight, re=True) - self.get_base_layer().weight.data = delta_weight + base_layer.weight.data = delta_weight.to(orig_dtype) def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor: """ @@ -213,12 +219,15 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens if cast_to_fp32: weight_bone = weight_bone.float() + orig_weight = orig_weight.to(weight_bone.dtype) r = weight_bone.size(-1) if re: o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) one = torch.eye(weight_bone.size(-1)).to(weight_bone.device) + # inverse must be in float32, after that the dtype can be adjusted if needed inv_I_plus_b = torch.inverse(one + weight_bone) + inv_I_plus_b = inv_I_plus_b.to(weight_bone.dtype) w = (o - weight_bone) @ inv_I_plus_b output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape) else: @@ -318,7 +327,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: delta_weight = self.get_delta_weight(active_adapter, orig_weight) orig_weight = orig_weight + delta_weight - result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias) + x = self._cast_input_dtype(x, orig_weight.dtype) + bias = self._cast_input_dtype(self.base_layer.bias, orig_weight.dtype) + result = F.linear(input=x, weight=orig_weight, bias=bias) else: result = self.base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: @@ -329,6 +340,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: if x.size(-1) % r != 0: padding_size = (r - x.size(-1) % r) % r x = F.pad(x, (0, padding_size)) + x = self._cast_input_dtype(x, bone.dtype) result = result + torch.sum(x.reshape(*x.shape[:-1], x.size(-1) // r, r), dim=-2) @ bone result = result.to(previous_dtype) diff --git a/src/peft/tuners/fourierft/layer.py b/src/peft/tuners/fourierft/layer.py index a8f615a989..0cff748a3f 100644 --- a/src/peft/tuners/fourierft/layer.py +++ b/src/peft/tuners/fourierft/layer.py @@ -84,12 +84,13 @@ def reset_fourier_parameters(self, adapter_name): nn.init.zeros_(self.fourierft_spectrum[adapter_name]) def get_delta_weight(self, adapter) -> torch.Tensor: + # careful: ifft2 does not work with float16 or bfloat16 spectrum = self.fourierft_spectrum[adapter] indices = self.indices[adapter].to(spectrum.device) - dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device, dtype=spectrum.dtype) - dense_spectrum[indices[0, :], indices[1, :]] = spectrum + dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device) + dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter] - return delta_weight + return delta_weight.to(spectrum.dtype) class FourierFTLinear(nn.Module, FourierFTLayer): diff --git a/src/peft/tuners/hra/layer.py b/src/peft/tuners/hra/layer.py index 2494510c60..9570133e38 100644 --- a/src/peft/tuners/hra/layer.py +++ b/src/peft/tuners/hra/layer.py @@ -37,6 +37,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + # flag to enable/disable casting of input to weight dtype during forward call + self.cast_input_dtype_enabled = True self.kwargs = kwargs base_layer = self.get_base_layer() @@ -162,22 +164,24 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N for active_adapter in adapter_names: if active_adapter in self.hra_u.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weight = base_layer.weight.data.clone() delta_weight = self.get_delta_weight(active_adapter) - orig_weight = torch.mm(orig_weight, delta_weight) + orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight) if not torch.isfinite(orig_weight).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.base_layer.weight.data = orig_weight + base_layer.weight.data = orig_weight.to(orig_dtype) else: delta_weight = self.get_delta_weight(active_adapter) - self.base_layer.weight.data = torch.mm(self.base_layer.weight.data, delta_weight) + new_weight = torch.mm(base_layer.weight.data.to(delta_weight.dtype), delta_weight) + base_layer.weight.data = new_weight.to(orig_dtype) self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -187,12 +191,16 @@ def unmerge(self) -> None: if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return + while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if active_adapter in self.hra_u.keys(): - orig_weight = self.get_base_layer().weight.data.clone() + orig_weight = base_layer.weight.data.clone() delta_weight = self.get_delta_weight(active_adapter, reverse=True) - self.get_base_layer().weight.data = torch.mm(orig_weight, delta_weight) + new_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight) + base_layer.weight.data = new_weight.to(orig_dtype) def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor: rank = self.hra_r[adapter_name] @@ -240,13 +248,18 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: if active_adapter not in self.hra_u.keys(): continue delta_weight = self.get_delta_weight(active_adapter) - new_weight = torch.mm(new_weight, delta_weight) + new_weight = torch.mm(new_weight.to(delta_weight.dtype), delta_weight) - x = x.to(self.get_base_layer().weight.data.dtype) orig_weight = self.get_base_layer().weight.data + orig_weight = self._cast_input_dtype(orig_weight, new_weight.dtype) new_weight = torch.mm(orig_weight, new_weight) + bias = self._cast_input_dtype(self.base_layer.bias, new_weight.dtype) - result = F.linear(input=x, weight=new_weight, bias=self.base_layer.bias) + if self.cast_input_dtype_enabled: + x = self._cast_input_dtype(x, new_weight.dtype) + else: + x = x.to(self.get_base_layer().weight.data.dtype) + result = F.linear(input=x, weight=new_weight, bias=bias) result = result.to(previous_dtype) return result @@ -294,21 +307,22 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N for active_adapter in adapter_names: if active_adapter in self.hra_u.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weight = base_layer.weight.data.clone() orig_weight = orig_weight.view( self.out_features, - self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + self.in_features * base_layer.kernel_size[0] * self.base_layer.kernel_size[0], ) delta_weight = self.get_delta_weight(active_adapter) - orig_weight = torch.mm(orig_weight, delta_weight) + orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight) orig_weight = orig_weight.view( self.out_features, self.in_features, - self.base_layer.kernel_size[0], - self.base_layer.kernel_size[0], + base_layer.kernel_size[0], + base_layer.kernel_size[0], ) if not torch.isfinite(orig_weight).all(): @@ -316,7 +330,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.base_layer.weight.data = orig_weight + base_layer.weight.data = orig_weight.to(orig_dtype) else: orig_weight = base_layer.weight.data orig_weight = orig_weight.view( @@ -324,15 +338,15 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], ) delta_weight = self.get_delta_weight(active_adapter) - orig_weight = torch.mm(orig_weight, delta_weight) + orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight) orig_weight = orig_weight.view( self.out_features, self.in_features, - self.base_layer.kernel_size[0], - self.base_layer.kernel_size[0], + base_layer.kernel_size[0], + base_layer.kernel_size[0], ) - self.base_layer.weight.data = orig_weight + base_layer.weight.data = orig_weight.to(orig_dtype) self.merged_adapters.append(active_adapter) def unmerge(self) -> None: @@ -344,19 +358,21 @@ def unmerge(self) -> None: return while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if active_adapter in self.hra_u.keys(): - orig_weight = self.get_base_layer().weight.data.clone() + orig_weight = base_layer.weight.data.clone() orig_weight = orig_weight.view( self.out_features, - self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], ) delta_weight = self.get_delta_weight(active_adapter, reverse=True) - orig_weight = torch.mm(orig_weight, delta_weight) + orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight) orig_weight = orig_weight.view( - self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0] + self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] ) - self.get_base_layer().weight.data = orig_weight + base_layer.weight.data = orig_weight.to(orig_dtype) def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor: rank = self.hra_r[adapter_name] @@ -401,21 +417,21 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: new_weight = torch.eye( self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], device=x.device, - dtype=previous_dtype, ) for active_adapter in self.active_adapters: if active_adapter not in self.hra_u.keys(): continue delta_weight = self.get_delta_weight(active_adapter) - new_weight = torch.mm(new_weight, delta_weight) - - x = x.to(self.base_layer.weight.data.dtype) + new_weight = torch.mm(new_weight.to(delta_weight.dtype), delta_weight) orig_weight = self.base_layer.weight.data orig_weight = orig_weight.view( self.out_features, self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], ) + orig_weight = self._cast_input_dtype(orig_weight, new_weight.dtype) + bias = self._cast_input_dtype(self.base_layer.bias, new_weight.dtype) + new_weight = torch.mm(orig_weight, new_weight) new_weight = new_weight.view( self.out_features, @@ -424,10 +440,14 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: self.base_layer.kernel_size[0], ) + if self.cast_input_dtype_enabled: + x = self._cast_input_dtype(x, new_weight.dtype) + else: + x = x.to(self.get_base_layer().weight.data.dtype) result = F.conv2d( input=x, weight=new_weight, - bias=self.base_layer.bias, + bias=bias, padding=self.base_layer.padding[0], stride=self.base_layer.stride[0], ) diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index e941089315..ba568a617f 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -234,6 +234,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N for active_adapter in adapter_names: if active_adapter in self.ia3_l.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.data.dtype ia3_scaling = self.ia3_l[active_adapter].data if not self.is_feedforward: ia3_scaling = ia3_scaling.transpose(0, 1) @@ -246,13 +247,13 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - base_layer.weight.data = output_weight + base_layer.weight.data = output_weight.to(orig_dtype) else: - base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_scaling) + base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_scaling).to(orig_dtype) if not self.is_feedforward and (base_layer.bias is not None): scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) - base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data) + base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype) self.merged_adapters.append(active_adapter) @@ -269,15 +270,17 @@ def unmerge(self) -> None: active_adapter = self.merged_adapters.pop() if active_adapter in self.ia3_l.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.data.dtype # divide by (IA)^3 vector. Add tolerace to avoid division by zero ia3_scaling = self.ia3_l[active_adapter].data if not self.is_feedforward: ia3_scaling = ia3_scaling.transpose(0, 1) - base_layer.weight.data = torch.div(base_layer.weight.data, ia3_scaling + 1e-8) + base_layer.weight.data = torch.div(base_layer.weight.data, ia3_scaling + 1e-8).to(orig_dtype) if not self.is_feedforward and (base_layer.bias is not None): scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) - base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data) + orig_dtype = base_layer.bias.data.dtype + base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype) def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: dtype = previous_dtype = x.dtype diff --git a/src/peft/tuners/ln_tuning/layer.py b/src/peft/tuners/ln_tuning/layer.py index 1c90d4b9e6..025526d734 100644 --- a/src/peft/tuners/ln_tuning/layer.py +++ b/src/peft/tuners/ln_tuning/layer.py @@ -60,7 +60,8 @@ def enable_adapters(self, enabled: bool) -> None: layer.requires_grad_(False) self._disable_adapters = True - def merge(self, adapter_names: Optional[List[str]] = None): + def merge(self, adapter_names: Optional[List[str]] = None, safe_merge: bool = False): + # note that there is no actual merging, so whether safe_merge is True or False is irrelevant adapter_names = check_adapters_to_merge(self, adapter_names) if not adapter_names: # no adapter to merge diff --git a/src/peft/tuners/ln_tuning/model.py b/src/peft/tuners/ln_tuning/model.py index 3e16a7fad4..a4cc716866 100644 --- a/src/peft/tuners/ln_tuning/model.py +++ b/src/peft/tuners/ln_tuning/model.py @@ -203,3 +203,9 @@ def merge_and_unload( self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None ) -> nn.Module: return self._unload_and_optionally_merge(merge=True) + + def _cast_adapter_dtype(self, adapter_name: str, autocast_adapter_dtype: bool = True) -> None: + # Note: LN Tuning does not add adapter layers, instead it creates copies of the original layer. For this reason, + # we need to skip adapter autocasting, otherwise we would change the dtype of copies of the original layer, + # resulting in dtype errors down the line. + pass diff --git a/src/peft/tuners/loha/layer.py b/src/peft/tuners/loha/layer.py index 6d294af669..a22e840305 100644 --- a/src/peft/tuners/loha/layer.py +++ b/src/peft/tuners/loha/layer.py @@ -239,6 +239,7 @@ def _get_delta_activations( self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any ) -> torch.Tensor: delta_weight = self.get_delta_weight(adapter_name) + input = self._cast_input_dtype(input, delta_weight.dtype) # don't add bias here, because the bias is already included in the output of the base_layer return F.linear(input, delta_weight) @@ -274,6 +275,7 @@ def _get_delta_activations( self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any ) -> torch.Tensor: delta_weight = self.get_delta_weight(adapter_name) + input = self._cast_input_dtype(input, delta_weight.dtype) # don't add bias here, because the bias is already included in the output of the base_layer base_layer = self.get_base_layer() return F.conv2d( diff --git a/src/peft/tuners/lokr/layer.py b/src/peft/tuners/lokr/layer.py index 1c8cf1bbd9..a705a9963f 100644 --- a/src/peft/tuners/lokr/layer.py +++ b/src/peft/tuners/lokr/layer.py @@ -303,6 +303,7 @@ def _get_delta_activations( self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any ) -> torch.Tensor: delta_weight = self.get_delta_weight(adapter_name) + input = self._cast_input_dtype(input, delta_weight.dtype) # don't add bias here, because the bias is already included in the output of the base_layer return F.linear(input, delta_weight) @@ -340,6 +341,7 @@ def _get_delta_activations( self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any ) -> torch.Tensor: delta_weight = self.get_delta_weight(adapter_name) + input = self._cast_input_dtype(input, delta_weight.dtype) # don't add bias here, because the bias is already included in the output of the base_layer base_layer = self.get_base_layer() return F.conv2d( diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 300d837362..7ff73fe2d9 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -532,19 +532,6 @@ def _mixed_batch_forward( return result - def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: - """ - Whether to cast the dtype of the input to the forward method. - - Usually, we want to enable this to align the input dtype with the dtype of the weight, but by setting - layer.cast_input_dtype=False, this can be disabled if necessary. - - Enabling or disabling can be managed via the peft.helpers.disable_lora_input_dtype_casting context manager. - """ - if (not self.cast_input_dtype_enabled) or (x.dtype == dtype): - return x - return x.to(dtype=dtype) - # Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py # and modified to work with PyTorch FSDP @@ -623,9 +610,10 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weight = base_layer.weight.data.clone() + orig_dtype = orig_weight.dtype if active_adapter not in self.lora_variant: # vanilla LoRA delta_weight = self.get_delta_weight(active_adapter) - orig_weight += delta_weight + orig_weight += delta_weight.to(orig_dtype) else: orig_weight = self.lora_variant[active_adapter].merge_safe(self, active_adapter, orig_weight) @@ -642,7 +630,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - base_layer.bias.data = new_bias + base_layer.bias.data = new_bias.to(orig_dtype) else: if active_adapter not in self.lora_variant: # vanilla LoRA @@ -668,8 +656,9 @@ def unmerge(self) -> None: if active_adapter in self.lora_A.keys(): weight = self.get_base_layer().weight if active_adapter not in self.lora_variant: # vanilla LoRA + orig_dtype = weight.dtype delta_weight = self.get_delta_weight(active_adapter) - weight.data -= delta_weight + weight.data -= delta_weight.to(orig_dtype) else: unmerged = self.lora_variant[active_adapter].unmerge(self, active_adapter, weight) weight.data = unmerged @@ -870,12 +859,13 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N for active_adapter in adapter_names: if active_adapter in self.lora_embedding_A.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weight = base_layer.weight.data.clone() if active_adapter not in self.lora_variant: # vanilla LoRA - orig_weight += self.get_delta_weight(active_adapter) + orig_weight += self.get_delta_weight(active_adapter).to(orig_dtype) else: orig_weight = self.lora_variant[active_adapter].merge_safe(self, active_adapter, orig_weight) @@ -887,7 +877,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N base_layer.weight.data = orig_weight else: if active_adapter not in self.lora_variant: # vanilla LoRA - base_layer.weight.data += self.get_delta_weight(active_adapter) + base_layer.weight.data += self.get_delta_weight(active_adapter).to(orig_dtype) else: self.lora_variant[active_adapter].merge_unsafe(self, active_adapter, base_layer.weight) self.merged_adapters.append(active_adapter) @@ -901,10 +891,11 @@ def unmerge(self) -> None: return while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() + orig_dtype = self.get_base_layer().weight.dtype if active_adapter in self.lora_embedding_A.keys(): weight = self.get_base_layer().weight if active_adapter not in self.lora_variant: # vanilla LoRA - weight.data -= self.get_delta_weight(active_adapter) + weight.data -= self.get_delta_weight(active_adapter).to(orig_dtype) else: unmerged = self.lora_variant[active_adapter].unmerge(self, active_adapter, weight) weight.data = unmerged @@ -1141,6 +1132,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N for active_adapter in adapter_names: if active_adapter in self.lora_A.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if base_layer.groups > 1: # https://github.com/huggingface/peft/pull/2403 @@ -1152,7 +1144,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N orig_weight = base_layer.weight.data.clone() if active_adapter not in self.lora_variant: # vanilla LoRA delta_weight = self.get_delta_weight(active_adapter) - orig_weight += delta_weight + orig_weight += delta_weight.to(orig_dtype) else: orig_weight = self.lora_variant[active_adapter].merge_safe(self, active_adapter, orig_weight) @@ -1160,6 +1152,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) + base_layer.weight.data = orig_weight if self.lora_bias[active_adapter]: @@ -1168,12 +1161,12 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - base_layer.bias.data = new_bias + base_layer.bias.data = new_bias.to(orig_dtype) else: if active_adapter not in self.lora_variant: # vanilla LoRA delta_weight = self.get_delta_weight(active_adapter) - base_layer.weight.data += delta_weight + base_layer.weight.data += delta_weight.to(orig_dtype) else: self.lora_variant[active_adapter].merge_unsafe(self, active_adapter, base_layer.weight) @@ -1194,8 +1187,9 @@ def unmerge(self) -> None: if active_adapter in self.lora_A.keys(): weight = self.get_base_layer().weight if active_adapter not in self.lora_variant: # vanilla LoRA + orig_dtype = weight.dtype delta_weight = self.get_delta_weight(active_adapter) - weight.data -= delta_weight + weight.data -= delta_weight.to(orig_dtype) else: unmerged = self.lora_variant[active_adapter].unmerge(self, active_adapter, weight) weight.data = unmerged @@ -1494,11 +1488,12 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N for active_adapter in adapter_names: if active_adapter in self.lora_A.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.out_proj.weight.dtype if safe_merge: # TODO: work with separate weights # merging in_proj (nn.Parameter) orig_weight_in = base_layer.in_proj_weight.data.detach().clone() - orig_weight_in += self.get_delta_weight(active_adapter) + orig_weight_in += self.get_delta_weight(active_adapter).to(orig_dtype) if not torch.isfinite(orig_weight_in).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" @@ -1506,7 +1501,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N # merging out_proj (subclass of nn.Linear) orig_weight_out = base_layer.out_proj.weight.data.detach().clone() - orig_weight_out += base_layer.out_proj.get_delta_weight(active_adapter) + orig_weight_out += base_layer.out_proj.get_delta_weight(active_adapter).to(orig_dtype) if not torch.isfinite(orig_weight_out).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" @@ -1523,7 +1518,8 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N else: # merging in_proj (nn.Parameter) # TODO: work with separate weights - weight_merged = base_layer.in_proj_weight.data.detach() + self.get_delta_weight(active_adapter) + delta_weight = self.get_delta_weight(active_adapter).to(orig_dtype) + weight_merged = base_layer.in_proj_weight.data.detach() + delta_weight # unregister parameter implicitly and overwrite using merged weights; gradients are computed after # forward and, thus, after unmerging (see forward()), therefore this is safe to do. @@ -1531,9 +1527,8 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N base_layer.in_proj_weight = weight_merged # merging out_proj (subclass of nn.Linear) - weight_merged = base_layer.out_proj.weight.data.detach() + base_layer.out_proj.get_delta_weight( - active_adapter - ) + delta_weight = base_layer.out_proj.get_delta_weight(active_adapter).to(orig_dtype) + weight_merged = base_layer.out_proj.weight.data.detach() + delta_weight del base_layer.out_proj.get_base_layer().weight base_layer.out_proj.get_base_layer().weight = weight_merged base_layer.out_proj.merge(adapter_names=[active_adapter]) @@ -1549,6 +1544,7 @@ def unmerge(self) -> None: # TODO work with separate weights base_layer = self.get_base_layer() + orig_dtype = base_layer.out_proj.base_layer.weight.dtype while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.lora_A.keys(): @@ -1556,14 +1552,14 @@ def unmerge(self) -> None: # requires_grad was False when the optimizer was initialized, but still let's try to be correct here. # in_proj - old_weight = base_layer.in_proj_weight.data - self.get_delta_weight(active_adapter) + delta_weight = self.get_delta_weight(active_adapter).to(orig_dtype) + old_weight = base_layer.in_proj_weight.data - delta_weight del base_layer.in_proj_weight base_layer.register_parameter("in_proj_weight", nn.Parameter(old_weight, requires_grad=False)) # out_proj - old_weight = base_layer.out_proj.base_layer.weight.data - base_layer.out_proj.get_delta_weight( - active_adapter - ) + delta_weight = base_layer.out_proj.get_delta_weight(active_adapter).to(orig_dtype) + old_weight = base_layer.out_proj.base_layer.weight.data - delta_weight del base_layer.out_proj.base_layer.weight base_layer.out_proj.base_layer.register_parameter( "weight", nn.Parameter(old_weight, requires_grad=False) diff --git a/src/peft/tuners/lora/variants.py b/src/peft/tuners/lora/variants.py index 8f2b9eaab2..de2c92faf0 100644 --- a/src/peft/tuners/lora/variants.py +++ b/src/peft/tuners/lora/variants.py @@ -58,6 +58,7 @@ def init(module: Linear, adapter_name: str, **kwargs: Any) -> None: @staticmethod def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype delta_weight = module.get_delta_weight(active_adapter) # since delta_weight already includes scaling, set it to 1 here @@ -72,11 +73,13 @@ def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) - module._cache_store(f"{active_adapter}-weight_norm", weight_norm) dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm dora_factor = transpose(dora_factor.view(-1, 1), module.fan_in_fan_out) - orig_weight = dora_factor * (orig_weight + delta_weight) - return orig_weight + new_weight = dora_factor * (orig_weight + delta_weight) + new_weight = new_weight.to(orig_dtype) + return new_weight @staticmethod def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None: + orig_dtype = orig_weight.dtype delta_weight = module.get_delta_weight(active_adapter) weight_norm = ( module.lora_magnitude_vector[active_adapter] @@ -90,14 +93,18 @@ def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm dora_factor = transpose(dora_factor.view(-1, 1), module.fan_in_fan_out) new_weight = dora_factor * (orig_weight.data + delta_weight) + new_weight = new_weight.to(orig_dtype) orig_weight.data = new_weight @staticmethod def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype delta_weight = module.get_delta_weight(active_adapter) weight_norm = module._cache_pop(f"{active_adapter}-weight_norm") dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm - return orig_weight.data / dora_factor.view(-1, 1) - delta_weight + new_weight = orig_weight.data / dora_factor.view(-1, 1) - delta_weight + new_weight = new_weight.to(orig_dtype) + return new_weight @staticmethod def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor: @@ -141,6 +148,7 @@ def init(module: Embedding, adapter_name: str, **kwargs: Any) -> None: @staticmethod def merge_safe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype delta_weight = module.get_delta_weight(active_adapter) # since delta_weight already includes scaling, set it to 1 here @@ -155,11 +163,13 @@ def merge_safe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor module._cache_store(f"{active_adapter}-weight_norm", weight_norm) dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm dora_factor = dora_factor.view(1, -1) - orig_weight = dora_factor * (orig_weight + delta_weight) - return orig_weight + new_weight = dora_factor * (orig_weight + delta_weight) + new_weight = new_weight.to(orig_dtype) + return new_weight @staticmethod def merge_unsafe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> None: + orig_dtype = orig_weight.dtype delta_weight = module.get_delta_weight(active_adapter) weight_norm = ( module.lora_magnitude_vector[active_adapter] @@ -173,14 +183,18 @@ def merge_unsafe(module: Embedding, active_adapter: str, orig_weight: torch.Tens dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm dora_factor = dora_factor.view(1, -1) new_weight = dora_factor * (orig_weight.data + delta_weight) + new_weight = new_weight.to(orig_dtype) orig_weight.data = new_weight @staticmethod def unmerge(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype delta_weight = module.get_delta_weight(active_adapter) weight_norm = module._cache_pop(f"{active_adapter}-weight_norm") dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm - return orig_weight.data / dora_factor.view(1, -1) - delta_weight + new_weight = orig_weight.data / dora_factor.view(1, -1) - delta_weight + new_weight = new_weight.to(orig_dtype) + return new_weight @staticmethod def forward(module: Embedding, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor: @@ -215,6 +229,7 @@ def init_convd_variant(module: _ConvNd, adapter_name: str, dora_layer: nn.Module @staticmethod def merge_safe(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype delta_weight = module.get_delta_weight(active_adapter) # since delta_weight already includes scaling, set it to 1 here @@ -226,11 +241,13 @@ def merge_safe(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) # different value module._cache_store(f"{active_adapter}-weight_norm", weight_norm) dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm - orig_weight = dora_factor.view(*module._get_dora_factor_view()) * (orig_weight + delta_weight) - return orig_weight + new_weight = dora_factor.view(*module._get_dora_factor_view()) * (orig_weight + delta_weight) + new_weight = new_weight.to(orig_dtype) + return new_weight @staticmethod def merge_unsafe(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> None: + orig_dtype = orig_weight.dtype delta_weight = module.get_delta_weight(active_adapter) # since delta_weight already includes scaling, set it to 1 here weight_norm = ( @@ -242,14 +259,18 @@ def merge_unsafe(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor module._cache_store(f"{active_adapter}-weight_norm", weight_norm) dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm new_weight = dora_factor.view(*module._get_dora_factor_view()) * (orig_weight.data + delta_weight) + new_weight = new_weight.to(orig_dtype) orig_weight.data = new_weight @staticmethod def unmerge(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype delta_weight = module.get_delta_weight(active_adapter) weight_norm = module._cache_pop(f"{active_adapter}-weight_norm") dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm - return orig_weight.data / dora_factor.view(*module._get_dora_factor_view()) - delta_weight + new_weight = orig_weight.data / dora_factor.view(*module._get_dora_factor_view()) - delta_weight + new_weight = new_weight.to(orig_dtype) + return new_weight @staticmethod def forward(module: _ConvNd, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor: diff --git a/src/peft/tuners/lycoris_utils.py b/src/peft/tuners/lycoris_utils.py index 99d42c1d7d..4df3bfb9d3 100644 --- a/src/peft/tuners/lycoris_utils.py +++ b/src/peft/tuners/lycoris_utils.py @@ -77,6 +77,8 @@ def __init__(self, base_layer: nn.Module) -> None: # Tuner info self._disable_adapters = False self.merged_adapters = [] + # flag to enable/disable casting of input to weight dtype during forward call + self.cast_input_dtype_enabled = True @property @abstractmethod diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index 7d58a8c023..f49de20e85 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -101,6 +101,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + # flag to enable/disable casting of input to weight dtype during forward call + self.cast_input_dtype_enabled = True self.kwargs = kwargs base_layer = self.get_base_layer() @@ -344,13 +346,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N for active_adapter in adapter_names: if active_adapter in self._available_adapters: base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weights = base_layer.weight.data oft_mat, oft_s = self.get_delta_weight(active_adapter) orig_weights = torch.transpose(orig_weights, 0, 1) - orig_weights = torch.mm(oft_mat, orig_weights) + orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype)) orig_weights = torch.transpose(orig_weights, 0, 1) orig_weights = orig_weights * oft_s @@ -359,16 +362,16 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - base_layer.weight.data = orig_weights.contiguous() + base_layer.weight.data = orig_weights.contiguous().to(orig_dtype) else: oft_mat, oft_s = self.get_delta_weight(active_adapter) orig_weights = base_layer.weight.data orig_weights = torch.transpose(orig_weights, 0, 1) - orig_weights = torch.mm(oft_mat, orig_weights) + orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype)) orig_weights = torch.transpose(orig_weights, 0, 1) orig_weights = orig_weights * oft_s - base_layer.weight.data = orig_weights.contiguous() + base_layer.weight.data = orig_weights.contiguous().to(orig_dtype) self.merged_adapters.append(active_adapter) @@ -379,6 +382,9 @@ def unmerge(self) -> None: if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return + + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.oft_r.keys(): @@ -386,10 +392,10 @@ def unmerge(self) -> None: orig_weights = self.get_base_layer().weight.data orig_weights = torch.transpose(orig_weights, 0, 1) - orig_weights = torch.mm(oft_mat.t(), orig_weights) + orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype)) orig_weights = torch.transpose(orig_weights, 0, 1) - self.get_base_layer().weight.data = orig_weights * (1 / oft_s) + base_layer.weight.data = (orig_weights * (1 / oft_s)).to(orig_dtype) def get_delta_weight(self, adapter_name) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -425,8 +431,8 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: elif self.merged: result = self.base_layer(x, *args, **kwargs) else: - oft_rotation = torch.eye(self.in_features, device=x.device, dtype=previous_dtype) - oft_scale = torch.ones((int(self.out_features), 1), device=x.device, dtype=previous_dtype) + oft_rotation = torch.eye(self.in_features, device=x.device) + oft_scale = torch.ones((int(self.out_features), 1), device=x.device) for active_adapter in self.active_adapters: if active_adapter not in self.oft_r.keys(): @@ -454,15 +460,12 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: orig_weight = self.get_base_layer().weight.data orig_weight = torch.transpose(orig_weight, 0, 1) - oft_rotation = oft_rotation.to(previous_dtype) - orig_weight = orig_weight.to(previous_dtype) - rotated_weight = torch.mm(oft_rotation, orig_weight) + rotated_weight = torch.mm(oft_rotation, orig_weight.to(oft_rotation.dtype)) rotated_weight = torch.transpose(rotated_weight, 0, 1) - scaled_rotated_weight = rotated_weight * oft_scale - scaled_rotated_weight = scaled_rotated_weight.to(previous_dtype) - bias = self.get_base_layer().bias.to(previous_dtype) if self.get_base_layer().bias is not None else None + x = self._cast_input_dtype(x, scaled_rotated_weight.dtype) + bias = self._cast_input_dtype(self.get_base_layer().bias, scaled_rotated_weight.dtype) result = F.linear(input=x, weight=scaled_rotated_weight, bias=bias) result = result.to(previous_dtype) @@ -580,6 +583,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N for active_adapter in adapter_names: if active_adapter in self.oft_r.keys(): base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. @@ -590,14 +594,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0] ) orig_weights = torch.transpose(orig_weights, 0, 1) - orig_weights = torch.mm(oft_mat, orig_weights) + orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype)) orig_weights = torch.transpose(orig_weights, 0, 1) orig_weights = orig_weights * oft_s orig_weights = orig_weights.view( self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] ) - base_layer.weight.data = orig_weights.contiguous() + base_layer.weight.data = orig_weights.contiguous().to(orig_dtype) else: oft_mat, oft_s = self.get_delta_weight(active_adapter) @@ -606,14 +610,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0] ) orig_weights = torch.transpose(orig_weights, 0, 1) - orig_weights = torch.mm(oft_mat, orig_weights) + orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype)) orig_weights = torch.transpose(orig_weights, 0, 1) orig_weights = orig_weights * oft_s orig_weights = orig_weights.view( self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] ) - base_layer.weight.data = orig_weights.contiguous() + base_layer.weight.data = orig_weights.contiguous().to(orig_dtype) self.merged_adapters.append(active_adapter) @@ -624,6 +628,9 @@ def unmerge(self) -> None: if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return + + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.oft_r.keys(): @@ -635,7 +642,7 @@ def unmerge(self) -> None: self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0], ) orig_weights = torch.transpose(orig_weights, 0, 1) - orig_weights = torch.mm(oft_mat.t(), orig_weights) + orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype)) orig_weights = torch.transpose(orig_weights, 0, 1) orig_weights = orig_weights * (1 / oft_s) orig_weights = orig_weights.view( @@ -645,7 +652,7 @@ def unmerge(self) -> None: self.get_base_layer().kernel_size[0], ) - self.get_base_layer().weight.data = orig_weights + base_layer.weight.data = orig_weights.to(orig_dtype) def get_delta_weight(self, adapter_name) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -684,9 +691,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: oft_rotation = torch.eye( self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0], device=x.device, - dtype=previous_dtype, ) - oft_scale = torch.ones((int(self.out_features), 1), device=x.device, dtype=previous_dtype) + oft_scale = torch.ones((int(self.out_features), 1), device=x.device) for active_adapter in self.active_adapters: if active_adapter not in self.oft_r.keys(): @@ -710,8 +716,6 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: oft_rotation = oft_mat @ oft_rotation oft_scale = oft_s * oft_scale - x = x.to(self.get_base_layer().weight.data.dtype) - orig_weights = self.base_layer.weight.data orig_weights = orig_weights.view( self.out_features, @@ -731,10 +735,12 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: self.get_base_layer().kernel_size[0], self.get_base_layer().kernel_size[0], ) + x = self._cast_input_dtype(x, scaled_rotated_weight.dtype) + bias = self._cast_input_dtype(self.get_base_layer().bias, scaled_rotated_weight.dtype) result = F.conv2d( input=x, weight=scaled_rotated_weight, - bias=self.get_base_layer().bias, + bias=bias, padding=self.get_base_layer().padding[0], stride=self.get_base_layer().stride[0], ) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 770af21319..621eb2fc94 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -20,7 +20,7 @@ import warnings from abc import ABC, abstractmethod from contextlib import contextmanager, nullcontext -from typing import Any, Optional, Union +from typing import Any, Optional, Union, overload import torch from accelerate.hooks import AlignDevicesHook @@ -571,7 +571,7 @@ def inject_adapter( else: model.modules_to_save.update(set(peft_config.modules_to_save)) - def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None: + def merge_adapter(self, adapter_names: Optional[list[str]] = None, safe_merge: bool = False) -> None: """ This method merges the adapter layers into the base model. @@ -580,19 +580,25 @@ def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None: in memory, please call `merge_and_unload`. Args: + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. safe_merge (`bool`, *optional*): If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs before merging the weights. This is useful if you want to check if the merge operation will produce NaNs. Defaults to `False`. - adapter_names (`list[str]`, *optional*): - The list of adapter names that should be merged. If `None`, all active adapters will be merged. - Defaults to `None`. """ + # Note: The order of arguments here is: + # adapter_names, safe_merge + # For layer.merge, the order is: + # safe_merge, adapter_names + # This is not so nice but this method here started with only adapter_names, thus putting safe_merge first would + # be a backwards incompatible change. self._check_merge_allowed() for module in self.model.modules(): if isinstance(module, BaseTunerLayer): with onload_layer(module): - module.merge(adapter_names=adapter_names) + module.merge(adapter_names=adapter_names, safe_merge=safe_merge) def unmerge_adapter(self): """ @@ -871,6 +877,29 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio else: adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device) + @overload + def _cast_input_dtype(self, x: None, dtype: torch.dtype) -> None: ... + + @overload + def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: ... + + def _cast_input_dtype(self, x, dtype: torch.dtype): + """ + Whether to cast the dtype of the input of the forward method. + + Usually, we want to enable this to align the input dtype with the dtype of the weight, but by setting + layer.cast_input_dtype=False, this can be disabled if necessary. + + Enabling or disabling can be managed via the peft.helpers.disable_lora_input_dtype_casting context manager. + """ + if x is None: # useful e.g. if x is the bias, which can be None + return None + + cast_input_dtype_enabled = getattr(self, "cast_input_dtype_enabled", True) + if (not cast_input_dtype_enabled) or (x.dtype == dtype): + return x + return x.to(dtype=dtype) + def _find_minimal_target_modules( target_modules: list[str] | set[str], other_module_names: list[str] | set[str] diff --git a/src/peft/tuners/vera/layer.py b/src/peft/tuners/vera/layer.py index 25789e8f59..ce97bf29c4 100644 --- a/src/peft/tuners/vera/layer.py +++ b/src/peft/tuners/vera/layer.py @@ -248,11 +248,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor: if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) - # cast back the weights - # TODO: why? - self.vera_lambda_d[adapter].data = lambda_d.to(dtype) - self.vera_lambda_b[adapter].data = lambda_b.to(dtype) - return output_tensor def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: diff --git a/tests/test_config.py b/tests/test_config.py index 8bdf157f5a..179496b6f3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -46,7 +46,7 @@ ) -PEFT_MODELS_TO_TEST = [("lewtun/tiny-random-OPTForCausalLM-delta", "v1")] +PEFT_MODELS_TO_TEST = [("peft-internal-testing/tiny-opt-lora-revision", "test")] # Config classes and their mandatory parameters ALL_CONFIG_CLASSES = ( diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 69358d018d..06ef6de608 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -16,6 +16,7 @@ # limitations under the License. import copy import os +import platform import re import shutil import tempfile @@ -688,9 +689,10 @@ def __init__(self, bias=True): self.drop = nn.Dropout(0.5) self.lin1 = nn.Linear(20, 2, bias=bias) self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float def forward(self, X): - X = X.float() + X = X.to(self.dtype) X = self.lin0(X) X = self.relu(X) X = self.drop(X) @@ -708,9 +710,10 @@ def __init__(self, bias=True): self.gru = nn.GRU(input_size=20, hidden_size=20, num_layers=1, batch_first=True, bias=bias) self.fc = nn.Linear(20, 2, bias=bias) self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float def forward(self, X): - X = X.float() + X = X.to(self.dtype) X = self.lin0(X) X = self.relu(X) X = self.drop(X) @@ -732,9 +735,10 @@ def __init__(self, bias=True): self.layernorm1 = nn.LayerNorm(20, 20) self.lin1 = nn.Linear(20, 2, bias=bias) self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float def forward(self, X): - X = X.float() + X = X.to(self.dtype) X = self.layernorm0(X) X = self.lin0(X) X = self.relu(X) @@ -753,9 +757,10 @@ def __init__(self, bias=True): self.drop = nn.Dropout(0.5) self.lin1 = nn.Linear(32, 2, bias=bias) self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float def forward(self, X): - X = X.float() + X = X.to(self.dtype) X = self.lin0(X) X = self.relu(X) X = self.drop(X) @@ -852,9 +857,11 @@ def __init__(self): self.flat = nn.Flatten() self.lin0 = nn.Linear(9, 2) self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float def forward(self, X): - X = X.float().reshape(-1, 1, 10) + X = X.to(self.dtype) + X = X.reshape(-1, 1, 10) X = self.conv1d(X) X = self.relu(X) X = self.flat(X) @@ -871,9 +878,11 @@ def __init__(self): self.flat = nn.Flatten() self.lin0 = nn.Linear(10, 2) self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float def forward(self, X): - X = X.float().reshape(-1, 5, 3, 3) + X = X.to(self.dtype) + X = X.reshape(-1, 5, 3, 3) X = self.conv2d(X) X = self.relu(X) X = self.flat(X) @@ -891,9 +900,10 @@ def __init__(self): self.flat = nn.Flatten() self.lin1 = nn.Linear(32, 2) self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float def forward(self, X): - X = X.float() + X = X.to(self.dtype) X = self.lin0(X) X = self.relu(X) X = X.reshape(-1, 8, 3, 3) @@ -913,9 +923,11 @@ def __init__(self): self.flat = nn.Flatten() self.lin0 = nn.Linear(5, 2) self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float def forward(self, X): - X = X.float().reshape(-1, 5, 3, 3) + X = X.to(self.dtype) + X = X.reshape(-1, 5, 3, 3) X = self.conv2d(X) X = self.relu(X) X = self.flat(X) @@ -932,12 +944,14 @@ def __init__(self): self.flat = nn.Flatten() self.lin0 = nn.Linear(10, 2) self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float def forward(self, X): + X = X.to(self.dtype) # If necessary, convert from 2D image to 3D volume if X.dim() == 2: X = torch.stack([X] * 3, dim=-1) - X = X.float().reshape(-1, 5, 3, 3, 3) + X = X.reshape(-1, 5, 3, 3, 3) X = self.conv3d(X) X = self.relu(X) X = self.flat(X) @@ -952,9 +966,10 @@ def __init__(self): self.mha = nn.MultiheadAttention(10, 2) self.lin0 = nn.Linear(10, 2) self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float def forward(self, X): - X = X.float() + X = X.to(self.dtype) X, _ = self.mha(X, X, X) X = self.lin0(X) X = self.sm(X) @@ -1184,6 +1199,172 @@ def test_forward_output_finite(self, test_name, model_id, config_cls, config_kwa output = model(**X) assert torch.isfinite(output).all() + @parameterized.expand(TEST_CASES) + def test_forward_float16(self, test_name, model_id, config_cls, config_kwargs): + # The user manually sets the dtype of the base model to fp16 precision. This should not cause an error for the + # different PEFT methods. + try: + torch.zeros(1, dtype=torch.float16) + except Exception: + # skip this test if float16 is not supported on this machine + self.skipTest(reason="Test requires float16 support") + + # skip on MacOS + if platform.system() == "Darwin": + self.skipTest(reason="MacOS does not support multiple ops in float16") + + X = self.prepare_inputs_for_testing() + model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.float16).to(self.torch_device) + model.dtype = torch.float16 + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model.eval() + + # check that none of this raises an error + model(**X) + + if model_id in ["Conv2dGroups"]: + # this model does not support merging + return + + model.merge_adapter(safe_merge=False) + model(**X) + model.unmerge_adapter() + model(**X) + model.merge_adapter(safe_merge=True) + model(**X) + model.unmerge_adapter() + model(**X) + model = model.merge_and_unload() + model(**X) + + @parameterized.expand(TEST_CASES) + def test_forward_bfloat16(self, test_name, model_id, config_cls, config_kwargs): + # The user manually sets the dtype of the base model to bf16 precision. This should not cause an error for the + # different PEFT methods. + try: + torch.zeros(1, dtype=torch.bfloat16) + except Exception: + # skip this test if float16 is not supported on this machine + self.skipTest(reason="Test requires bfloat16 support") + + # skip on MacOS + if platform.system() == "Darwin": + self.skipTest(reason="MacOS does not support multiple ops in bfloat16") + + X = self.prepare_inputs_for_testing() + model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(self.torch_device) + model.dtype = torch.bfloat16 + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model.eval() + + # check that none of this raises an error + model(**X) + + if model_id in ["Conv2dGroups"]: + # this model does not support merging + return + + model.merge_adapter(safe_merge=False) + model(**X) + model.unmerge_adapter() + model(**X) + model.merge_adapter(safe_merge=True) + model(**X) + model.unmerge_adapter() + model(**X) + model = model.merge_and_unload() + model(**X) + + @parameterized.expand(TEST_CASES) + def test_forward_float16_no_autocast(self, test_name, model_id, config_cls, config_kwargs): + # Same as above but don't autocast adapter weights to float32 automatically + try: + torch.zeros(1, dtype=torch.float16) + except Exception: + # skip this test if float16 is not supported on this machine + self.skipTest(reason="Test requires float16 support") + + # skip on MacOS + if platform.system() == "Darwin": + self.skipTest(reason="MacOS does not support multiple ops in float16") + + X = self.prepare_inputs_for_testing() + model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.float16).to(self.torch_device) + model.dtype = torch.float16 + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config, autocast_adapter_dtype=False) + model.eval() + + # check that none of this raises an error + model(**X) + + if model_id in ["Conv2dGroups"]: + # this model does not support merging + return + + model.merge_adapter(safe_merge=False) + model(**X) + model.unmerge_adapter() + model(**X) + model.merge_adapter(safe_merge=True) + model(**X) + model.unmerge_adapter() + model(**X) + model = model.merge_and_unload() + model(**X) + + @parameterized.expand(TEST_CASES) + def test_forward_bfloat16_no_autocast(self, test_name, model_id, config_cls, config_kwargs): + # Same as above but don't autocast adapter weights to float32 automatically + try: + torch.zeros(1, dtype=torch.bfloat16) + except Exception: + # skip this test if float16 is not supported on this machine + self.skipTest(reason="Test requires bfloat16 support") + + # skip on MacOS + if platform.system() == "Darwin": + self.skipTest(reason="MacOS does not support multiple ops in bfloat16") + + X = self.prepare_inputs_for_testing() + model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(self.torch_device) + model.dtype = torch.bfloat16 + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config, autocast_adapter_dtype=False) + model.eval() + + # check that none of this raises an error + model(**X) + + if model_id in ["Conv2dGroups"]: + # this model does not support merging + return + + model.merge_adapter(safe_merge=False) + model(**X) + model.unmerge_adapter() + model(**X) + model.merge_adapter(safe_merge=True) + model(**X) + model.unmerge_adapter() + model(**X) + model = model.merge_and_unload() + model(**X) + @parameterized.expand(TEST_CASES) def test_only_params_are_updated(self, test_name, model_id, config_cls, config_kwargs): # An explicit test that when using an adapter on a custom model, only the adapter parameters are updated during