Skip to content

Conversation

@winglian
Copy link
Collaborator

@winglian winglian commented Sep 9, 2025

Description

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features
    • Automatic dtype alignment for LoRA components during FSDP2 model preparation, minimizing manual adjustments.
  • Bug Fixes
    • Resolves mixed-precision mismatches between base weights/biases and LoRA adapters that could cause errors or instability.
    • Reduces noisy bias-related warnings by applying inline casting.
  • Refactor
    • Streamlined LoRA handling during FSDP2 preparation to enforce dtype consistency directly, improving reliability and setup smoothness.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 9, 2025

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Adds cast_lora_module to enforce dtype consistency for LoRA components and updates fsdp2_prepare_model to use it instead of logging bias mismatches. The prior _process_lora_module_for_fsdp remains but is unused in this path. Changes are confined to FSDP2 LoRA handling in a single file.

Changes

Cohort / File(s) Summary of changes
FSDP2 LoRA dtype casting
src/axolotl/monkeypatch/accelerate/fsdp2.py
Added cast_lora_module to cast LoRA params/biases to base layer dtype; modified fsdp2_prepare_model to call it for LoRA modules; commented out previous log-based mismatch handling; retained _process_lora_module_for_fsdp unused in this flow.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Pre-merge checks (1 warning, 2 inconclusive)

❌ Failed checks (1 warning, 2 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The current title “yodont keep adpater weights in fp32” contains typographical errors and does not clearly or concisely summarize the main change of casting LoRA adapter weights to the base layer’s dtype, rendering its intent vague and difficult to understand from a glance. Please revise the title to a clear, single‐sentence summary of the primary change, for example “Add cast_lora_module to cast LoRA adapter weights to base dtype.”
Description Check ❓ Inconclusive The pull request description consists only of the default template placeholders without any actual details of the changes, motivation, or testing, making it overly generic and uninformative regarding the modifications introduced. Please populate the description with a concise summary of what was changed, why the change is needed, and how it was tested to give reviewers the necessary context.
✨ Finishing touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch lora_bf16

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)

359-363: Wire up logging and drop dead code

Connect the returned mismatch flag to the existing aggregator and remove commented-out code.

-                cast_lora_module(module)
-                # module_log_bias_mismatch = _process_lora_module_for_fsdp(
-                #     module, fsdp2_kwargs
-                # )
-                # log_bias_dtype_mismatch |= module_log_bias_mismatch
+                module_log_bias_mismatch = cast_lora_module(module)
+                log_bias_dtype_mismatch |= module_log_bias_mismatch
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9640338 and 6daed7d.

📒 Files selected for processing (1)
  • src/axolotl/monkeypatch/accelerate/fsdp2.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/accelerate/fsdp2.py

187-187: Local variable log_bias_dtype_mismatch is assigned to but never used

Remove assignment to unused variable log_bias_dtype_mismatch

(F841)

🪛 GitHub Actions: lint
src/axolotl/monkeypatch/accelerate/fsdp2.py

[error] 187-187: F841 Local variable log_bias_dtype_mismatch is assigned to but never used.


[error] 178-192: ruff-format: File reformatted by formatter (1 file reformatted). Please review and commit changes.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.8.0)

Comment on lines 181 to 213
def cast_lora_module(module):
base_layer_dtype = module.base_layer.weight.dtype
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)

for active_adapter in module.active_adapters:
if module.lora_A:
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_B:
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_A:
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_B:
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_magnitude_vector:
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Cast adapters to compute dtype (not quantized weight), avoid breaking ParameterDict, and fix F841

  • Risk: Casting LoRA modules to base_layer.weight.dtype can incorrectly downcast to int/quantized dtypes (e.g., 4-bit), breaking training.
  • Bug: Assigning module.lora_magnitude_vector[adapter] = ... .to(...) replaces a Parameter with a Tensor in a ParameterDict.
  • Lint: log_bias_dtype_mismatch is set but unused (F841).

Proposed fix: derive a safe target_dtype from compute_dtype (if available) or a floating weight.dtype; cast modules in-place; cast ParameterDict entries via .data; return a flag so callers can log.

-def cast_lora_module(module):
-    base_layer_dtype = module.base_layer.weight.dtype
-    # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
-    # wrap this. Therefore we must ensure the bias has the same dtype as the weight
-    if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
-        if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
-            log_bias_dtype_mismatch = True
-            module.base_layer.bias.data = module.base_layer.bias.data.to(
-                module.base_layer.weight.dtype
-            )
-
-    for active_adapter in module.active_adapters:
-        if module.lora_A:
-            module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
-                module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_B:
-           module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
-           if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
-               module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_embedding_A:
-            module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
-                module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_embedding_B:
-            module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
-                module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_magnitude_vector:
-            module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
-                module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
+def cast_lora_module(module):
+    weight_dtype = module.base_layer.weight.dtype
+    # Linear4Bit keeps its bias term in fp32; if the weight dtype is bf16 we must match bias to weight.
+    # For adapter params, prefer the module's compute dtype when available; fall back to a floating weight dtype.
+    compute_dtype = getattr(module.base_layer, "compute_dtype", None)
+    float_dtypes = (torch.float16, torch.bfloat16, torch.float32)
+    target_dtype = (
+        compute_dtype
+        if compute_dtype in float_dtypes
+        else (weight_dtype if weight_dtype in float_dtypes else None)
+    )
+
+    bias_dtype_mismatch = False
+    if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
+        if module.base_layer.bias.dtype != weight_dtype:
+            module.base_layer.bias.data = module.base_layer.bias.data.to(weight_dtype)
+            bias_dtype_mismatch = True
+
+    # If we don't have a safe floating target dtype (e.g., quantized weight with no compute dtype), stop here.
+    if target_dtype is None:
+        return bias_dtype_mismatch
+
+    for active_adapter in module.active_adapters:
+        if getattr(module, "lora_A", None) and active_adapter in module.lora_A:
+            module.lora_A[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_A[active_adapter], "bias", None) is not None:
+                module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_B", None) and active_adapter in module.lora_B:
+            module.lora_B[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_B[active_adapter], "bias", None) is not None:
+                module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_embedding_A", None) and active_adapter in module.lora_embedding_A:
+            module.lora_embedding_A[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_embedding_A[active_adapter], "bias", None) is not None:
+                module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_embedding_B", None) and active_adapter in module.lora_embedding_B:
+            module.lora_embedding_B[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_embedding_B[active_adapter], "bias", None) is not None:
+                module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_magnitude_vector", None) and active_adapter in module.lora_magnitude_vector:
+            param = module.lora_magnitude_vector[active_adapter]
+            if isinstance(param, torch.nn.Parameter):
+                param.data = param.data.to(target_dtype)
+            else:
+                module.lora_magnitude_vector[active_adapter] = torch.nn.Parameter(
+                    param.to(target_dtype), requires_grad=getattr(param, "requires_grad", True)
+                )
+    return bias_dtype_mismatch
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def cast_lora_module(module):
base_layer_dtype = module.base_layer.weight.dtype
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
for active_adapter in module.active_adapters:
if module.lora_A:
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_B:
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_A:
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_B:
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_magnitude_vector:
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
def cast_lora_module(module):
weight_dtype = module.base_layer.weight.dtype
# Linear4Bit keeps its bias term in fp32; if the weight dtype is bf16 we must match bias to weight.
# For adapter params, prefer the module's compute dtype when available; fall back to a floating weight dtype.
compute_dtype = getattr(module.base_layer, "compute_dtype", None)
float_dtypes = (torch.float16, torch.bfloat16, torch.float32)
target_dtype = (
compute_dtype
if compute_dtype in float_dtypes
else (weight_dtype if weight_dtype in float_dtypes else None)
)
bias_dtype_mismatch = False
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.bias.dtype != weight_dtype:
module.base_layer.bias.data = module.base_layer.bias.data.to(weight_dtype)
bias_dtype_mismatch = True
# If we don't have a safe floating target dtype (e.g., quantized weight with no compute dtype), stop here.
if target_dtype is None:
return bias_dtype_mismatch
for active_adapter in module.active_adapters:
if getattr(module, "lora_A", None) and active_adapter in module.lora_A:
module.lora_A[active_adapter].to(dtype=target_dtype)
if getattr(module.lora_A[active_adapter], "bias", None) is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(target_dtype)
if getattr(module, "lora_B", None) and active_adapter in module.lora_B:
module.lora_B[active_adapter].to(dtype=target_dtype)
if getattr(module.lora_B[active_adapter], "bias", None) is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(target_dtype)
if getattr(module, "lora_embedding_A", None) and active_adapter in module.lora_embedding_A:
module.lora_embedding_A[active_adapter].to(dtype=target_dtype)
if getattr(module.lora_embedding_A[active_adapter], "bias", None) is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(target_dtype)
if getattr(module, "lora_embedding_B", None) and active_adapter in module.lora_embedding_B:
module.lora_embedding_B[active_adapter].to(dtype=target_dtype)
if getattr(module.lora_embedding_B[active_adapter], "bias", None) is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(target_dtype)
if getattr(module, "lora_magnitude_vector", None) and active_adapter in module.lora_magnitude_vector:
param = module.lora_magnitude_vector[active_adapter]
if isinstance(param, torch.nn.Parameter):
param.data = param.data.to(target_dtype)
else:
module.lora_magnitude_vector[active_adapter] = torch.nn.Parameter(
param.to(target_dtype), requires_grad=getattr(param, "requires_grad", True)
)
return bias_dtype_mismatch
🧰 Tools
🪛 Ruff (0.12.2)

187-187: Local variable log_bias_dtype_mismatch is assigned to but never used

Remove assignment to unused variable log_bias_dtype_mismatch

(F841)

🪛 GitHub Actions: lint

[error] 187-187: F841 Local variable log_bias_dtype_mismatch is assigned to but never used.

🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/accelerate/fsdp2.py around lines 181-213, the
function casts LoRA submodules to module.base_layer.weight.dtype which may be a
quantized/int dtype and breaks training, it mutates ParameterDict entries by
replacing Parameters with Tensors, and sets log_bias_dtype_mismatch but never
uses it; change to compute a safe target_dtype: use module.compute_dtype if
present else pick a floating dtype version of base_layer.weight.dtype (e.g.,
promote int/quantized to float32/float16), perform in-place casting (use .data =
.data.to(target_dtype)) for any Parameter or ParameterDict entry to avoid
replacing Parameter objects, cast biases similarly via .data, remove the unused
log_bias_dtype_mismatch variable and instead return a boolean flag (e.g.,
cast_performed) from the function so callers can log when a change occurred.

@codecov
Copy link

codecov bot commented Sep 9, 2025

Codecov Report

❌ Patch coverage is 0% with 30 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/monkeypatch/accelerate/fsdp2.py 0.00% 30 Missing ⚠️

📢 Thoughts on this report? Let us know!

module, fsdp2_kwargs
)
log_bias_dtype_mismatch |= module_log_bias_mismatch
cast_lora_module(module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could put this behind a config field - the only downside is that upstreaming it is going to be a pain

@SalmanMohammadi SalmanMohammadi marked this pull request as draft September 12, 2025 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants