|
| 1 | +# Kernel Selection in VeOmni |
| 2 | + |
| 3 | +VeOmni selects optimized kernel implementations for attention, cross-entropy |
| 4 | +loss, Liger fused ops, and MoE at different points in the lifecycle. This |
| 5 | +document describes every selection mechanism, when it fires, and how to |
| 6 | +configure it. |
| 7 | + |
| 8 | +## Quick Reference |
| 9 | + |
| 10 | +| Kernel | Config field | Env var | Default | Selection time | |
| 11 | +|--------|-------------|---------|---------|----------------| |
| 12 | +| Attention | `attn_implementation` | — | `"flash_attention_2"` | Config `__post_init__` + `build_foundation_model` | |
| 13 | +| Cross-entropy loss | — | `VEOMNI_USE_LIGER_KERNEL` | `"1"` | Import time | |
| 14 | +| Liger fused ops (RMSNorm, RoPE, SwiGLU) | — | `VEOMNI_USE_LIGER_KERNEL` | `"1"` | Model registration (import time) | |
| 15 | +| MoE implementation | `moe_implementation` | — | `None` | `build_foundation_model` | |
| 16 | + |
| 17 | +All config fields live in `OpsImplementationConfig` (`veomni/arguments/arguments_types.py`), |
| 18 | +accessible via `model.ops_implementation.*` in YAML. |
| 19 | + |
| 20 | +--- |
| 21 | + |
| 22 | +## Lifecycle Overview |
| 23 | + |
| 24 | +``` |
| 25 | +import veomni # (1) import time |
| 26 | + └─ apply_ops_patch() |
| 27 | + ├─ apply_veomni_attention_patch() # register FA2/3/4 with SP |
| 28 | + ├─ apply_veomni_loss_patch() # bind cross-entropy kernel |
| 29 | + └─ (MoE patch is NOT applied here) |
| 30 | + |
| 31 | +MODELING_REGISTRY.register() # (2) model class registration |
| 32 | + └─ gpu_patch files # Liger RMSNorm/RoPE/SwiGLU |
| 33 | + |
| 34 | +OpsImplementationConfig.__post_init__() # (3) config parse time |
| 35 | + └─ rewrite attn_implementation for SP |
| 36 | + |
| 37 | +build_foundation_model(...) # (4) model build time |
| 38 | + ├─ apply_veomni_fused_moe_patch(backend=) # bind MoE GEMM kernel |
| 39 | + ├─ config._moe_implementation = ... |
| 40 | + └─ model init + weight loading |
| 41 | + |
| 42 | +model.forward() # (5) runtime |
| 43 | + ├─ attention: ALL_ATTENTION_FUNCTIONS[config._attn_implementation] |
| 44 | + ├─ loss: _cross_entropy(...) |
| 45 | + └─ MoE: fused_moe_forward(...) or eager loop |
| 46 | +``` |
| 47 | + |
| 48 | +--- |
| 49 | + |
| 50 | +## 1. Attention |
| 51 | + |
| 52 | +### Config |
| 53 | + |
| 54 | +```yaml |
| 55 | +model: |
| 56 | + ops_implementation: |
| 57 | + attn_implementation: flash_attention_2 # default |
| 58 | +``` |
| 59 | + |
| 60 | +**Field:** `OpsImplementationConfig.attn_implementation` |
| 61 | + |
| 62 | +### Available implementations |
| 63 | + |
| 64 | +| Value | Kernel | Sequence Parallel | Requirements | |
| 65 | +|-------|--------|:-:|---| |
| 66 | +| `eager` | PyTorch | No | — | |
| 67 | +| `sdpa` | `F.scaled_dot_product_attention` | No | — | |
| 68 | +| `flash_attention_2` | Flash Attention v2 | Yes | `flash-attn` | |
| 69 | +| `flash_attention_3` | Flash Attention v3 | Yes | `flash-attn-interface` | |
| 70 | +| `flash_attention_4` | Flash Attention v4 | Yes | `flash-attn.cute` | |
| 71 | +| `native-sparse` | Sparse attention | No | — | |
| 72 | + |
| 73 | +When `MODELING_BACKEND=veomni` (the default), `__post_init__` automatically |
| 74 | +rewrites `flash_attention_2/3/4` to VeOmni SP-aware variants |
| 75 | +(`veomni_flash_attention_2_with_sp`, etc.) which wrap the underlying kernel |
| 76 | +with DeepSpeed Ulysses sequence parallelism gather/scatter. This is why FA2/3/4 |
| 77 | +support SP — the rewrite is transparent to the user. |
| 78 | + |
| 79 | +### Selection flow |
| 80 | + |
| 81 | +1. **Config `__post_init__`** — `flash_attention_2` → `veomni_flash_attention_2_with_sp` |
| 82 | +2. **`build_foundation_model`** — passed to HuggingFace `AutoModel.from_config(attn_implementation=...)`, stored as `config._attn_implementation` |
| 83 | +3. **Import-time registration** — `apply_veomni_attention_patch()` registers the VeOmni names in `ALL_ATTENTION_FUNCTIONS` |
| 84 | +4. **Forward** — Transformers dispatches to `flash_attention_forward()` via `ALL_ATTENTION_FUNCTIONS[config._attn_implementation]` |
| 85 | + |
| 86 | +### Key files |
| 87 | + |
| 88 | +- Config: `veomni/arguments/arguments_types.py` — `OpsImplementationConfig` |
| 89 | +- Registration: `veomni/ops/flash_attn/__init__.py` — `apply_veomni_attention_patch()`, `flash_attention_forward()` |
| 90 | +- Plumbing: `veomni/models/auto.py` — `build_foundation_model(attn_implementation=...)` |
| 91 | + |
| 92 | +--- |
| 93 | + |
| 94 | +## 2. Cross-Entropy Loss |
| 95 | + |
| 96 | +### Config |
| 97 | + |
| 98 | +No config field. Controlled by environment variable. |
| 99 | + |
| 100 | +| Env var | Default | Values | |
| 101 | +|---------|---------|--------| |
| 102 | +| `VEOMNI_USE_LIGER_KERNEL` | `"1"` | `"0"` / `"1"` | |
| 103 | +| `VEOMNI_ENABLE_CHUNK_LOSS` | `"0"` | `"0"` / `"1"` (NPU only) | |
| 104 | + |
| 105 | +### Available implementations |
| 106 | + |
| 107 | +| Implementation | When selected | |
| 108 | +|---|---| |
| 109 | +| `fused_liger_kernel_cross_entropy` | GPU + Liger installed + `VEOMNI_USE_LIGER_KERNEL=1` | |
| 110 | +| `eager_cross_entropy` | GPU fallback, or NPU | |
| 111 | +| `chunk_loss_function` | NPU + `VEOMNI_ENABLE_CHUNK_LOSS=1` | |
| 112 | + |
| 113 | +### Selection flow |
| 114 | + |
| 115 | +`apply_veomni_loss_patch()` runs at import time and sets the global |
| 116 | +`_cross_entropy` function pointer: |
| 117 | + |
| 118 | +1. NPU → `eager_cross_entropy` (+ optional `chunk_loss_function` for `LOSS_MAPPING`) |
| 119 | +2. GPU + Liger + env `"1"` → `fused_liger_kernel_cross_entropy` |
| 120 | +3. Fallback → `eager_cross_entropy` |
| 121 | + |
| 122 | +### Key files |
| 123 | + |
| 124 | +- Selection: `veomni/ops/fused_cross_entropy/__init__.py` — `apply_veomni_loss_patch()` |
| 125 | +- Eager impl: `veomni/ops/fused_cross_entropy/eager.py` |
| 126 | +- Liger impl: `veomni/ops/fused_cross_entropy/liger_kernel.py` |
| 127 | + |
| 128 | +--- |
| 129 | + |
| 130 | +## 3. Liger Fused Ops (RMSNorm, RoPE, SwiGLU MLP) |
| 131 | + |
| 132 | +### Config |
| 133 | + |
| 134 | +No config field. Same environment variable as cross-entropy. |
| 135 | + |
| 136 | +| Env var | Default | |
| 137 | +|---------|---------| |
| 138 | +| `VEOMNI_USE_LIGER_KERNEL` | `"1"` | |
| 139 | + |
| 140 | +### What gets patched |
| 141 | + |
| 142 | +When `VEOMNI_USE_LIGER_KERNEL=1` and the `liger_kernel` package is installed, |
| 143 | +each model's `gpu_patch.py` replaces HuggingFace module classes: |
| 144 | + |
| 145 | +| Component | Original | Liger replacement | |
| 146 | +|---|---|---| |
| 147 | +| RMSNorm | `{Model}RMSNorm` | `LigerRMSNorm` | |
| 148 | +| Rotary embedding | `apply_rotary_pos_emb` | `liger_rotary_pos_emb` | |
| 149 | +| SwiGLU MLP | `{Model}MLP` | `LigerSwiGLUMLP` | |
| 150 | + |
| 151 | +### Selection flow |
| 152 | + |
| 153 | +Patching happens at model class registration time (import of the model |
| 154 | +module). Each model's `gpu_patch.py` checks: |
| 155 | + |
| 156 | +```python |
| 157 | +if is_liger_kernel_available() and get_env("VEOMNI_USE_LIGER_KERNEL") == "1": |
| 158 | + hf_module.apply_rotary_pos_emb = liger_rotary_pos_emb |
| 159 | + hf_module.ModelRMSNorm = LigerRMSNorm |
| 160 | + hf_module.ModelMLP = LigerSwiGLUMLP |
| 161 | +``` |
| 162 | + |
| 163 | +### Models with Liger support |
| 164 | + |
| 165 | +Qwen2, Qwen3, Qwen3-MoE, Qwen2-VL, DeepSeek-V3, Llama, Seed-OSS. |
| 166 | + |
| 167 | +### Key files |
| 168 | + |
| 169 | +- `veomni/models/transformers/{model}/gpu_patch.py` (7 model-specific files) |
| 170 | + |
| 171 | +--- |
| 172 | + |
| 173 | +## 4. MoE Kernel |
| 174 | + |
| 175 | +MoE kernel selection is controlled by a single `moe_implementation` field: |
| 176 | + |
| 177 | +```yaml |
| 178 | +model: |
| 179 | + ops_implementation: |
| 180 | + moe_implementation: fused # Triton group-gemm (default fused path) |
| 181 | + # moe_implementation: fused_quack # Quack CUTLASS/CuTe kernels (SM90+) |
| 182 | + # moe_implementation: eager # Reference PyTorch loop (very slow, debug only) |
| 183 | +``` |
| 184 | + |
| 185 | +**Field:** `OpsImplementationConfig.moe_implementation` |
| 186 | +**Default:** `None` (falls back to `"eager"` per model config) |
| 187 | + |
| 188 | +| Value | Kernel | Hardware | EP support | |
| 189 | +|-------|--------|----------|:----------:| |
| 190 | +| `eager` | PyTorch expert loop | Any | No | |
| 191 | +| `fused` | Triton group-gemm (`group_gemm_same_nk`) | SM70+ (V100+) | Yes | |
| 192 | +| `fused_quack` | Quack CUTLASS/CuTe (`quack.gemm_interface.gemm`) | SM90+ (H100+) | No | |
| 193 | +| *(NPU auto)* | NPU group-gemm | Ascend NPU | Yes | |
| 194 | + |
| 195 | +Models only see `_moe_implementation` as `"eager"` or `"fused"` — the |
| 196 | +`fused_quack` variant is mapped to `"fused"` on the config, with the kernel |
| 197 | +backend selected separately via `apply_veomni_fused_moe_patch`. |
| 198 | + |
| 199 | +On NPU devices, the backend parameter is ignored — the NPU kernel is always |
| 200 | +selected. |
| 201 | + |
| 202 | +### Selection flow |
| 203 | + |
| 204 | +Unlike attention and loss, the MoE patch is **not** applied at import time. |
| 205 | +It is applied inside `build_foundation_model()`: |
| 206 | + |
| 207 | +```python |
| 208 | +def build_foundation_model(..., moe_implementation="fused_quack"): |
| 209 | + config._moe_implementation = "fused" |
| 210 | + apply_veomni_fused_moe_patch(moe_implementation="fused_quack") |
| 211 | +``` |
| 212 | + |
| 213 | +This deferred approach allows the config to drive kernel selection without |
| 214 | +env vars. |
| 215 | + |
| 216 | +### Usage |
| 217 | + |
| 218 | +**Via config (YAML):** |
| 219 | + |
| 220 | +```yaml |
| 221 | +model: |
| 222 | + ops_implementation: |
| 223 | + moe_implementation: fused_quack |
| 224 | +``` |
| 225 | + |
| 226 | +**Via `build_foundation_model` (standalone scripts):** |
| 227 | + |
| 228 | +```python |
| 229 | +model = build_foundation_model( |
| 230 | + config_path="...", |
| 231 | + moe_implementation="fused_quack", |
| 232 | +) |
| 233 | +``` |
| 234 | + |
| 235 | +**Direct patch (tests / benchmarks):** |
| 236 | + |
| 237 | +```python |
| 238 | +from veomni.ops.fused_moe import apply_veomni_fused_moe_patch |
| 239 | +apply_veomni_fused_moe_patch(moe_implementation="fused_quack") |
| 240 | +``` |
| 241 | + |
| 242 | +### Key files |
| 243 | + |
| 244 | +- Config: `veomni/arguments/arguments_types.py` — `OpsImplementationConfig` |
| 245 | +- Dispatch: `veomni/ops/fused_moe/__init__.py` — `apply_veomni_fused_moe_patch()` |
| 246 | +- Triton impl: `veomni/ops/fused_moe/group_gemm.py` |
| 247 | +- Quack impl: `veomni/ops/fused_moe/quack_gemm.py` |
| 248 | +- NPU impl: `veomni/ops/fused_moe/npu_group_gemm.py` |
| 249 | +- Plumbing: `veomni/models/auto.py` — `build_foundation_model(moe_implementation=...)` |
| 250 | + |
| 251 | +--- |
| 252 | + |
| 253 | +## Environment Variables Summary |
| 254 | + |
| 255 | +| Env var | Default | Scope | Notes | |
| 256 | +|---------|---------|-------|-------| |
| 257 | +| `MODELING_BACKEND` | `"veomni"` | Global | `"veomni"` or `"hf"` — controls whether VeOmni ops patches are applied | |
| 258 | +| `VEOMNI_USE_LIGER_KERNEL` | `"1"` | Global | Controls Liger kernel for RMSNorm/RoPE/SwiGLU + cross-entropy loss | |
| 259 | +| `USE_GROUP_GEMM` | `"1"` | MoE | Gate for Triton group-gemm availability; set `"0"` to force fallback | |
| 260 | +| `VEOMNI_ENABLE_CHUNK_LOSS` | `"0"` | NPU only | Enable chunked loss computation | |
| 261 | + |
| 262 | +All env vars are registered in `veomni/utils/env.py` with defaults and can be |
| 263 | +overridden by setting the corresponding shell environment variable. |
0 commit comments