-
Notifications
You must be signed in to change notification settings - Fork 1
MoE Implementation
gpt-oss leverages MoE architecture to reduce the number of active parameters needed to process input.
Here you can see the MoE configurations of gpt-oss:
| Model | Layers | Total Params | Active Params Per Token | Total Experts | Active Experts Per Token | Context Length |
|---|---|---|---|---|---|---|
| gpt-oss-120b | 36 | 117B | 5.1B | 128 | 4 | 128k |
| gpt-oss-20b | 24 | 21B | 3.6B | 32 | 4 | 128k |
Reference: Introducing gpt-oss (OpenAI News - Product)
During early implementation, we encountered buf != nil is False errors due to memory limitations.
This led us to implement custom TIR kernels for MoE operations instead of relying on high-level operators (nn.op.take and relax.op.einsum).
Here, we don't gather the tokens routed to each expert, but directly perform computations similar to those of high-level operators in gpt-oss.
This is because we prioritized homogeneity from the gpt-oss reference. (c.f. Design Philosophy)
- Check
T.grid(...)→ outer loop dimensions (output shape) - Find
T.serial(...)→ inner loop (reduction axis) - Look at
sum_buffer[0] += ...→ core computation
run_einsum performs matrix multiplication between input tensor and the selected expert parameters.
-
MLP #1
# MLP #1 t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
-
MLP #2
# MLP #2 t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias
- Input shapes:
-
input_tensor:(seq_len, in_features)or(seq_len, experts_per_token, in_features) -
export_indices:(seq_len, experts_per_token) -
weight_tensor:(num_all_experts, out_features, in_features) -
bias_tensor:(num_all_experts, out_features)
-
- Output shape:
(seq_len, experts_per_token, out_features) - Operation: For each selected expert, computes
sum(input * weight)overin_features
Code
@staticmethod
def _run_einsum(
input_tensor: Tensor,
expert_indices: Tensor,
weight_tensor: Tensor,
bias_tensor: Tensor
) -> Tensor:
threads_per_block = 256
# here, len(input_tensor.shape) is 2 or 3
# (seq_len, 2880); normed tensor in 1st operation
# (seq_len, top_k, 2880)
seq_len = input_tensor.shape[0]
in_features = input_tensor.shape[-1]
input_shape = input_tensor.shape
input_is_3d = len(input_shape) == 3
dtype = input_tensor.dtype
_, experts_per_token = expert_indices.shape # (seq_len, experts_per_token)
num_all_experts, out_features, _ = weight_tensor.shape # (32, 5760, 2880) | (32, 2880, 2880)
@T.prim_func(private=True)
def _einsum_matrix(
x_handle: T.handle,
e_indices: T.Buffer((seq_len, experts_per_token), "int32"),
weight: T.Buffer((num_all_experts, out_features, in_features), dtype),
bias: T.Buffer((num_all_experts, out_features), dtype),
out_handle: T.handle
):
T.func_attr({"op_pattern": 4, "tir.noalias": True})
x = T.match_buffer(x_handle, input_shape, dtype)
out = T.match_buffer(
out_handle,
(seq_len, experts_per_token, out_features),
"float32"
)
for s, r, c in T.grid(seq_len, experts_per_token, out_features):
# s, r, c for seq_idx, expert_rank, out_channel, in_channel, resp.
with T.block("compute_matmul"):
# S for spatial axis in the argument `kinds`
seq_idx, expert_rank, out_idx = T.axis.remap("SSS", [s, r, c])
# e_idx = e_indices[seq_idx, expert_rank]
sum_buffer = T.alloc_buffer((1, ), dtype="float32", scope="local")
sum_buffer[0] = T.cast(bias[e_indices[seq_idx, expert_rank], out_idx], "float32")
for in_idx in T.serial(in_features):
if input_is_3d:
sum_buffer[0] += (
T.cast(x[seq_idx, expert_rank, in_idx], "float32") *
T.cast(weight[e_indices[seq_idx, expert_rank], out_idx, in_idx], "float32")
)
else:
sum_buffer[0] += (
T.cast(x[seq_idx, in_idx], "float32") *
T.cast(weight[e_indices[seq_idx, expert_rank], out_idx, in_idx], "float32")
)
out[seq_idx, expert_rank, out_idx] = sum_buffer[0]
return op.tensor_ir_op(
_einsum_matrix,
"moe_gemv_einsum",
args=[input_tensor, expert_indices, weight_tensor, bias_tensor],
out=Tensor.placeholder((seq_len, experts_per_token, out_features), "float32"),
)MoE einsum with the fused MXFP4 dequantization is described in here. Expert routing indices and weights are passed to the kernel, which performs dequantization and accumulation in a single pass.
run_gating aggregates outputs from multiple experts using expert_weights (routing scores from gating network).
-
Weighted sum of experts
# Weighted sum of experts t = torch.einsum("bec,be->bc", t, expert_weights)
- Input shapes:
-
input_tensor:(seq_len, experts_per_token, out_features) -
expert_weights:(seq_len, experts_per_token)
-
- Output shape:
(seq_len, out_features) - Operation: Computes
sum(expert_output * expert_weights)overexperts_per_token
Code
@staticmethod
def _run_gating(
input_tensor: Tensor,
expert_weights: Tensor,
):
seq_len, experts_per_token, out_features = input_tensor.shape
dtype = "bfloat16"
@T.prim_func(private=True)
def _apply_gate(
x_handle: T.handle,
e_weights: T.Buffer((seq_len, experts_per_token), dtype),
out_handle: T.handle
):
T.func_attr({"op_pattern": 4, "tir.noalias": True})
x = T.match_buffer(x_handle, (seq_len, experts_per_token, out_features), "float32")
out = T.match_buffer(
out_handle,
(seq_len, out_features),
dtype
)
for s, c in T.grid(seq_len, out_features):
with T.block("apply_gate"):
seq_idx, out_idx = T.axis.remap("SS", [s, c])
gate_buffer = T.alloc_buffer((1, ), dtype="float32", scope="local")
gate_buffer[0] = T.cast(0.0, "float32")
for expert_rank in T.serial(experts_per_token):
gate_buffer[0] += (
T.cast(x[seq_idx, expert_rank, out_idx], "float32") *
T.cast(e_weights[seq_idx, expert_rank], "float32")
)
out[seq_idx, out_idx] = T.cast(gate_buffer[0], dtype)
return op.tensor_ir_op(
_apply_gate,
"moe_gemv_gating",
args=[input_tensor, expert_weights],
out=Tensor.placeholder((seq_len, out_features), dtype),
)Both functions share a similar TIR structure but differ in their reduction axis:
run_einsum |
run_gating |
|
|---|---|---|
| Role | per-expert matmul | weighted aggregation |
| Outer loop | seq, experts_per_token, out_features |
seq, out_features |
| Inner loop (reduction) | in_features |
experts_per_token |
| Accumulator init | bias[e_idx, out_idx] |
0.0 |
| Core operation | input × weight[e_idx, out_idx, in_idx] |
input × expert_weights[seq_idx, expert_rank] |
| Output shape | (seq, experts_per_token, out_features) |
(seq, out_features) |
MLC-LLM provides MoE utilities (moe_matmul.py, moe_misc.py) with a different design:
| MLC-LLM | This implementation | |
|---|---|---|
| Expert indexing |
indptr array (cumsum-based) |
Direct expert_indices
|
| Batching | Group tokens by expert → batched GEMM | Per-token einsum |
| Dequantization | Separate dequantize_gemv
|
Fused into einsum |
MLC-LLM groups tokens assigned to the same expert together, enabling batched matrix operations. This implementation instead processes each token independently with direct expert indexing, prioritizing a faithful port of the gpt-oss reference.
- gpt-oss-tvm
- gpt-oss
- Model Card
- Blog post
- GitHub
- [Huggingface] gpt-oss-20b
- [Huggingface] gpt-oss-120b
- TVM
- MLC LLM
-
Attention & Sliding Window
- Computing attentions in TVM
- Sink Token Workaround
-
Mixture-of-Experts (MoE)
- TIR-based MoE Einsum
- Gating Network Implementation
- Comparison with Standard TVM Approaches
-
RoPE with YaRN
- What is YaRN?
- Limitations in Existing TVM Implementations
- Our Improvements
-
TIR-based support for MXFP4
- What is MXFP4?
- MXFP4 TIR Implementation
- Operator Fusion