Conversation
📝 WalkthroughWalkthroughThese changes add Flash Attention 4 (FA4) support to the axolotl system. A documentation section describes FA4 requirements and installation methods. The patch manager was updated to auto-apply FA4 patches during pre-model load when enabled. A new monkeypatch module implements the FA4 patching logic for SM90+ hardware. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
src/axolotl/loaders/patch_manager.py (1)
231-238: No opt-out mechanism for FA4 auto-upgrade on SM90+ hardware.When
flash_attention: trueis set on SM90+ hardware with FA4 installed, FA4 is automatically used. Users who intentionally want FA2 or FA3 (e.g., for debugging, compatibility testing, or feature comparison) have no way to disable this behavior.Consider adding a config option like
flash_attention_version: "fa2" | "fa3" | "fa4" | "auto"ordisable_fa4_upgrade: trueto give users control.💡 Example config approach
def _apply_flash_attn_4_patches(self): """Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+.""" if not self.cfg.flash_attention: return + + # Allow users to explicitly disable FA4 auto-upgrade + if getattr(self.cfg, "disable_fa4_auto_upgrade", False): + return from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4 patch_flash_attn_4()This would require adding
disable_fa4_auto_upgrade: bool | Noneto the config schema.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/loaders/patch_manager.py` around lines 231 - 238, The _apply_flash_attn_4_patches method auto-applies FA4 whenever cfg.flash_attention is true, offering no opt-out; update the config schema to add a toggle (e.g., flash_attention_version: "auto"|"fa2"|"fa3"|"fa4" or disable_fa4_auto_upgrade: bool) and change _apply_flash_attn_4_patches to check that new setting before calling patch_flash_attn_4(); specifically, read the new field from self.cfg (e.g., self.cfg.flash_attention_version or self.cfg.disable_fa4_auto_upgrade) and only call patch_flash_attn_4() when the config permits auto-upgrade (or when version == "fa4" or "auto"), otherwise skip so users can force FA2/FA3.src/axolotl/monkeypatch/attention/flash_attn_4.py (1)
20-23: Device capability check may not represent all GPUs in multi-GPU setup.
torch.cuda.get_device_capability()without arguments returns the capability of the current CUDA device (typically device 0). In heterogeneous multi-GPU setups, other devices may have different capabilities.Consider checking all available devices or at minimum documenting this limitation:
♻️ Suggested approach
- major, _ = torch.cuda.get_device_capability() - # Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11] - if major not in (9, 10, 11): - return + # Check if any available GPU supports FA4 (SM90+) + # Note: This enables FA4 if at least one GPU supports it + device_count = torch.cuda.device_count() + supported = any( + torch.cuda.get_device_capability(i)[0] in (9, 10, 11) + for i in range(device_count) + ) + if not supported: + return🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/monkeypatch/attention/flash_attn_4.py` around lines 20 - 23, The current device-capability check uses torch.cuda.get_device_capability() with no device argument (captured into major) which only queries the current CUDA device and can miss other GPUs in heterogeneous multi‑GPU systems; update the check in the module that gates flash_attn_4 (the block using torch.cuda.get_device_capability(), major, and the if major not in (9,10,11) return) to iterate over all CUDA devices (torch.cuda.device_count()) and query each device's capability (torch.cuda.get_device_capability(device)) and proceed if any device major is in (9,10,11), or at minimum add a clear comment documenting the single‑device limitation if iterating is not desired.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/loaders/patch_manager.py`:
- Around line 98-101: The FA4 patch can run unintentionally because
_apply_flash_attention_patches mutates self.cfg.flash_attention before
_apply_flash_attn_4_patches runs; record the original intent (e.g., capture the
initial self.cfg.flash_attention or a derived flag like want_xformers =
bool(self.cfg.xformers and self.cfg.sample_packing)) before calling any patch
methods, then update _apply_flash_attn_4_patches to early-return if the original
intent was to use xformers (or if the captured flag indicates xformers was
requested), referencing the methods _apply_flash_attention_patches and
_apply_flash_attn_4_patches and the config symbol
self.cfg.flash_attention/self.cfg.xformers to locate where to add the guard.
In `@src/axolotl/monkeypatch/attention/flash_attn_4.py`:
- Around line 30-40: _patched_lazy_imports currently ignores the implementation
argument and always returns FA4 symbols (flash_attn_func,
flash_attn_varlen_func, fa_utils._pad_input, fa_utils._unpad_input); update it
to honor the implementation parameter by checking its value and only forcing FA4
for compatible values (e.g., the FA2 identifier you intend to override),
otherwise defer to the original behavior or return the appropriate backend, and
add defensive logging (via the module logger or fa_utils logger) to warn when a
requested non-FA2 implementation is being overridden so callers requesting a
different backend are visible in logs.
---
Nitpick comments:
In `@src/axolotl/loaders/patch_manager.py`:
- Around line 231-238: The _apply_flash_attn_4_patches method auto-applies FA4
whenever cfg.flash_attention is true, offering no opt-out; update the config
schema to add a toggle (e.g., flash_attention_version: "auto"|"fa2"|"fa3"|"fa4"
or disable_fa4_auto_upgrade: bool) and change _apply_flash_attn_4_patches to
check that new setting before calling patch_flash_attn_4(); specifically, read
the new field from self.cfg (e.g., self.cfg.flash_attention_version or
self.cfg.disable_fa4_auto_upgrade) and only call patch_flash_attn_4() when the
config permits auto-upgrade (or when version == "fa4" or "auto"), otherwise skip
so users can force FA2/FA3.
In `@src/axolotl/monkeypatch/attention/flash_attn_4.py`:
- Around line 20-23: The current device-capability check uses
torch.cuda.get_device_capability() with no device argument (captured into major)
which only queries the current CUDA device and can miss other GPUs in
heterogeneous multi‑GPU systems; update the check in the module that gates
flash_attn_4 (the block using torch.cuda.get_device_capability(), major, and the
if major not in (9,10,11) return) to iterate over all CUDA devices
(torch.cuda.device_count()) and query each device's capability
(torch.cuda.get_device_capability(device)) and proceed if any device major is in
(9,10,11), or at minimum add a clear comment documenting the single‑device
limitation if iterating is not desired.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 883d308f-f700-46f1-97d4-3c3e22e76f38
📒 Files selected for processing (3)
docs/attention.qmdsrc/axolotl/loaders/patch_manager.pysrc/axolotl/monkeypatch/attention/flash_attn_4.py
| self._apply_flash_attention_patches() | ||
| self._apply_chunked_cross_entropy_patch() | ||
| self._apply_sageattn_patches() | ||
| self._apply_flash_attn_4_patches() |
There was a problem hiding this comment.
FA4 patch may run unintentionally after xformers patch sets flash_attention.
At line 157, _apply_flash_attention_patches sets self.cfg.flash_attention = True when xformers + sample_packing is configured. Since _apply_flash_attn_4_patches runs after this (line 101), it could attempt to enable FA4 even when the user intended to use xformers attention.
The FA4 patch has its own guards (SM90+ check, flash_attn.cute availability), so this may be benign, but the interaction is subtle and could cause unexpected behavior.
💡 Suggested fix
Consider checking if xformers was the original intent:
def _apply_flash_attn_4_patches(self):
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
- if not self.cfg.flash_attention:
+ # Skip if flash_attention was not explicitly enabled (e.g., set by xformers patch)
+ if not self.cfg.flash_attention or self.cfg.xformers_attention:
return
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
patch_flash_attn_4()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/loaders/patch_manager.py` around lines 98 - 101, The FA4 patch
can run unintentionally because _apply_flash_attention_patches mutates
self.cfg.flash_attention before _apply_flash_attn_4_patches runs; record the
original intent (e.g., capture the initial self.cfg.flash_attention or a derived
flag like want_xformers = bool(self.cfg.xformers and self.cfg.sample_packing))
before calling any patch methods, then update _apply_flash_attn_4_patches to
early-return if the original intent was to use xformers (or if the captured flag
indicates xformers was requested), referencing the methods
_apply_flash_attention_patches and _apply_flash_attn_4_patches and the config
symbol self.cfg.flash_attention/self.cfg.xformers to locate where to add the
guard.
| def _patched_lazy_imports( | ||
| implementation, attention_wrapper=None, allow_all_kernels=False | ||
| ): | ||
| from flash_attn.cute import flash_attn_func, flash_attn_varlen_func | ||
|
|
||
| return ( | ||
| flash_attn_func, | ||
| flash_attn_varlen_func, | ||
| fa_utils._pad_input, | ||
| fa_utils._unpad_input, | ||
| ) |
There was a problem hiding this comment.
Ignoring implementation argument forces FA4 unconditionally.
The patched function ignores all arguments including implementation, which the original _lazy_imports uses to select between backends (e.g., "flash_attention_2", "sdpa"). This means FA4 is used regardless of what transformers requests.
While this appears intentional for the FA2→FA4 upgrade, it could cause issues if:
- Code explicitly requests a non-flash backend
- Future transformers versions pass new
implementationvalues with specific semantics
Consider at minimum logging when a non-FA2 implementation is requested but overridden:
💡 Suggested defensive logging
def _patched_lazy_imports(
- implementation, attention_wrapper=None, allow_all_kernels=False
+ implementation, attention_wrapper=None, allow_all_kernels=False # noqa: ARG001
):
+ if implementation and "flash" not in implementation.lower():
+ LOG.debug("FA4 patch overriding requested implementation: %s", implementation)
from flash_attn.cute import flash_attn_func, flash_attn_varlen_func📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _patched_lazy_imports( | |
| implementation, attention_wrapper=None, allow_all_kernels=False | |
| ): | |
| from flash_attn.cute import flash_attn_func, flash_attn_varlen_func | |
| return ( | |
| flash_attn_func, | |
| flash_attn_varlen_func, | |
| fa_utils._pad_input, | |
| fa_utils._unpad_input, | |
| ) | |
| def _patched_lazy_imports( | |
| implementation, attention_wrapper=None, allow_all_kernels=False # noqa: ARG001 | |
| ): | |
| if implementation and "flash" not in implementation.lower(): | |
| LOG.debug("FA4 patch overriding requested implementation: %s", implementation) | |
| from flash_attn.cute import flash_attn_func, flash_attn_varlen_func | |
| return ( | |
| flash_attn_func, | |
| flash_attn_varlen_func, | |
| fa_utils._pad_input, | |
| fa_utils._unpad_input, | |
| ) |
🧰 Tools
🪛 Ruff (0.15.4)
[warning] 31-31: Unused function argument: implementation
(ARG001)
[warning] 31-31: Unused function argument: attention_wrapper
(ARG001)
[warning] 31-31: Unused function argument: allow_all_kernels
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/monkeypatch/attention/flash_attn_4.py` around lines 30 - 40,
_patched_lazy_imports currently ignores the implementation argument and always
returns FA4 symbols (flash_attn_func, flash_attn_varlen_func,
fa_utils._pad_input, fa_utils._unpad_input); update it to honor the
implementation parameter by checking its value and only forcing FA4 for
compatible values (e.g., the FA2 identifier you intend to override), otherwise
defer to the original behavior or return the appropriate backend, and add
defensive logging (via the module logger or fa_utils logger) to warn when a
requested non-FA2 implementation is being overridden so callers requesting a
different backend are visible in logs.
|
📖 Documentation Preview: https://69b136f44746169609049f27--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit ea6d243 |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
we manually install fa2 from wheels https://github.com/axolotl-ai-cloud/axolotl/blob/main/docker/Dockerfile-uv-base#L43-L57, so we should keep that in mind. is it best to figure out how to make it work with kernels? I was hoping we could also swap to use |
Description
Require:
Enable:
Please read included doc for model support capability.
Closes #3463
Motivation and Context
How has this been tested?
Improvements appear in larger context/models. <=8k context show similar results. These runs were with Qwen2.5 7B FFT. A larger model could demonstrate better perf.
FA3 & FA4 from source on
1314ea24e3502f83fbb6a04c164bf965d643fe75Hopper H100
Blackwell B200 (no FA3 support)


AI Usage Disclaimer
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
Documentation
New Features