Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions src/axolotl/monkeypatch/accelerate/fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,38 @@ def get_state_dict(self, model, unwrap=True):

return state_dict

def cast_lora_module(module):
base_layer_dtype = module.base_layer.weight.dtype
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)

for active_adapter in module.active_adapters:
if module.lora_A:
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_B:
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_A:
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_B:
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_magnitude_vector:
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Cast adapters to compute dtype (not quantized weight), avoid breaking ParameterDict, and fix F841

  • Risk: Casting LoRA modules to base_layer.weight.dtype can incorrectly downcast to int/quantized dtypes (e.g., 4-bit), breaking training.
  • Bug: Assigning module.lora_magnitude_vector[adapter] = ... .to(...) replaces a Parameter with a Tensor in a ParameterDict.
  • Lint: log_bias_dtype_mismatch is set but unused (F841).

Proposed fix: derive a safe target_dtype from compute_dtype (if available) or a floating weight.dtype; cast modules in-place; cast ParameterDict entries via .data; return a flag so callers can log.

-def cast_lora_module(module):
-    base_layer_dtype = module.base_layer.weight.dtype
-    # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
-    # wrap this. Therefore we must ensure the bias has the same dtype as the weight
-    if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
-        if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
-            log_bias_dtype_mismatch = True
-            module.base_layer.bias.data = module.base_layer.bias.data.to(
-                module.base_layer.weight.dtype
-            )
-
-    for active_adapter in module.active_adapters:
-        if module.lora_A:
-            module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
-                module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_B:
-           module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
-           if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
-               module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_embedding_A:
-            module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
-                module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_embedding_B:
-            module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
-                module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
-        if module.lora_magnitude_vector:
-            module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
-            if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
-                module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
+def cast_lora_module(module):
+    weight_dtype = module.base_layer.weight.dtype
+    # Linear4Bit keeps its bias term in fp32; if the weight dtype is bf16 we must match bias to weight.
+    # For adapter params, prefer the module's compute dtype when available; fall back to a floating weight dtype.
+    compute_dtype = getattr(module.base_layer, "compute_dtype", None)
+    float_dtypes = (torch.float16, torch.bfloat16, torch.float32)
+    target_dtype = (
+        compute_dtype
+        if compute_dtype in float_dtypes
+        else (weight_dtype if weight_dtype in float_dtypes else None)
+    )
+
+    bias_dtype_mismatch = False
+    if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
+        if module.base_layer.bias.dtype != weight_dtype:
+            module.base_layer.bias.data = module.base_layer.bias.data.to(weight_dtype)
+            bias_dtype_mismatch = True
+
+    # If we don't have a safe floating target dtype (e.g., quantized weight with no compute dtype), stop here.
+    if target_dtype is None:
+        return bias_dtype_mismatch
+
+    for active_adapter in module.active_adapters:
+        if getattr(module, "lora_A", None) and active_adapter in module.lora_A:
+            module.lora_A[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_A[active_adapter], "bias", None) is not None:
+                module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_B", None) and active_adapter in module.lora_B:
+            module.lora_B[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_B[active_adapter], "bias", None) is not None:
+                module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_embedding_A", None) and active_adapter in module.lora_embedding_A:
+            module.lora_embedding_A[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_embedding_A[active_adapter], "bias", None) is not None:
+                module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_embedding_B", None) and active_adapter in module.lora_embedding_B:
+            module.lora_embedding_B[active_adapter].to(dtype=target_dtype)
+            if getattr(module.lora_embedding_B[active_adapter], "bias", None) is not None:
+                module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(target_dtype)
+        if getattr(module, "lora_magnitude_vector", None) and active_adapter in module.lora_magnitude_vector:
+            param = module.lora_magnitude_vector[active_adapter]
+            if isinstance(param, torch.nn.Parameter):
+                param.data = param.data.to(target_dtype)
+            else:
+                module.lora_magnitude_vector[active_adapter] = torch.nn.Parameter(
+                    param.to(target_dtype), requires_grad=getattr(param, "requires_grad", True)
+                )
+    return bias_dtype_mismatch
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def cast_lora_module(module):
base_layer_dtype = module.base_layer.weight.dtype
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
for active_adapter in module.active_adapters:
if module.lora_A:
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_B:
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_A:
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_B:
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_magnitude_vector:
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
def cast_lora_module(module):
weight_dtype = module.base_layer.weight.dtype
# Linear4Bit keeps its bias term in fp32; if the weight dtype is bf16 we must match bias to weight.
# For adapter params, prefer the module's compute dtype when available; fall back to a floating weight dtype.
compute_dtype = getattr(module.base_layer, "compute_dtype", None)
float_dtypes = (torch.float16, torch.bfloat16, torch.float32)
target_dtype = (
compute_dtype
if compute_dtype in float_dtypes
else (weight_dtype if weight_dtype in float_dtypes else None)
)
bias_dtype_mismatch = False
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.bias.dtype != weight_dtype:
module.base_layer.bias.data = module.base_layer.bias.data.to(weight_dtype)
bias_dtype_mismatch = True
# If we don't have a safe floating target dtype (e.g., quantized weight with no compute dtype), stop here.
if target_dtype is None:
return bias_dtype_mismatch
for active_adapter in module.active_adapters:
if getattr(module, "lora_A", None) and active_adapter in module.lora_A:
module.lora_A[active_adapter].to(dtype=target_dtype)
if getattr(module.lora_A[active_adapter], "bias", None) is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(target_dtype)
if getattr(module, "lora_B", None) and active_adapter in module.lora_B:
module.lora_B[active_adapter].to(dtype=target_dtype)
if getattr(module.lora_B[active_adapter], "bias", None) is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(target_dtype)
if getattr(module, "lora_embedding_A", None) and active_adapter in module.lora_embedding_A:
module.lora_embedding_A[active_adapter].to(dtype=target_dtype)
if getattr(module.lora_embedding_A[active_adapter], "bias", None) is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(target_dtype)
if getattr(module, "lora_embedding_B", None) and active_adapter in module.lora_embedding_B:
module.lora_embedding_B[active_adapter].to(dtype=target_dtype)
if getattr(module.lora_embedding_B[active_adapter], "bias", None) is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(target_dtype)
if getattr(module, "lora_magnitude_vector", None) and active_adapter in module.lora_magnitude_vector:
param = module.lora_magnitude_vector[active_adapter]
if isinstance(param, torch.nn.Parameter):
param.data = param.data.to(target_dtype)
else:
module.lora_magnitude_vector[active_adapter] = torch.nn.Parameter(
param.to(target_dtype), requires_grad=getattr(param, "requires_grad", True)
)
return bias_dtype_mismatch
🧰 Tools
🪛 Ruff (0.12.2)

187-187: Local variable log_bias_dtype_mismatch is assigned to but never used

Remove assignment to unused variable log_bias_dtype_mismatch

(F841)

🪛 GitHub Actions: lint

[error] 187-187: F841 Local variable log_bias_dtype_mismatch is assigned to but never used.

🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/accelerate/fsdp2.py around lines 181-213, the
function casts LoRA submodules to module.base_layer.weight.dtype which may be a
quantized/int dtype and breaks training, it mutates ParameterDict entries by
replacing Parameters with Tensors, and sets log_bias_dtype_mismatch but never
uses it; change to compute a safe target_dtype: use module.compute_dtype if
present else pick a floating dtype version of base_layer.weight.dtype (e.g.,
promote int/quantized to float32/float16), perform in-place casting (use .data =
.data.to(target_dtype)) for any Parameter or ParameterDict entry to avoid
replacing Parameter objects, cast biases similarly via .data, remove the unused
log_bias_dtype_mismatch variable and instead return a boolean flag (e.g.,
cast_performed) from the function so callers can log when a change occurred.

def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
Expand Down Expand Up @@ -324,10 +356,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer):
module_log_bias_mismatch = _process_lora_module_for_fsdp(
module, fsdp2_kwargs
)
log_bias_dtype_mismatch |= module_log_bias_mismatch
cast_lora_module(module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could put this behind a config field - the only downside is that upstreaming it is going to be a pain

# module_log_bias_mismatch = _process_lora_module_for_fsdp(
# module, fsdp2_kwargs
# )
# log_bias_dtype_mismatch |= module_log_bias_mismatch
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)

Expand Down
Loading