Skip to content

Conversation

@onesnep
Copy link

@onesnep onesnep commented Sep 27, 2025

Description

The previous logic in resolve_dtype incorrectly configured the model's torch_dtype for Automatic Mixed Precision (AMP) training.

When fp16: true or bf16: true was set for mixed precision, the model was loaded directly into a half-precision format. This conflicts with the standard PyTorch AMP workflow, which expects the model to be loaded in FP32 to establish master weights before being managed by the autocast context and GradScaler.

This misconfiguration led to GradScaler failures with FP16 (expecting FP32 weights) and an inefficient, non-standard AMP implementation for BF16.

This commit adjusts the logic to prioritize the AMP flags. If fp16 or bf16 is enabled, torch_dtype is now correctly resolved to torch.float32. The logic for pure precision modes (float16, bfloat16) remains.

Motivation and Context

My hardware (AMD MI100) has a 2x faster theoretical throughput for FP16 compared with BF16, so I was interested in trying FP16 mixed precision despite the reduction in stability.

I initially observed failures when attempting to use FP16 mixed-precision by toggling fp16: true. Doing so on an example config would result in an error like the following:

traceback
File "/nvme/training/axolotl/src/axolotl/cli/train.py", line 121, in <module>
    fire.Fire(do_cli)
  File "/nvme/training/axolotl/venv/lib/python3.12/site-packages/fire/core.py", line 135, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/training/axolotl/venv/lib/python3.12/site-packages/fire/core.py", line 468, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/nvme/training/axolotl/venv/lib/python3.12/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/training/axolotl/src/axolotl/cli/train.py", line 88, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/training/axolotl/src/axolotl/cli/train.py", line 45, in do_train
    model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/training/axolotl/src/axolotl/train.py", line 583, in train
    execute_training(cfg, trainer, resume_from_checkpoint)
  File "/nvme/training/axolotl/src/axolotl/train.py", line 204, in execute_training
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/nvme/training/axolotl/venv/lib/python3.12/site-packages/transformers/trainer.py", line 2328, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/nvme/training/axolotl/venv/lib/python3.12/site-packages/transformers/trainer.py", line 2713, in _inner_training_loop
    _grad_norm = self.accelerator.clip_grad_norm_(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/training/axolotl/venv/lib/python3.12/site-packages/accelerate/accelerator.py", line 2890, in clip_grad_norm_
    self.unscale_gradients()
  File "/nvme/training/axolotl/venv/lib/python3.12/site-packages/accelerate/accelerator.py", line 2828, in unscale_gradients
    self.scaler.unscale_(opt)
  File "/nvme/training/axolotl/venv/lib/python3.12/site-packages/torch/amp/grad_scaler.py", line 343, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
                                              ^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/training/axolotl/venv/lib/python3.12/site-packages/torch/amp/grad_scaler.py", line 261, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")

This error was consistent across a number of different configuration changes (flash attention, xformers, gradient accumulation, sample packing, etc.).

BF16 AMP would run without error, as it avoided the gradient scaling pathway.

A simple torch reproducer with FP16 AMP worked without error.

torch reprod
import torch
from accelerate import Accelerator

# 1. Dummy Model and Data
model = torch.nn.Linear(10, 10).cuda()
optimizer = torch.optim.AdamW(model.parameters())
dummy_input = torch.randn(4, 10).cuda()
dummy_labels = torch.randn(4, 10).cuda()

# 2. Setup the failing environment
# This is the key part - we are creating the Accelerator programmatically
accelerator = Accelerator(mixed_precision='fp16')
model, optimizer = accelerator.prepare(model, optimizer)

# 3. Run one training step
with accelerator.autocast():
    outputs = model(dummy_input)
    loss = torch.nn.functional.mse_loss(outputs, dummy_labels)

print(f"Loss: {loss.item()}")

# This is where the crash happens
accelerator.backward(loss)

# The clip_grad_norm_ call that triggers the error in your traceback
# We can add it to be 100% sure
if accelerator.sync_gradients:
    accelerator.clip_grad_norm_(model.parameters(), 1.0)

optimizer.step()
optimizer.zero_grad()

print("Step completed successfully.")

As did a HF Trainer reproducer, so I knew the issue was with axolotl.

hf trainer reprod
import torch
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)
from datasets import Dataset
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

MODEL_CONFIG_PATH = "../mew_init"
TOKENIZER_NAME = "NousResearch/Llama-2-7b-hf"

print("="*80)
print(f"Loading model config from: {MODEL_CONFIG_PATH}")
print("="*80)

config = AutoConfig.from_pretrained(MODEL_CONFIG_PATH)
model = AutoModelForCausalLM.from_config(config)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

NUM_SAMPLES = 100
SEQ_LENGTH = 512

# Create a list of random token sequences
input_ids_list = [torch.randint(0, config.vocab_size, (SEQ_LENGTH,)) for _ in range(NUM_SAMPLES)]

dummy_data = {
    "input_ids": input_ids_list,
    "labels": input_ids_list,
}

train_dataset = Dataset.from_dict(dummy_data)
print(f"Created a dummy dataset with {len(train_dataset)} samples.")
print("="*80)


training_args = TrainingArguments(
    fp16=True,
    output_dir="./reproducer_out",
    max_steps=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    logging_steps=1,
    optim="adamw_torch",
    report_to="none",
)

print("Instantiating Trainer with the following key arguments:")
print(f"  fp16={training_args.fp16}")
print(f"  per_device_train_batch_size={training_args.per_device_train_batch_size}")
print(f"  gradient_accumulation_steps={training_args.gradient_accumulation_steps}")
print(f"  optim={training_args.optim}")
print("="*80)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

try:
    print("Starting trainer.train()...")
    trainer.train()
    print("\n--- TEST SUCCEEDED (NO BUG) ---")

except ValueError as e:
    print("\n--- TEST FAILED (BUG REPRODUCED) ---")
    raise e

How has this been tested?

This change was validated on an AMD MI100 GPU (ROCm backend), where the original code consistently failed with a ValueError: Attempting to unscale FP16 gradients.

After applying this patch, FP16 AMP training now runs successfully, with a significant performance improvement over both FP32 as measured with a custom nanoGPT-style config.

Flash attention 2 (only supporting FP16/BF16) can be enabled and provides a performance boost, implying that AMP is active and the attention passes are done in lower precision as expected.

Configuration Throughput (tokens/sec) Speedup vs. FP32 VRAM (Max Allocated)
FP32 (Baseline) 24,541 1.00x 15.38 GiB
FP16 AMP (main) - - -
FP16 AMP (PR) 30,218 1.23x 12.93 GiB
FP16 AMP + Flash Attn (PR) 36,109 1.47x 9.59 GiB

The fix should be backend-agnostic.

Example config to reproduce crash:
base_model: NousResearch/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

datasets:
  - path: teknium/GPT4-LLM-Cleaned
    type: alpaca

dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/out

sequence_len: 512
sample_packing: false

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5

fp16: true
bf16: false
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: false

warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
  pad_token: <|end_of_text|>

max_steps: 30
include_tokens_per_second: true

Types of changes

Bugfix

Summary by CodeRabbit

  • New Features
    • Improved mixed-precision handling: base models now load in float32 by default when using fp16/bf16 for more stable behavior.
  • Bug Fixes
    • Reduced precision-related crashes and inconsistencies by refining dtype selection between float16, bfloat16, and float32.
  • Refactor
    • Streamlined logic to make dtype resolution more predictable in mixed-precision configurations.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 27, 2025

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Adjusts resolve_dtype logic in src/axolotl/utils/config/init.py to default mixed-precision requests (fp16/bf16) to load the base model in torch.float32, with a specific branch setting torch_dtype to bf16, and narrows conditions that set torch_dtype to float16 by removing fp16 from that path.

Changes

Cohort / File(s) Summary
Config dtype resolution
src/axolotl/utils/config/__init__.py
Reworked resolve_dtype branching: prioritizes fp16/bf16 to keep base model in fp32; adds explicit bf16 branch; removes fp16 from the float16-setting condition; adds comment clarifying mixed-precision base dtype intent.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title succinctly highlights the core change of ensuring torch_dtype is resolved correctly for Automatic Mixed Precision (AMP), directly reflecting the PR’s main objective. It correctly references the affected function (resolve_dtype) and the context (AMP), providing clarity. This concise phrasing makes it easy for teammates to understand the purpose of the pull request at a glance.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@NanoCode012
Copy link
Collaborator

NanoCode012 commented Oct 1, 2025

Thanks for the PR and figuring it out! I think for a lot of people, they would prefer bf16 master weights to save VRAM.

Could an alternative solution be, creating a new config: amp_master_fp32: bool (just a quick name, TBD) which when checked with fp16/bf16 loads master in fp32?

We can then add this to the FAQ to enable this config for fp16 if getting gradscaler error.

@djsaunde
Copy link
Member

djsaunde commented Oct 1, 2025

Thanks for the PR and figuring it out! I think for a lot of people, they would prefer bf16 master weights to save VRAM.

Could an alternative solution be, creating a new config: amp_master_fp32: bool (just a quick name, TBD) which when checked with fp16/bf16 loads master in fp32?

We can then add this to the FAQ to enable this config for fp16 if getting gradscaler error.

I don't agree with this approach, I think we should respect the user's choice of torch_dtype (and have it default to bfloat16).

@codecov
Copy link

codecov bot commented Oct 1, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@onesnep
Copy link
Author

onesnep commented Oct 2, 2025

Appreciate for the replies!

I think there might be some confusion stemming from the current config schema, which I also found challenging. Let me try to clarify my understanding of the issue.

As I understand it, mixed precision has a specific technical meaning. It uses fast half-precision (FP16/BF16) for compute while maintaining a high-precision FP32 master copy of the weights for the optimizer. Hence this "mix" provides the speed of half-precision with the stability of FP32. When a engineer reads the axolotl docpage about enabling mixed precision, this is what they'll assume is happening when they toggle fp16: true.

This is distinct from "pure" half-precision training, where the model's weights, gradients, and optimizer states are all in BF16 (or FP16, which is generally unstable). This saves more VRAM but can sacrifice numerical stability.

In terms of how the configs appear to be intended to work in Axolotl, I understand it like this:

argument mixed precision weight dtype compute dtype
fp16 yes fp32 fp16
float16 no fp16 fp16
bf16 yes fp32 bf16
bfloat16 no bf16 bf16

The issue my PR addresses is that the original logic conflated these two distinct modes. When a user specified fp16: true (intending to use mixed precision), the code forced the model to load directly into FP16 (and likewise with BF16). This violates the assumptions of Torch's gradscaler, causing the crash. It also effectively made both mixed precision options redundant.

This is the current scenario on main:

argument mixed precision weight dtype compute dtype
fp16 no* fp16 fp16
float16 no fp16 fp16
bf16 no** bf16 bf16
bfloat16 no bf16 bf16

* crashes
** unsure exactly what torch AMP does in this scenario, it may e.g. instantiate duplicate bf16 model weights if it thinks its using mixed precision.

I think for a lot of people, they would prefer bf16 master weights to save VRAM. Could an alternative solution be, creating a new config: amp_master_fp32: bool

The existing config schema already provides a way to train in bf16 with bf16 master weights: through setting bfloat16. I suppose it wouldn't allow for a hypothetical FP16/BF16 mixed precision setup, but this would only benefit niche hardware (such as mine; it's uncommon for BF16 to be supported but slower than FP16) and I'm unsure if this would even be supported by Torch AMP. Generally, I don't think adding another config option specifically for FP32 model weights is necessary, as this is already implied by whether you are using mixed precision -- though the config name itself could be much clearer.

This whole discussion highlights that the current flags (bf16 vs. bfloat16, etc.) can be a source of confusion. I think it would make more sense to use a mixed_precision argument which can be set to fp16 or bf16 (corresponding to compute dtype) or disabled.

But I'm aware this might be beyond the scope of this specific PR, which just aims to make behaviour under the current schema stable and correct by aligning it with the standard PyTorch AMP implementation.

onesnep and others added 2 commits October 3, 2025 09:17
The previous logic in `resolve_dtype` incorrectly configured the model's `torch_dtype` for Automatic Mixed Precision (AMP) training.

When `fp16: true` or `bf16: true` was set, the model was loaded directly into a half-precision format. This conflicts with the standard PyTorch AMP workflow, which expects the model to be loaded in FP32 to establish master weights before being managed by the `autocast` context and `GradScaler`.

This misconfiguration led to `GradScaler` failures with FP16 and an inefficient, non-standard AMP implementation for BF16.

This commit adjusts the logic to prioritize the AMP flags. If `fp16` or `bf16` is enabled, `torch_dtype` is now correctly resolved to `torch.float32`. The logic for pure precision modes (`float16`, `bfloat16`) remains.
@NanoCode012
Copy link
Collaborator

Ok, after some internal discussion, I'm good with this PR now. My next thought would whether to convert existing example yamls to use bfloat16 for backward compatibility?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants