Skip to content

feat: add FA4#3481

Merged
winglian merged 8 commits intomainfrom
feat/fa4
Mar 16, 2026
Merged

feat: add FA4#3481
winglian merged 8 commits intomainfrom
feat/fa4

Conversation

@NanoCode012
Copy link
Collaborator

@NanoCode012 NanoCode012 commented Mar 9, 2026

Description

Require:

pip install flash-attn-4

# may require
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute

Enable:

flash_attention: true

Please read included doc for model support capability.

Closes #3463

Motivation and Context

How has this been tested?

Improvements appear in larger context/models. <=8k context show similar results. These runs were with Qwen2.5 7B FFT. A larger model could demonstrate better perf.

FA3 & FA4 from source on 1314ea24e3502f83fbb6a04c164bf965d643fe75

Hopper H100

image

Blackwell B200 (no FA3 support)
image
image

AI Usage Disclaimer

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • Documentation

    • Added Flash Attention 4 installation guide with pip and from-source setup instructions for Hopper and Blackwell GPUs.
  • New Features

    • Enabled Flash Attention 4 automatic optimization on SM90+ hardware when flash attention is configured.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 9, 2026

📝 Walkthrough

Walkthrough

These changes add Flash Attention 4 (FA4) support to the axolotl system. A documentation section describes FA4 requirements and installation methods. The patch manager was updated to auto-apply FA4 patches during pre-model load when enabled. A new monkeypatch module implements the FA4 patching logic for SM90+ hardware.

Changes

Cohort / File(s) Summary
Documentation
docs/attention.qmd
Added new "Flash Attention 4" subsection with GPU requirements (Hopper/Blackwell), pip installation instructions, and source build guidance.
Patch Infrastructure
src/axolotl/loaders/patch_manager.py
Added _apply_flash_attn_4_patches() method to conditionally apply FA4 patches when flash_attention is enabled; integrated into pre-model load pipeline.
Flash Attention 4 Monkeypatch
src/axolotl/monkeypatch/attention/flash_attn_4.py
New module implementing patch_flash_attn_4() function that monkey-patches transformers' FA utilities with FA4 implementations for SM90+ hardware, gated by environment checks.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested labels

scheduled_release

Suggested reviewers

  • salmanmohammadi
🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: add FA4' is concise and clearly identifies the main change—adding Flash Attention 4 support. It directly corresponds to the primary objective and the implemented feature across all modified files.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/fa4

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
src/axolotl/loaders/patch_manager.py (1)

231-238: No opt-out mechanism for FA4 auto-upgrade on SM90+ hardware.

When flash_attention: true is set on SM90+ hardware with FA4 installed, FA4 is automatically used. Users who intentionally want FA2 or FA3 (e.g., for debugging, compatibility testing, or feature comparison) have no way to disable this behavior.

Consider adding a config option like flash_attention_version: "fa2" | "fa3" | "fa4" | "auto" or disable_fa4_upgrade: true to give users control.

💡 Example config approach
 def _apply_flash_attn_4_patches(self):
     """Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
     if not self.cfg.flash_attention:
         return
+
+    # Allow users to explicitly disable FA4 auto-upgrade
+    if getattr(self.cfg, "disable_fa4_auto_upgrade", False):
+        return

     from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4

     patch_flash_attn_4()

This would require adding disable_fa4_auto_upgrade: bool | None to the config schema.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/loaders/patch_manager.py` around lines 231 - 238, The
_apply_flash_attn_4_patches method auto-applies FA4 whenever cfg.flash_attention
is true, offering no opt-out; update the config schema to add a toggle (e.g.,
flash_attention_version: "auto"|"fa2"|"fa3"|"fa4" or disable_fa4_auto_upgrade:
bool) and change _apply_flash_attn_4_patches to check that new setting before
calling patch_flash_attn_4(); specifically, read the new field from self.cfg
(e.g., self.cfg.flash_attention_version or self.cfg.disable_fa4_auto_upgrade)
and only call patch_flash_attn_4() when the config permits auto-upgrade (or when
version == "fa4" or "auto"), otherwise skip so users can force FA2/FA3.
src/axolotl/monkeypatch/attention/flash_attn_4.py (1)

20-23: Device capability check may not represent all GPUs in multi-GPU setup.

torch.cuda.get_device_capability() without arguments returns the capability of the current CUDA device (typically device 0). In heterogeneous multi-GPU setups, other devices may have different capabilities.

Consider checking all available devices or at minimum documenting this limitation:

♻️ Suggested approach
-    major, _ = torch.cuda.get_device_capability()
-    # Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11]
-    if major not in (9, 10, 11):
-        return
+    # Check if any available GPU supports FA4 (SM90+)
+    # Note: This enables FA4 if at least one GPU supports it
+    device_count = torch.cuda.device_count()
+    supported = any(
+        torch.cuda.get_device_capability(i)[0] in (9, 10, 11)
+        for i in range(device_count)
+    )
+    if not supported:
+        return
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/monkeypatch/attention/flash_attn_4.py` around lines 20 - 23, The
current device-capability check uses torch.cuda.get_device_capability() with no
device argument (captured into major) which only queries the current CUDA device
and can miss other GPUs in heterogeneous multi‑GPU systems; update the check in
the module that gates flash_attn_4 (the block using
torch.cuda.get_device_capability(), major, and the if major not in (9,10,11)
return) to iterate over all CUDA devices (torch.cuda.device_count()) and query
each device's capability (torch.cuda.get_device_capability(device)) and proceed
if any device major is in (9,10,11), or at minimum add a clear comment
documenting the single‑device limitation if iterating is not desired.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/loaders/patch_manager.py`:
- Around line 98-101: The FA4 patch can run unintentionally because
_apply_flash_attention_patches mutates self.cfg.flash_attention before
_apply_flash_attn_4_patches runs; record the original intent (e.g., capture the
initial self.cfg.flash_attention or a derived flag like want_xformers =
bool(self.cfg.xformers and self.cfg.sample_packing)) before calling any patch
methods, then update _apply_flash_attn_4_patches to early-return if the original
intent was to use xformers (or if the captured flag indicates xformers was
requested), referencing the methods _apply_flash_attention_patches and
_apply_flash_attn_4_patches and the config symbol
self.cfg.flash_attention/self.cfg.xformers to locate where to add the guard.

In `@src/axolotl/monkeypatch/attention/flash_attn_4.py`:
- Around line 30-40: _patched_lazy_imports currently ignores the implementation
argument and always returns FA4 symbols (flash_attn_func,
flash_attn_varlen_func, fa_utils._pad_input, fa_utils._unpad_input); update it
to honor the implementation parameter by checking its value and only forcing FA4
for compatible values (e.g., the FA2 identifier you intend to override),
otherwise defer to the original behavior or return the appropriate backend, and
add defensive logging (via the module logger or fa_utils logger) to warn when a
requested non-FA2 implementation is being overridden so callers requesting a
different backend are visible in logs.

---

Nitpick comments:
In `@src/axolotl/loaders/patch_manager.py`:
- Around line 231-238: The _apply_flash_attn_4_patches method auto-applies FA4
whenever cfg.flash_attention is true, offering no opt-out; update the config
schema to add a toggle (e.g., flash_attention_version: "auto"|"fa2"|"fa3"|"fa4"
or disable_fa4_auto_upgrade: bool) and change _apply_flash_attn_4_patches to
check that new setting before calling patch_flash_attn_4(); specifically, read
the new field from self.cfg (e.g., self.cfg.flash_attention_version or
self.cfg.disable_fa4_auto_upgrade) and only call patch_flash_attn_4() when the
config permits auto-upgrade (or when version == "fa4" or "auto"), otherwise skip
so users can force FA2/FA3.

In `@src/axolotl/monkeypatch/attention/flash_attn_4.py`:
- Around line 20-23: The current device-capability check uses
torch.cuda.get_device_capability() with no device argument (captured into major)
which only queries the current CUDA device and can miss other GPUs in
heterogeneous multi‑GPU systems; update the check in the module that gates
flash_attn_4 (the block using torch.cuda.get_device_capability(), major, and the
if major not in (9,10,11) return) to iterate over all CUDA devices
(torch.cuda.device_count()) and query each device's capability
(torch.cuda.get_device_capability(device)) and proceed if any device major is in
(9,10,11), or at minimum add a clear comment documenting the single‑device
limitation if iterating is not desired.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 883d308f-f700-46f1-97d4-3c3e22e76f38

📥 Commits

Reviewing files that changed from the base of the PR and between 43b1c80 and 43d6fea.

📒 Files selected for processing (3)
  • docs/attention.qmd
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/monkeypatch/attention/flash_attn_4.py

Comment on lines 98 to +101
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_flash_attn_4_patches()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

FA4 patch may run unintentionally after xformers patch sets flash_attention.

At line 157, _apply_flash_attention_patches sets self.cfg.flash_attention = True when xformers + sample_packing is configured. Since _apply_flash_attn_4_patches runs after this (line 101), it could attempt to enable FA4 even when the user intended to use xformers attention.

The FA4 patch has its own guards (SM90+ check, flash_attn.cute availability), so this may be benign, but the interaction is subtle and could cause unexpected behavior.

💡 Suggested fix

Consider checking if xformers was the original intent:

 def _apply_flash_attn_4_patches(self):
     """Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
-    if not self.cfg.flash_attention:
+    # Skip if flash_attention was not explicitly enabled (e.g., set by xformers patch)
+    if not self.cfg.flash_attention or self.cfg.xformers_attention:
         return

     from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4

     patch_flash_attn_4()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/loaders/patch_manager.py` around lines 98 - 101, The FA4 patch
can run unintentionally because _apply_flash_attention_patches mutates
self.cfg.flash_attention before _apply_flash_attn_4_patches runs; record the
original intent (e.g., capture the initial self.cfg.flash_attention or a derived
flag like want_xformers = bool(self.cfg.xformers and self.cfg.sample_packing))
before calling any patch methods, then update _apply_flash_attn_4_patches to
early-return if the original intent was to use xformers (or if the captured flag
indicates xformers was requested), referencing the methods
_apply_flash_attention_patches and _apply_flash_attn_4_patches and the config
symbol self.cfg.flash_attention/self.cfg.xformers to locate where to add the
guard.

Comment on lines +30 to +40
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
from flash_attn.cute import flash_attn_func, flash_attn_varlen_func

return (
flash_attn_func,
flash_attn_varlen_func,
fa_utils._pad_input,
fa_utils._unpad_input,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ignoring implementation argument forces FA4 unconditionally.

The patched function ignores all arguments including implementation, which the original _lazy_imports uses to select between backends (e.g., "flash_attention_2", "sdpa"). This means FA4 is used regardless of what transformers requests.

While this appears intentional for the FA2→FA4 upgrade, it could cause issues if:

  1. Code explicitly requests a non-flash backend
  2. Future transformers versions pass new implementation values with specific semantics

Consider at minimum logging when a non-FA2 implementation is requested but overridden:

💡 Suggested defensive logging
     def _patched_lazy_imports(
-        implementation, attention_wrapper=None, allow_all_kernels=False
+        implementation, attention_wrapper=None, allow_all_kernels=False  # noqa: ARG001
     ):
+        if implementation and "flash" not in implementation.lower():
+            LOG.debug("FA4 patch overriding requested implementation: %s", implementation)
         from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
return (
flash_attn_func,
flash_attn_varlen_func,
fa_utils._pad_input,
fa_utils._unpad_input,
)
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False # noqa: ARG001
):
if implementation and "flash" not in implementation.lower():
LOG.debug("FA4 patch overriding requested implementation: %s", implementation)
from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
return (
flash_attn_func,
flash_attn_varlen_func,
fa_utils._pad_input,
fa_utils._unpad_input,
)
🧰 Tools
🪛 Ruff (0.15.4)

[warning] 31-31: Unused function argument: implementation

(ARG001)


[warning] 31-31: Unused function argument: attention_wrapper

(ARG001)


[warning] 31-31: Unused function argument: allow_all_kernels

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/monkeypatch/attention/flash_attn_4.py` around lines 30 - 40,
_patched_lazy_imports currently ignores the implementation argument and always
returns FA4 symbols (flash_attn_func, flash_attn_varlen_func,
fa_utils._pad_input, fa_utils._unpad_input); update it to honor the
implementation parameter by checking its value and only forcing FA4 for
compatible values (e.g., the FA2 identifier you intend to override), otherwise
defer to the original behavior or return the appropriate backend, and add
defensive logging (via the module logger or fa_utils logger) to warn when a
requested non-FA2 implementation is being overridden so callers requesting a
different backend are visible in logs.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 9, 2026

📖 Documentation Preview: https://69b136f44746169609049f27--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit ea6d243

@codecov
Copy link

codecov bot commented Mar 9, 2026

Codecov Report

❌ Patch coverage is 7.01754% with 53 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/monkeypatch/attention/flash_attn_4.py 0.00% 51 Missing ⚠️
src/axolotl/loaders/patch_manager.py 66.66% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@winglian
Copy link
Collaborator

winglian commented Mar 9, 2026

we manually install fa2 from wheels https://github.com/axolotl-ai-cloud/axolotl/blob/main/docker/Dockerfile-uv-base#L43-L57, so we should keep that in mind. is it best to figure out how to make it work with kernels? I was hoping we could also swap to use attn_implementation instead of all the various possible args we have now.

@winglian winglian merged commit 7da5f94 into main Mar 16, 2026
18 of 19 checks passed
@winglian winglian deleted the feat/fa4 branch March 16, 2026 04:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[New feature] Flash attention 4

2 participants