Skip to content

Commit 6a8baf8

Browse files
authored
feat: add sonicmoe (#3411)
* feat: add sonicmoe * feat: add torch compile for routing * feat: add routing smoke test * feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe * fix: disable mlp kernel for sonicmoe too * feat: update to sonicmoe release * chore: update import following new sonicmoe changes * feat: update handling for blackwell * feat: add sonicmoe e2e test * fix: installation for updated sonicmoe * fix: git commit * fix: ignore py req and fix metadata * fix: increase min hidden size to match sonicmoe kernel min * fix: attempt properly interleave and handle unpatch mid-test * chore: refactor teardown better * chore: refactor to re-use rearrange * fix: add idempotency guard * fix: address comments on CI memory and interleave * fix: tests grad, param doublewrapped
1 parent 1eaf4d7 commit 6a8baf8

File tree

12 files changed

+1698
-42
lines changed

12 files changed

+1698
-42
lines changed

src/axolotl/integrations/kernels/README.md

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class ExpertsInterface(GeneralInterface):
1010
}
1111
```
1212

13-
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
13+
In our custom integration, we add support for **ScatterMoE** and **SonicMoE**, which are more efficient and faster than `grouped_mm`.
1414

1515
## Usage
1616

@@ -21,23 +21,55 @@ plugins:
2121
- axolotl.integrations.kernels.KernelsPlugin
2222

2323
use_kernels: true
24+
25+
# Choose one (mutually exclusive):
2426
use_scattermoe: true
27+
# OR
28+
use_sonicmoe: true
29+
```
30+
31+
**Important:** Setting `experts_implementation` is incompatible with custom kernel options.
32+
33+
### SonicMoE installation
34+
35+
**Prerequisites:**
36+
- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU
37+
- CUDA 12.9+ (13.0+ for B300)
38+
- PyTorch 2.7+ (2.9.1 recommended)
39+
- For B300: Triton 3.6.0
40+
41+
```bash
42+
pip install --ignore-requires-python --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5
2543
```
2644

27-
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
45+
See the [SonicMoE installation guide](https://github.com/Dao-AILab/sonic-moe?tab=readme-ov-file#-installation) for the latest prerequisite details.
46+
47+
**Note:** Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels.
2848

2949
## How It Works
3050

3151
The `KernelsPlugin` runs before model loading and:
3252

33-
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
53+
### ScatterMoE
54+
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
3455
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
3556

36-
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
57+
### SonicMoE
58+
1. Resolves the model's MoE block class(es) from `constants.py`.
59+
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
60+
3. Supports both softmax->topk and sigmoid->topk routing strategies.
61+
62+
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
63+
64+
#### Supported Models
65+
66+
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
3767

3868
## Limitations
3969

40-
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
70+
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
71+
72+
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
4173

4274
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
4375

src/axolotl/integrations/kernels/args.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@
66

77

88
class KernelsArgs(BaseModel):
9-
use_scattermoe: bool | None = True
9+
use_scattermoe: bool | None = None
10+
use_sonicmoe: bool | None = None
11+
12+
@model_validator(mode="before")
13+
@classmethod
14+
def check_mutually_exclusive(cls, data):
15+
if data.get("use_scattermoe") and data.get("use_sonicmoe"):
16+
raise ValueError(
17+
"Cannot use both ScatterMoE and SonicMoE simultaneously. "
18+
"Please set only one of `use_scattermoe` or `use_sonicmoe` to true."
19+
)
20+
return data
1021

1122
@model_validator(mode="before")
1223
@classmethod
@@ -36,11 +47,11 @@ def check_experts_implementation(cls, data):
3647

3748
@model_validator(mode="before")
3849
@classmethod
39-
def disable_mlp_kernel_scattermoe(cls, data):
40-
if data.get("use_scattermoe") is True:
50+
def disable_mlp_kernel(cls, data):
51+
if data.get("use_scattermoe") is True or data.get("use_sonicmoe") is True:
4152
if data.get("lora_mlp_kernel") is True:
4253
LOG.warning(
43-
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
54+
"Disabling lora_mlp_kernel when using custom MoE kernels due to compatibility issues."
4455
)
4556
data["lora_mlp_kernel"] = False
4657
data["mlp_kernel"] = False
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""
2+
Supported MoE block mappings for kernel integrations.
3+
4+
Maps model_type to the SparseMoeBlock class name(s) in transformers.
5+
Used by both ScatterMoE and SonicMoE kernel paths.
6+
7+
Values can be a single class name (str) or a list of class names for models
8+
with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker).
9+
"""
10+
11+
import importlib
12+
13+
SPARSE_MOE_BLOCK = {
14+
# softmax -> topk routing
15+
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
16+
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
17+
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
18+
"qwen3_next": "Qwen3NextSparseMoeBlock",
19+
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
20+
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
21+
"qwen3_omni_moe": [
22+
"Qwen3OmniMoeThinkerTextSparseMoeBlock",
23+
"Qwen3OmniMoeTalkerTextSparseMoeBlock",
24+
],
25+
"olmoe": "OlmoeSparseMoeBlock",
26+
"mixtral": "MixtralSparseMoeBlock",
27+
"minimax": "MiniMaxSparseMoeBlock",
28+
# sigmoid -> topk routing (with group-based expert selection)
29+
"glm_moe_dsa": "GlmMoeDsaMoE",
30+
"deepseek_v3": "DeepseekV3MoE",
31+
"glm4_moe": "Glm4MoeMoE",
32+
"glm4_moe_lite": "Glm4MoeLiteMoE",
33+
"glm4v_moe": "Glm4vMoeTextMoE",
34+
# sigmoid -> topk routing (no group selection)
35+
"minimax_m2": "MiniMaxM2SparseMoeBlock",
36+
# Models below need custom routing (not yet implemented):
37+
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
38+
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
39+
# "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing
40+
# "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases
41+
}
42+
43+
44+
def resolve_moe_block_classes(model_type: str):
45+
"""Resolve all MoE block classes from transformers for the given model type.
46+
47+
Returns a list of classes (one for most models, multiple for models with
48+
distinct MoE block types like qwen3_omni_moe).
49+
"""
50+
entry = SPARSE_MOE_BLOCK.get(model_type)
51+
if entry is None:
52+
raise ValueError(
53+
f"Unsupported MoE model type '{model_type}'. "
54+
f"Supported types: {list(SPARSE_MOE_BLOCK.keys())}"
55+
)
56+
57+
cls_names = entry if isinstance(entry, list) else [entry]
58+
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
59+
module = importlib.import_module(module_path)
60+
61+
classes = []
62+
for cls_name in cls_names:
63+
moe_cls = getattr(module, cls_name, None)
64+
if moe_cls is None:
65+
raise ValueError(f"Could not find class '{cls_name}' in '{module_path}'")
66+
classes.append(moe_cls)
67+
68+
return classes

src/axolotl/integrations/kernels/plugin.py

Lines changed: 81 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,59 @@
1+
import importlib
2+
import os
13
from pathlib import Path
24

3-
from kernels import (
4-
LocalLayerRepository,
5-
Mode,
6-
register_kernel_mapping,
7-
replace_kernel_forward_from_hub,
8-
)
5+
import torch
96

107
from axolotl.integrations.base import BasePlugin
11-
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
8+
from axolotl.utils.logging import get_logger
9+
10+
LOG = get_logger(__name__)
11+
12+
13+
def _check_sonicmoe_gpu_compat():
14+
"""Validate GPU compute capability for SonicMoE and configure env.
15+
16+
Supported: Hopper (sm_90), Blackwell (sm_100 - sm_103).
17+
B300 (sm_103) additionally requires Triton 3.6.0.
18+
"""
19+
if not torch.cuda.is_available():
20+
return
21+
22+
cc = torch.cuda.get_device_capability()
23+
24+
if cc < (9, 0):
25+
raise RuntimeError(
26+
f"SonicMoE requires Hopper (sm_90) or Blackwell (sm_100+) GPU, "
27+
f"but detected sm_{cc[0]}{cc[1]}."
28+
)
29+
30+
if cc > (10, 3):
31+
raise RuntimeError(
32+
f"SonicMoE does not yet support sm_{cc[0]}{cc[1]}. "
33+
f"Supported: Hopper (sm_90) and Blackwell (sm_100 - sm_103)."
34+
)
35+
36+
# Blackwell (sm_100+): enable QuACK GEMM kernels
37+
if cc >= (10, 0):
38+
os.environ.setdefault("USE_QUACK_GEMM", "1")
39+
LOG.info(
40+
f"Blackwell GPU (sm_{cc[0]}{cc[1]}) detected, enabling USE_QUACK_GEMM=1"
41+
)
42+
43+
# B300 (sm_103): requires Triton 3.6.0
44+
if cc == (10, 3):
45+
triton_spec = importlib.util.find_spec("triton")
46+
if triton_spec is None:
47+
raise RuntimeError(
48+
"B300 (sm_103) requires Triton 3.6.0, but Triton is not installed."
49+
)
50+
import triton
51+
52+
triton_version = tuple(int(x) for x in triton.__version__.split(".")[:2])
53+
if triton_version != (3, 6):
54+
raise RuntimeError(
55+
f"B300 (sm_103) requires Triton 3.6.x, but found {triton.__version__}."
56+
)
1257

1358

1459
class KernelsPlugin(BasePlugin):
@@ -19,8 +64,32 @@ def pre_model_load(self, cfg):
1964
if cfg.use_scattermoe:
2065
self._register_kernels()
2166
self._kernelize_model(cfg.model_config_type)
67+
elif cfg.use_sonicmoe:
68+
if not importlib.util.find_spec("sonicmoe"):
69+
raise RuntimeError(
70+
"SonicMoE is not installed. See installation instructions at "
71+
"https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/kernels/README.md#sonicmoe-installation"
72+
)
73+
74+
_check_sonicmoe_gpu_compat()
75+
76+
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
77+
78+
LOG.info(
79+
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
80+
)
81+
patch_sonicmoe(
82+
cfg.model_config_type,
83+
torch_compile=bool(getattr(cfg, "torch_compile", False)),
84+
)
2285

2386
def _register_kernels(self):
87+
from kernels import (
88+
LocalLayerRepository,
89+
Mode,
90+
register_kernel_mapping,
91+
)
92+
2493
plugin_root = Path(__file__).parent
2594
register_kernel_mapping(
2695
{
@@ -42,25 +111,11 @@ def _register_kernels(self):
42111
)
43112

44113
def _kernelize_model(self, model_type: str):
45-
if model_type == "olmoe":
46-
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
114+
from kernels import replace_kernel_forward_from_hub
115+
116+
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
47117

118+
for model_moe_cls in resolve_moe_block_classes(model_type):
48119
replace_kernel_forward_from_hub(
49-
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
120+
model_moe_cls, "HFScatterMoEParallelExperts"
50121
)
51-
else:
52-
try:
53-
model_moe_cls = get_model_moe_block(model_type)
54-
replace_kernel_forward_from_hub(
55-
model_moe_cls, "HFScatterMoEParallelExperts"
56-
)
57-
except Exception as err:
58-
raise ValueError(f"Unsupported model type: {model_type}") from err
59-
60-
61-
def get_model_moe_block(model_type: str):
62-
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
63-
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
64-
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
65-
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
66-
return model_cls
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .patch import patch_sonicmoe
2+
3+
__all__ = ["patch_sonicmoe"]

0 commit comments

Comments
 (0)