-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix: Multiple PEFT methods have issues with models loaded in float16 or bfloat16 #2433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
BenjaminBossan
merged 17 commits into
huggingface:main
from
BenjaminBossan:fix-multiple-methods-model-dtype-issues
Apr 4, 2025
Merged
Changes from 9 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
7ba3aee
FIX Multiple issues with low precision base models
BenjaminBossan 7ad07be
Small fixes to OFT, LoRA MHA
BenjaminBossan 715f1a4
Fix bug in VeRA
BenjaminBossan da7123d
Skip float16/bfloat16 tests on MacOS
BenjaminBossan b55600f
Fix VeRA differently
BenjaminBossan 963d4dd
Small fix for bias in bone
BenjaminBossan 4ad67ff
Merge branch 'main' into fix-multiple-methods-model-dtype-issues
BenjaminBossan d3f916b
Merge branch 'main' into fix-multiple-methods-model-dtype-issues
BenjaminBossan 3484d4c
Fix tests to account for conv+groups model
BenjaminBossan 5345178
Document dtype handling in docs
BenjaminBossan 3732a54
Merge branch 'main' into fix-multiple-methods-model-dtype-issues
BenjaminBossan 48cff98
Fix errors introduced by merging
BenjaminBossan f5cfbf7
Also test safe merging
BenjaminBossan 51c39bb
Fix peft config tests
BenjaminBossan 075a18a
Merge branch 'main' into fix-multiple-methods-model-dtype-issues
BenjaminBossan b65f266
Swap order of arguments for BC
BenjaminBossan df4e086
Reviewer feedback: reuse variable
BenjaminBossan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: | |
| # Mark the weight as unmerged | ||
| self._disable_adapters = False | ||
| self.merged_adapters = [] | ||
| # flag to enable/disable casting of input to weight dtype during forward call | ||
| self.cast_input_dtype_enabled = True | ||
| self.kwargs = kwargs | ||
|
|
||
| base_layer = self.get_base_layer() | ||
|
|
@@ -150,6 +152,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N | |
| for active_adapter in adapter_names: | ||
| if active_adapter in self.bone_block.keys(): | ||
| base_layer = self.get_base_layer() | ||
| orig_dtype = base_layer.weight.dtype | ||
| if safe_merge: | ||
| # Note that safe_merge will be slower than the normal merge | ||
| # because of the copy operation. | ||
|
|
@@ -166,14 +169,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N | |
| f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" | ||
| ) | ||
|
|
||
| self.base_layer.weight.data = orig_weight | ||
| self.base_layer.weight.data = orig_weight.to(orig_dtype) | ||
| else: | ||
| if self.bone_fn == "bat": | ||
| delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data) | ||
| self.base_layer.weight.data += delta_weight | ||
| self.base_layer.weight.data += delta_weight.to(orig_dtype) | ||
| else: | ||
| delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data) | ||
| self.base_layer.weight.data = delta_weight | ||
| self.base_layer.weight.data = delta_weight.to(orig_dtype) | ||
|
||
| self.merged_adapters.append(active_adapter) | ||
|
|
||
| def unmerge(self) -> None: | ||
|
|
@@ -183,16 +186,19 @@ def unmerge(self) -> None: | |
| if not self.merged: | ||
| warnings.warn("Already unmerged. Nothing to do.") | ||
| return | ||
|
|
||
| while len(self.merged_adapters) > 0: | ||
| active_adapter = self.merged_adapters.pop() | ||
| base_layer = self.get_base_layer() | ||
| orig_dtype = base_layer.weight.dtype | ||
| if active_adapter in self.bone_block.keys(): | ||
| orig_weight = self.get_base_layer().weight.data.clone() | ||
| if self.bone_fn == "bat": | ||
| delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True) | ||
| else: | ||
| delta_weight = self.get_delta_weight_bone(active_adapter, orig_weight, re=True) | ||
|
|
||
| self.get_base_layer().weight.data = delta_weight | ||
| base_layer.weight.data = delta_weight.to(orig_dtype) | ||
|
|
||
| def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -213,12 +219,15 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens | |
|
|
||
| if cast_to_fp32: | ||
| weight_bone = weight_bone.float() | ||
| orig_weight = orig_weight.to(weight_bone.dtype) | ||
|
|
||
| r = weight_bone.size(-1) | ||
| if re: | ||
| o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) | ||
| one = torch.eye(weight_bone.size(-1)).to(weight_bone.device) | ||
| # inverse must be in float32, after that the dtype can be adjusted if needed | ||
| inv_I_plus_b = torch.inverse(one + weight_bone) | ||
| inv_I_plus_b = inv_I_plus_b.to(weight_bone.dtype) | ||
| w = (o - weight_bone) @ inv_I_plus_b | ||
| output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape) | ||
| else: | ||
|
|
@@ -318,7 +327,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | |
| delta_weight = self.get_delta_weight(active_adapter, orig_weight) | ||
| orig_weight = orig_weight + delta_weight | ||
|
|
||
| result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias) | ||
| x = self._cast_input_dtype(x, orig_weight.dtype) | ||
| bias = self._cast_input_dtype(self.base_layer.bias, orig_weight.dtype) | ||
| result = F.linear(input=x, weight=orig_weight, bias=bias) | ||
| else: | ||
| result = self.base_layer(x, *args, **kwargs) | ||
| for active_adapter in self.active_adapters: | ||
|
|
@@ -329,6 +340,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | |
| if x.size(-1) % r != 0: | ||
| padding_size = (r - x.size(-1) % r) % r | ||
| x = F.pad(x, (0, padding_size)) | ||
| x = self._cast_input_dtype(x, bone.dtype) | ||
| result = result + torch.sum(x.reshape(*x.shape[:-1], x.size(-1) // r, r), dim=-2) @ bone | ||
|
|
||
| result = result.to(previous_dtype) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.