Skip to content

Commit a18e994

Browse files
authored
[ops] feat: add Quack GEMM backend for fused MoE & upgrade fa4 (#546)
1 parent 083873c commit a18e994

File tree

15 files changed

+1303
-396
lines changed

15 files changed

+1303
-396
lines changed

docs/design/kernel_selection.md

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# Kernel Selection in VeOmni
2+
3+
VeOmni selects optimized kernel implementations for attention, cross-entropy
4+
loss, Liger fused ops, and MoE at different points in the lifecycle. This
5+
document describes every selection mechanism, when it fires, and how to
6+
configure it.
7+
8+
## Quick Reference
9+
10+
| Kernel | Config field | Env var | Default | Selection time |
11+
|--------|-------------|---------|---------|----------------|
12+
| Attention | `attn_implementation` | — | `"flash_attention_2"` | Config `__post_init__` + `build_foundation_model` |
13+
| Cross-entropy loss | — | `VEOMNI_USE_LIGER_KERNEL` | `"1"` | Import time |
14+
| Liger fused ops (RMSNorm, RoPE, SwiGLU) | — | `VEOMNI_USE_LIGER_KERNEL` | `"1"` | Model registration (import time) |
15+
| MoE implementation | `moe_implementation` | — | `None` | `build_foundation_model` |
16+
17+
All config fields live in `OpsImplementationConfig` (`veomni/arguments/arguments_types.py`),
18+
accessible via `model.ops_implementation.*` in YAML.
19+
20+
---
21+
22+
## Lifecycle Overview
23+
24+
```
25+
import veomni # (1) import time
26+
└─ apply_ops_patch()
27+
├─ apply_veomni_attention_patch() # register FA2/3/4 with SP
28+
├─ apply_veomni_loss_patch() # bind cross-entropy kernel
29+
└─ (MoE patch is NOT applied here)
30+
31+
MODELING_REGISTRY.register() # (2) model class registration
32+
└─ gpu_patch files # Liger RMSNorm/RoPE/SwiGLU
33+
34+
OpsImplementationConfig.__post_init__() # (3) config parse time
35+
└─ rewrite attn_implementation for SP
36+
37+
build_foundation_model(...) # (4) model build time
38+
├─ apply_veomni_fused_moe_patch(backend=) # bind MoE GEMM kernel
39+
├─ config._moe_implementation = ...
40+
└─ model init + weight loading
41+
42+
model.forward() # (5) runtime
43+
├─ attention: ALL_ATTENTION_FUNCTIONS[config._attn_implementation]
44+
├─ loss: _cross_entropy(...)
45+
└─ MoE: fused_moe_forward(...) or eager loop
46+
```
47+
48+
---
49+
50+
## 1. Attention
51+
52+
### Config
53+
54+
```yaml
55+
model:
56+
ops_implementation:
57+
attn_implementation: flash_attention_2 # default
58+
```
59+
60+
**Field:** `OpsImplementationConfig.attn_implementation`
61+
62+
### Available implementations
63+
64+
| Value | Kernel | Sequence Parallel | Requirements |
65+
|-------|--------|:-:|---|
66+
| `eager` | PyTorch | No | — |
67+
| `sdpa` | `F.scaled_dot_product_attention` | No | — |
68+
| `flash_attention_2` | Flash Attention v2 | Yes | `flash-attn` |
69+
| `flash_attention_3` | Flash Attention v3 | Yes | `flash-attn-interface` |
70+
| `flash_attention_4` | Flash Attention v4 | Yes | `flash-attn.cute` |
71+
| `native-sparse` | Sparse attention | No | — |
72+
73+
When `MODELING_BACKEND=veomni` (the default), `__post_init__` automatically
74+
rewrites `flash_attention_2/3/4` to VeOmni SP-aware variants
75+
(`veomni_flash_attention_2_with_sp`, etc.) which wrap the underlying kernel
76+
with DeepSpeed Ulysses sequence parallelism gather/scatter. This is why FA2/3/4
77+
support SP — the rewrite is transparent to the user.
78+
79+
### Selection flow
80+
81+
1. **Config `__post_init__`** — `flash_attention_2` → `veomni_flash_attention_2_with_sp`
82+
2. **`build_foundation_model`** — passed to HuggingFace `AutoModel.from_config(attn_implementation=...)`, stored as `config._attn_implementation`
83+
3. **Import-time registration** — `apply_veomni_attention_patch()` registers the VeOmni names in `ALL_ATTENTION_FUNCTIONS`
84+
4. **Forward** — Transformers dispatches to `flash_attention_forward()` via `ALL_ATTENTION_FUNCTIONS[config._attn_implementation]`
85+
86+
### Key files
87+
88+
- Config: `veomni/arguments/arguments_types.py` — `OpsImplementationConfig`
89+
- Registration: `veomni/ops/flash_attn/__init__.py` — `apply_veomni_attention_patch()`, `flash_attention_forward()`
90+
- Plumbing: `veomni/models/auto.py` — `build_foundation_model(attn_implementation=...)`
91+
92+
---
93+
94+
## 2. Cross-Entropy Loss
95+
96+
### Config
97+
98+
No config field. Controlled by environment variable.
99+
100+
| Env var | Default | Values |
101+
|---------|---------|--------|
102+
| `VEOMNI_USE_LIGER_KERNEL` | `"1"` | `"0"` / `"1"` |
103+
| `VEOMNI_ENABLE_CHUNK_LOSS` | `"0"` | `"0"` / `"1"` (NPU only) |
104+
105+
### Available implementations
106+
107+
| Implementation | When selected |
108+
|---|---|
109+
| `fused_liger_kernel_cross_entropy` | GPU + Liger installed + `VEOMNI_USE_LIGER_KERNEL=1` |
110+
| `eager_cross_entropy` | GPU fallback, or NPU |
111+
| `chunk_loss_function` | NPU + `VEOMNI_ENABLE_CHUNK_LOSS=1` |
112+
113+
### Selection flow
114+
115+
`apply_veomni_loss_patch()` runs at import time and sets the global
116+
`_cross_entropy` function pointer:
117+
118+
1. NPU → `eager_cross_entropy` (+ optional `chunk_loss_function` for `LOSS_MAPPING`)
119+
2. GPU + Liger + env `"1"` → `fused_liger_kernel_cross_entropy`
120+
3. Fallback → `eager_cross_entropy`
121+
122+
### Key files
123+
124+
- Selection: `veomni/ops/fused_cross_entropy/__init__.py` — `apply_veomni_loss_patch()`
125+
- Eager impl: `veomni/ops/fused_cross_entropy/eager.py`
126+
- Liger impl: `veomni/ops/fused_cross_entropy/liger_kernel.py`
127+
128+
---
129+
130+
## 3. Liger Fused Ops (RMSNorm, RoPE, SwiGLU MLP)
131+
132+
### Config
133+
134+
No config field. Same environment variable as cross-entropy.
135+
136+
| Env var | Default |
137+
|---------|---------|
138+
| `VEOMNI_USE_LIGER_KERNEL` | `"1"` |
139+
140+
### What gets patched
141+
142+
When `VEOMNI_USE_LIGER_KERNEL=1` and the `liger_kernel` package is installed,
143+
each model's `gpu_patch.py` replaces HuggingFace module classes:
144+
145+
| Component | Original | Liger replacement |
146+
|---|---|---|
147+
| RMSNorm | `{Model}RMSNorm` | `LigerRMSNorm` |
148+
| Rotary embedding | `apply_rotary_pos_emb` | `liger_rotary_pos_emb` |
149+
| SwiGLU MLP | `{Model}MLP` | `LigerSwiGLUMLP` |
150+
151+
### Selection flow
152+
153+
Patching happens at model class registration time (import of the model
154+
module). Each model's `gpu_patch.py` checks:
155+
156+
```python
157+
if is_liger_kernel_available() and get_env("VEOMNI_USE_LIGER_KERNEL") == "1":
158+
hf_module.apply_rotary_pos_emb = liger_rotary_pos_emb
159+
hf_module.ModelRMSNorm = LigerRMSNorm
160+
hf_module.ModelMLP = LigerSwiGLUMLP
161+
```
162+
163+
### Models with Liger support
164+
165+
Qwen2, Qwen3, Qwen3-MoE, Qwen2-VL, DeepSeek-V3, Llama, Seed-OSS.
166+
167+
### Key files
168+
169+
- `veomni/models/transformers/{model}/gpu_patch.py` (7 model-specific files)
170+
171+
---
172+
173+
## 4. MoE Kernel
174+
175+
MoE kernel selection is controlled by a single `moe_implementation` field:
176+
177+
```yaml
178+
model:
179+
ops_implementation:
180+
moe_implementation: fused # Triton group-gemm (default fused path)
181+
# moe_implementation: fused_quack # Quack CUTLASS/CuTe kernels (SM90+)
182+
# moe_implementation: eager # Reference PyTorch loop (very slow, debug only)
183+
```
184+
185+
**Field:** `OpsImplementationConfig.moe_implementation`
186+
**Default:** `None` (falls back to `"eager"` per model config)
187+
188+
| Value | Kernel | Hardware | EP support |
189+
|-------|--------|----------|:----------:|
190+
| `eager` | PyTorch expert loop | Any | No |
191+
| `fused` | Triton group-gemm (`group_gemm_same_nk`) | SM70+ (V100+) | Yes |
192+
| `fused_quack` | Quack CUTLASS/CuTe (`quack.gemm_interface.gemm`) | SM90+ (H100+) | No |
193+
| *(NPU auto)* | NPU group-gemm | Ascend NPU | Yes |
194+
195+
Models only see `_moe_implementation` as `"eager"` or `"fused"` — the
196+
`fused_quack` variant is mapped to `"fused"` on the config, with the kernel
197+
backend selected separately via `apply_veomni_fused_moe_patch`.
198+
199+
On NPU devices, the backend parameter is ignored — the NPU kernel is always
200+
selected.
201+
202+
### Selection flow
203+
204+
Unlike attention and loss, the MoE patch is **not** applied at import time.
205+
It is applied inside `build_foundation_model()`:
206+
207+
```python
208+
def build_foundation_model(..., moe_implementation="fused_quack"):
209+
config._moe_implementation = "fused"
210+
apply_veomni_fused_moe_patch(moe_implementation="fused_quack")
211+
```
212+
213+
This deferred approach allows the config to drive kernel selection without
214+
env vars.
215+
216+
### Usage
217+
218+
**Via config (YAML):**
219+
220+
```yaml
221+
model:
222+
ops_implementation:
223+
moe_implementation: fused_quack
224+
```
225+
226+
**Via `build_foundation_model` (standalone scripts):**
227+
228+
```python
229+
model = build_foundation_model(
230+
config_path="...",
231+
moe_implementation="fused_quack",
232+
)
233+
```
234+
235+
**Direct patch (tests / benchmarks):**
236+
237+
```python
238+
from veomni.ops.fused_moe import apply_veomni_fused_moe_patch
239+
apply_veomni_fused_moe_patch(moe_implementation="fused_quack")
240+
```
241+
242+
### Key files
243+
244+
- Config: `veomni/arguments/arguments_types.py` — `OpsImplementationConfig`
245+
- Dispatch: `veomni/ops/fused_moe/__init__.py` — `apply_veomni_fused_moe_patch()`
246+
- Triton impl: `veomni/ops/fused_moe/group_gemm.py`
247+
- Quack impl: `veomni/ops/fused_moe/quack_gemm.py`
248+
- NPU impl: `veomni/ops/fused_moe/npu_group_gemm.py`
249+
- Plumbing: `veomni/models/auto.py` — `build_foundation_model(moe_implementation=...)`
250+
251+
---
252+
253+
## Environment Variables Summary
254+
255+
| Env var | Default | Scope | Notes |
256+
|---------|---------|-------|-------|
257+
| `MODELING_BACKEND` | `"veomni"` | Global | `"veomni"` or `"hf"` — controls whether VeOmni ops patches are applied |
258+
| `VEOMNI_USE_LIGER_KERNEL` | `"1"` | Global | Controls Liger kernel for RMSNorm/RoPE/SwiGLU + cross-entropy loss |
259+
| `USE_GROUP_GEMM` | `"1"` | MoE | Gate for Triton group-gemm availability; set `"0"` to force fallback |
260+
| `VEOMNI_ENABLE_CHUNK_LOSS` | `"0"` | NPU only | Enable chunked loss computation |
261+
262+
All env vars are registered in `veomni/utils/env.py` with defaults and can be
263+
overridden by setting the corresponding shell environment variable.

docs/index.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ hardware_support/get_started_npu.md
4747
:caption: Examples
4848

4949
examples/qwen3.md
50+
examples/qwen3_5.md
5051
examples/qwen3_moe.md
5152
examples/qwen3_vl.md
5253
examples/qwen3_omni_moe.md
@@ -64,6 +65,13 @@ key_features/ulysses.md
6465

6566
```
6667

68+
```{toctree}
69+
:maxdepth: 1
70+
:caption: Design
71+
72+
design/kernel_selection.md
73+
```
74+
6775
```{toctree}
6876
:maxdepth: 1
6977
:caption: Transformers v5 Updates

docs/transformers_v5/veomni_flash_attention_kernel_adapter.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,13 @@ After `import veomni`:
7979
no-op for those two in practice, but is kept for safety.
8080
- FA4 (`veomni_flash_attention_4_with_sp`) has no such branch in `_lazy_imports` and
8181
always falls through to the hub-kernel path in Transformers v5. The adapter is the
82-
**critical** component that makes FA4 usable. FA4 is not supported on Transformers v4.
83-
- FA4 requires the `flash-attn-cute` package (`flash_attn.cute`). To install Transformers v5
84-
and FA4 together, run:
85-
```
86-
uv sync --extra gpu --extra fa4 --extra transformers5-exp --no-group transformers-stable
87-
```
82+
**critical** component that makes FA4 usable on v5.
83+
- On Transformers v4, FA4 is supported via the VeOmni SP variant
84+
(`veomni_flash_attention_4_with_sp`). Instead of the string name, VeOmni passes
85+
a `SimpleNamespace` object (from `_load_veomni_local_flash_kernel`) directly to
86+
`_lazy_imports`, which v4 accepts in its kernels-fallback branch via `getattr()`.
87+
The bare `flash_attention_4` name still requires Transformers v5; for Transformers v4,
88+
use `attn_implementation="veomni_flash_attention_4_with_sp"`.
89+
- FA4 requires the `flash-attn-cute` package (`flash_attn.cute`). To install FA4:
90+
- **Transformers v5**: `uv sync --extra gpu --extra fa4 --extra transformers5-exp --no-group transformers-stable`
91+
- **Transformers v4**: `uv sync --extra gpu --extra fa4`

docs/usage/arguments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ Root config — assembles `model`, `data`, and `train`.
122122
| Field | Type | Default | Description |
123123
| --- | --- | --- | --- |
124124
| attn_implementation | `Optional[Literal["eager", "sdpa", "flash_attention_2", "flash_attention_3", "flash_attention_4", "native-sparse"]]` | `"flash_attention_2"` | Attention implementation to use. |
125-
| moe_implementation | `Optional[Literal["eager", "fused"]]` | `None` | MoE implementation to use. |
125+
| moe_implementation | `Optional[Literal["eager", "fused", "fused_quack"]]` | `None` | MoE implementation: `eager` (reference loop), `fused` (Triton), `fused_quack` (Quack CUTLASS, SM90+). |
126126

127127
### DataArguments
128128

pyproject.toml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ gpu = [
9393
"torch-c-dlpack-ext",
9494
# For models with linear attention like Qwen 3.5
9595
"flash-linear-attention",
96+
"quack-kernels==0.3.2",
9697
]
9798
megatron = [
9899
"megatron-energon>=7.2.1"
@@ -102,7 +103,7 @@ trl = [
102103
]
103104

104105
fa4 = [
105-
"flash-attn-cute",
106+
"flash-attn-4==0.1.0",
106107
"nvidia-cutlass-dsl>=4.4.0"
107108
]
108109

@@ -210,11 +211,6 @@ conflicts = [
210211
{ group = "transformers-stable" },
211212
{ extra = "transformers5-exp" },
212213
],
213-
# FA4 only works for transformers v5
214-
[
215-
{ group = "transformers-stable" },
216-
{ extra = "fa4" },
217-
],
218214
]
219215

220216
[tool.uv.sources]
@@ -244,8 +240,8 @@ flash-attn-3 = [
244240
{ url = "https://github.com/windreamer/flash-attention3-wheels/releases/download/2026.01.12-6b9e0bf/flash_attn_3-3.0.0b1%2B20260112.cu129torch291cxx11abitrue.ea8f73-cp39-abi3-linux_x86_64.whl", marker = "extra == 'gpu'"},
245241
]
246242
# FlashAttention 4 is developed under flash-attention/flash-attn/cute folder as a standalone python project.
247-
# Pinned to 02/20/2026 latest main commit.
248-
flash-attn-cute = { git = "https://github.com/Dao-AILab/flash-attention", subdirectory = "flash_attn/cute", rev = "6079a9bf4cfd7af8e7586afea6c49a97ebddf46e" }
243+
# Pinned to 03/10/2026 latest main commit.
244+
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention", subdirectory = "flash_attn/cute", rev = "7fd16f28bffe71c9ab6b7eecc5dd14bf87c1dc9e" }
249245

250246
# Download av wheel directly to avoid FFmpeg build dependency issues in CI.
251247
av = { url = "https://files.pythonhosted.org/packages/f8/9a/8ffabfcafb42154b4b3a67d63f9b69e68fa8c34cb39ddd5cb813dd049ed4/av-14.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", marker = "extra == 'audio' or extra == 'video'" }

0 commit comments

Comments
 (0)