- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.2k
yodont keep adpater weights in fp32 #3143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the  You can disable this status message by setting the  📝 WalkthroughWalkthroughAdds cast_lora_module to enforce dtype consistency for LoRA components and updates fsdp2_prepare_model to use it instead of logging bias mismatches. The prior _process_lora_module_for_fsdp remains but is unused in this path. Changes are confined to FSDP2 LoRA handling in a single file. Changes
 Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Pre-merge checks (1 warning, 2 inconclusive)❌ Failed checks (1 warning, 2 inconclusive)
 ✨ Finishing touches🧪 Generate unit tests
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)
359-363: Wire up logging and drop dead codeConnect the returned mismatch flag to the existing aggregator and remove commented-out code.
- cast_lora_module(module) - # module_log_bias_mismatch = _process_lora_module_for_fsdp( - # module, fsdp2_kwargs - # ) - # log_bias_dtype_mismatch |= module_log_bias_mismatch + module_log_bias_mismatch = cast_lora_module(module) + log_bias_dtype_mismatch |= module_log_bias_mismatch
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
- src/axolotl/monkeypatch/accelerate/fsdp2.py(2 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/accelerate/fsdp2.py
187-187: Local variable log_bias_dtype_mismatch is assigned to but never used
Remove assignment to unused variable log_bias_dtype_mismatch
(F841)
🪛 GitHub Actions: lint
src/axolotl/monkeypatch/accelerate/fsdp2.py
[error] 187-187: F841 Local variable log_bias_dtype_mismatch is assigned to but never used.
[error] 178-192: ruff-format: File reformatted by formatter (1 file reformatted). Please review and commit changes.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.8.0)
| def cast_lora_module(module): | ||
| base_layer_dtype = module.base_layer.weight.dtype | ||
| # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to | ||
| # wrap this. Therefore we must ensure the bias has the same dtype as the weight | ||
| if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None: | ||
| if module.base_layer.weight.dtype != module.base_layer.bias.dtype: | ||
| log_bias_dtype_mismatch = True | ||
| module.base_layer.bias.data = module.base_layer.bias.data.to( | ||
| module.base_layer.weight.dtype | ||
| ) | ||
|  | ||
| for active_adapter in module.active_adapters: | ||
| if module.lora_A: | ||
| module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype) | ||
| if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None: | ||
| module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype) | ||
| if module.lora_B: | ||
| module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype) | ||
| if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None: | ||
| module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype) | ||
| if module.lora_embedding_A: | ||
| module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype) | ||
| if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None: | ||
| module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype) | ||
| if module.lora_embedding_B: | ||
| module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype) | ||
| if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None: | ||
| module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype) | ||
| if module.lora_magnitude_vector: | ||
| module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype) | ||
| if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None: | ||
| module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype) | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Cast adapters to compute dtype (not quantized weight), avoid breaking ParameterDict, and fix F841
- Risk: Casting LoRA modules to base_layer.weight.dtypecan incorrectly downcast to int/quantized dtypes (e.g., 4-bit), breaking training.
- Bug: Assigning module.lora_magnitude_vector[adapter] = ... .to(...)replaces aParameterwith aTensorin aParameterDict.
- Lint: log_bias_dtype_mismatchis set but unused (F841).
Proposed fix: derive a safe target_dtype from compute_dtype (if available) or a floating weight.dtype; cast modules in-place; cast ParameterDict entries via .data; return a flag so callers can log.
-def cast_lora_module(module):
-    base_layer_dtype = module.base_layer.weight.dtype
-    # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
-    # wrap this. Therefore we must ensure the bias has the same dtype as the weight
-    if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
-        if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
-            log_bias_dtype_mismatch = True
-            module.base_layer.bias.data = module.base_layer.bias.data.to(
-                module.base_layer.weight.dtype
-            )
-
-    for active_adapter in module.active_adapters:
-        if module.lora_A:
-            module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
-                module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_B:
-           module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
-           if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
-               module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_embedding_A:
-            module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
-                module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_embedding_B:
-            module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
-                module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_magnitude_vector:
-            module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
-                module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
+def cast_lora_module(module):
+    weight_dtype = module.base_layer.weight.dtype
+    # Linear4Bit keeps its bias term in fp32; if the weight dtype is bf16 we must match bias to weight.
+    # For adapter params, prefer the module's compute dtype when available; fall back to a floating weight dtype.
+    compute_dtype = getattr(module.base_layer, "compute_dtype", None)
+    float_dtypes = (torch.float16, torch.bfloat16, torch.float32)
+    target_dtype = (
+        compute_dtype
+        if compute_dtype in float_dtypes
+        else (weight_dtype if weight_dtype in float_dtypes else None)
+    )
+
+    bias_dtype_mismatch = False
+    if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
+        if module.base_layer.bias.dtype != weight_dtype:
+            module.base_layer.bias.data = module.base_layer.bias.data.to(weight_dtype)
+            bias_dtype_mismatch = True
+
+    # If we don't have a safe floating target dtype (e.g., quantized weight with no compute dtype), stop here.
+    if target_dtype is None:
+        return bias_dtype_mismatch
+
+    for active_adapter in module.active_adapters:
+        if getattr(module, "lora_A", None) and active_adapter in module.lora_A:
+            module.lora_A[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_A[active_adapter], "bias", None) is not None:
+                module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_B", None) and active_adapter in module.lora_B:
+            module.lora_B[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_B[active_adapter], "bias", None) is not None:
+                module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_embedding_A", None) and active_adapter in module.lora_embedding_A:
+            module.lora_embedding_A[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_embedding_A[active_adapter], "bias", None) is not None:
+                module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_embedding_B", None) and active_adapter in module.lora_embedding_B:
+            module.lora_embedding_B[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_embedding_B[active_adapter], "bias", None) is not None:
+                module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_magnitude_vector", None) and active_adapter in module.lora_magnitude_vector:
+            param = module.lora_magnitude_vector[active_adapter]
+            if isinstance(param, torch.nn.Parameter):
+                param.data = param.data.to(target_dtype)
+            else:
+                module.lora_magnitude_vector[active_adapter] = torch.nn.Parameter(
+                    param.to(target_dtype), requires_grad=getattr(param, "requires_grad", True)
+                )
+    return bias_dtype_mismatch📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def cast_lora_module(module): | |
| base_layer_dtype = module.base_layer.weight.dtype | |
| # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to | |
| # wrap this. Therefore we must ensure the bias has the same dtype as the weight | |
| if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None: | |
| if module.base_layer.weight.dtype != module.base_layer.bias.dtype: | |
| log_bias_dtype_mismatch = True | |
| module.base_layer.bias.data = module.base_layer.bias.data.to( | |
| module.base_layer.weight.dtype | |
| ) | |
| for active_adapter in module.active_adapters: | |
| if module.lora_A: | |
| module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype) | |
| if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None: | |
| module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype) | |
| if module.lora_B: | |
| module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype) | |
| if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None: | |
| module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype) | |
| if module.lora_embedding_A: | |
| module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype) | |
| if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None: | |
| module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype) | |
| if module.lora_embedding_B: | |
| module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype) | |
| if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None: | |
| module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype) | |
| if module.lora_magnitude_vector: | |
| module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype) | |
| if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None: | |
| module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype) | |
| def cast_lora_module(module): | |
| weight_dtype = module.base_layer.weight.dtype | |
| # Linear4Bit keeps its bias term in fp32; if the weight dtype is bf16 we must match bias to weight. | |
| # For adapter params, prefer the module's compute dtype when available; fall back to a floating weight dtype. | |
| compute_dtype = getattr(module.base_layer, "compute_dtype", None) | |
| float_dtypes = (torch.float16, torch.bfloat16, torch.float32) | |
| target_dtype = ( | |
| compute_dtype | |
| if compute_dtype in float_dtypes | |
| else (weight_dtype if weight_dtype in float_dtypes else None) | |
| ) | |
| bias_dtype_mismatch = False | |
| if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None: | |
| if module.base_layer.bias.dtype != weight_dtype: | |
| module.base_layer.bias.data = module.base_layer.bias.data.to(weight_dtype) | |
| bias_dtype_mismatch = True | |
| # If we don't have a safe floating target dtype (e.g., quantized weight with no compute dtype), stop here. | |
| if target_dtype is None: | |
| return bias_dtype_mismatch | |
| for active_adapter in module.active_adapters: | |
| if getattr(module, "lora_A", None) and active_adapter in module.lora_A: | |
| module.lora_A[active_adapter].to(dtype=target_dtype) | |
| if getattr(module.lora_A[active_adapter], "bias", None) is not None: | |
| module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(target_dtype) | |
| if getattr(module, "lora_B", None) and active_adapter in module.lora_B: | |
| module.lora_B[active_adapter].to(dtype=target_dtype) | |
| if getattr(module.lora_B[active_adapter], "bias", None) is not None: | |
| module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(target_dtype) | |
| if getattr(module, "lora_embedding_A", None) and active_adapter in module.lora_embedding_A: | |
| module.lora_embedding_A[active_adapter].to(dtype=target_dtype) | |
| if getattr(module.lora_embedding_A[active_adapter], "bias", None) is not None: | |
| module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(target_dtype) | |
| if getattr(module, "lora_embedding_B", None) and active_adapter in module.lora_embedding_B: | |
| module.lora_embedding_B[active_adapter].to(dtype=target_dtype) | |
| if getattr(module.lora_embedding_B[active_adapter], "bias", None) is not None: | |
| module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(target_dtype) | |
| if getattr(module, "lora_magnitude_vector", None) and active_adapter in module.lora_magnitude_vector: | |
| param = module.lora_magnitude_vector[active_adapter] | |
| if isinstance(param, torch.nn.Parameter): | |
| param.data = param.data.to(target_dtype) | |
| else: | |
| module.lora_magnitude_vector[active_adapter] = torch.nn.Parameter( | |
| param.to(target_dtype), requires_grad=getattr(param, "requires_grad", True) | |
| ) | |
| return bias_dtype_mismatch | 
🧰 Tools
🪛 Ruff (0.12.2)
187-187: Local variable log_bias_dtype_mismatch is assigned to but never used
Remove assignment to unused variable log_bias_dtype_mismatch
(F841)
🪛 GitHub Actions: lint
[error] 187-187: F841 Local variable log_bias_dtype_mismatch is assigned to but never used.
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/accelerate/fsdp2.py around lines 181-213, the
function casts LoRA submodules to module.base_layer.weight.dtype which may be a
quantized/int dtype and breaks training, it mutates ParameterDict entries by
replacing Parameters with Tensors, and sets log_bias_dtype_mismatch but never
uses it; change to compute a safe target_dtype: use module.compute_dtype if
present else pick a floating dtype version of base_layer.weight.dtype (e.g.,
promote int/quantized to float32/float16), perform in-place casting (use .data =
.data.to(target_dtype)) for any Parameter or ParameterDict entry to avoid
replacing Parameter objects, cast biases similarly via .data, remove the unused
log_bias_dtype_mismatch variable and instead return a boolean flag (e.g.,
cast_performed) from the function so callers can log when a change occurred.
| Codecov Report❌ Patch coverage is  
 📢 Thoughts on this report? Let us know! | 
| module, fsdp2_kwargs | ||
| ) | ||
| log_bias_dtype_mismatch |= module_log_bias_mismatch | ||
| cast_lora_module(module) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could put this behind a config field - the only downside is that upstreaming it is going to be a pain
Description
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit