-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
yodont keep adpater weights in fp32 #3143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
winglian
wants to merge
5
commits into
main
Choose a base branch
from
lora_bf16
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 1 commit
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -178,6 +178,38 @@ def get_state_dict(self, model, unwrap=True): | |
|
|
||
| return state_dict | ||
|
|
||
| def cast_lora_module(module): | ||
| base_layer_dtype = module.base_layer.weight.dtype | ||
| # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to | ||
| # wrap this. Therefore we must ensure the bias has the same dtype as the weight | ||
| if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None: | ||
| if module.base_layer.weight.dtype != module.base_layer.bias.dtype: | ||
| log_bias_dtype_mismatch = True | ||
| module.base_layer.bias.data = module.base_layer.bias.data.to( | ||
| module.base_layer.weight.dtype | ||
| ) | ||
|
|
||
| for active_adapter in module.active_adapters: | ||
| if module.lora_A: | ||
| module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype) | ||
| if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None: | ||
| module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype) | ||
| if module.lora_B: | ||
| module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype) | ||
| if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None: | ||
| module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype) | ||
| if module.lora_embedding_A: | ||
| module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype) | ||
| if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None: | ||
| module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype) | ||
| if module.lora_embedding_B: | ||
| module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype) | ||
| if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None: | ||
| module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype) | ||
| if module.lora_magnitude_vector: | ||
| module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype) | ||
| if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None: | ||
| module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype) | ||
|
|
||
| def _process_lora_module_for_fsdp(module, fsdp2_kwargs): | ||
| """Helper function to process LoRA modules for FSDP2.""" | ||
|
|
@@ -324,10 +356,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: | |
| if auto_wrap_policy is not None: | ||
| for module in get_module_children_bottom_up(model)[:-1]: | ||
| if is_peft_model and isinstance(module, LoraLayer): | ||
| module_log_bias_mismatch = _process_lora_module_for_fsdp( | ||
| module, fsdp2_kwargs | ||
| ) | ||
| log_bias_dtype_mismatch |= module_log_bias_mismatch | ||
| cast_lora_module(module) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could put this behind a config field - the only downside is that upstreaming it is going to be a pain |
||
| # module_log_bias_mismatch = _process_lora_module_for_fsdp( | ||
| # module, fsdp2_kwargs | ||
| # ) | ||
| # log_bias_dtype_mismatch |= module_log_bias_mismatch | ||
| if auto_wrap_policy(module) and not isinstance(module, FSDPModule): | ||
| fully_shard(module, **fsdp2_kwargs) | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Cast adapters to compute dtype (not quantized weight), avoid breaking ParameterDict, and fix F841
base_layer.weight.dtypecan incorrectly downcast to int/quantized dtypes (e.g., 4-bit), breaking training.module.lora_magnitude_vector[adapter] = ... .to(...)replaces aParameterwith aTensorin aParameterDict.log_bias_dtype_mismatchis set but unused (F841).Proposed fix: derive a safe
target_dtypefromcompute_dtype(if available) or a floatingweight.dtype; cast modules in-place; castParameterDictentries via.data; return a flag so callers can log.📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.12.2)
187-187: Local variable
log_bias_dtype_mismatchis assigned to but never usedRemove assignment to unused variable
log_bias_dtype_mismatch(F841)
🪛 GitHub Actions: lint
[error] 187-187: F841 Local variable
log_bias_dtype_mismatchis assigned to but never used.🤖 Prompt for AI Agents