Skip to content

MoE Implementation

Park Woorak edited this page Feb 3, 2026 · 7 revisions

Mixture-of-Experts (MoE)

MoE in gpt-oss

gpt-oss leverages MoE architecture to reduce the number of active parameters needed to process input.

Here you can see the MoE configurations of gpt-oss:

Model Layers Total Params Active Params Per Token Total Experts Active Experts Per Token Context Length
gpt-oss-120b 36 117B 5.1B 128 4 128k
gpt-oss-20b 24 21B 3.6B 32 4 128k

Reference: Introducing gpt-oss (OpenAI News - Product)

TIR-based MoE

During early implementation, we encountered buf != nil is False errors due to memory limitations.
This led us to implement custom TIR kernels for MoE operations instead of relying on high-level operators (nn.op.take and relax.op.einsum).

Here, we don't gather the tokens routed to each expert, but directly perform computations similar to those of high-level operators in gpt-oss.
This is because we prioritized homogeneity from the gpt-oss reference. (c.f. Design Philosophy)

How to read the code

  1. Check T.grid(...) → outer loop dimensions (output shape)
  2. Find T.serial(...) → inner loop (reduction axis)
  3. Look at sum_buffer[0] += ... → core computation

Matrix Multiplication between input and expert parameters

run_einsum performs matrix multiplication between input tensor and the selected expert parameters.

Corresponding PyTorch operations

  1. MLP #1
    # MLP #1
    t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
  2. MLP #2
    # MLP #2
    t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias

Implementation

  • Input shapes:
    • input_tensor: (seq_len, in_features) or (seq_len, experts_per_token, in_features)
    • export_indices: (seq_len, experts_per_token)
    • weight_tensor: (num_all_experts, out_features, in_features)
    • bias_tensor: (num_all_experts, out_features)
  • Output shape: (seq_len, experts_per_token, out_features)
  • Operation: For each selected expert, computes sum(input * weight) over in_features
Code
    @staticmethod
    def _run_einsum(
        input_tensor: Tensor,
        expert_indices: Tensor,
        weight_tensor: Tensor,
        bias_tensor: Tensor
    ) -> Tensor:
        threads_per_block = 256

        # here, len(input_tensor.shape) is 2 or 3
        # (seq_len, 2880); normed tensor in 1st operation
        # (seq_len, top_k, 2880)

        seq_len = input_tensor.shape[0]
        in_features = input_tensor.shape[-1]
        input_shape = input_tensor.shape
        input_is_3d = len(input_shape) == 3
        dtype = input_tensor.dtype

        _, experts_per_token = expert_indices.shape  # (seq_len, experts_per_token)
        num_all_experts, out_features, _ = weight_tensor.shape  # (32, 5760, 2880) | (32, 2880, 2880)

        @T.prim_func(private=True)
        def _einsum_matrix(
            x_handle: T.handle,
            e_indices: T.Buffer((seq_len, experts_per_token), "int32"),
            weight: T.Buffer((num_all_experts, out_features, in_features), dtype),
            bias: T.Buffer((num_all_experts, out_features), dtype),
            out_handle: T.handle
        ):
            T.func_attr({"op_pattern": 4, "tir.noalias": True})
            x = T.match_buffer(x_handle, input_shape, dtype)
            out = T.match_buffer(
                out_handle,
                (seq_len, experts_per_token, out_features),
                "float32"
            )

            for s, r, c in T.grid(seq_len, experts_per_token, out_features):
                # s, r, c for seq_idx, expert_rank, out_channel, in_channel, resp.
                with T.block("compute_matmul"):
                    # S for spatial axis in the argument `kinds`
                    seq_idx, expert_rank, out_idx = T.axis.remap("SSS", [s, r, c])
                    # e_idx = e_indices[seq_idx, expert_rank]

                    sum_buffer = T.alloc_buffer((1, ), dtype="float32", scope="local")
                    sum_buffer[0] = T.cast(bias[e_indices[seq_idx, expert_rank], out_idx], "float32")

                    for in_idx in T.serial(in_features):
                        if input_is_3d:
                            sum_buffer[0] += (
                                T.cast(x[seq_idx, expert_rank, in_idx], "float32") *
                                T.cast(weight[e_indices[seq_idx, expert_rank], out_idx, in_idx], "float32")
                            )
                        else:
                            sum_buffer[0] += (
                                T.cast(x[seq_idx, in_idx], "float32") *
                                T.cast(weight[e_indices[seq_idx, expert_rank], out_idx, in_idx], "float32")
                            )

                    out[seq_idx, expert_rank, out_idx] = sum_buffer[0]

        return op.tensor_ir_op(
            _einsum_matrix,
            "moe_gemv_einsum",
            args=[input_tensor, expert_indices, weight_tensor, bias_tensor],
            out=Tensor.placeholder((seq_len, experts_per_token, out_features), "float32"),
        )

Einsum with The Fused MXFP4 Dequantization

MoE einsum with the fused MXFP4 dequantization is described in here. Expert routing indices and weights are passed to the kernel, which performs dequantization and accumulation in a single pass.

Gating Network Implementation with TIR

run_gating aggregates outputs from multiple experts using expert_weights (routing scores from gating network).

Corresponding PyTorch operations

  1. Weighted sum of experts
    # Weighted sum of experts
    t = torch.einsum("bec,be->bc", t, expert_weights)

Implementation

  • Input shapes:
    • input_tensor: (seq_len, experts_per_token, out_features)
    • expert_weights: (seq_len, experts_per_token)
  • Output shape: (seq_len, out_features)
  • Operation: Computes sum(expert_output * expert_weights) over experts_per_token
Code
    @staticmethod
    def _run_gating(
        input_tensor: Tensor,
        expert_weights: Tensor,
    ):
        seq_len, experts_per_token, out_features = input_tensor.shape
        dtype = "bfloat16"

        @T.prim_func(private=True)
        def _apply_gate(
            x_handle: T.handle,
            e_weights: T.Buffer((seq_len, experts_per_token), dtype),
            out_handle: T.handle
        ):
            T.func_attr({"op_pattern": 4, "tir.noalias": True})
            x = T.match_buffer(x_handle, (seq_len, experts_per_token, out_features), "float32")
            out = T.match_buffer(
                out_handle,
                (seq_len, out_features),
                dtype
            )

            for s, c in T.grid(seq_len, out_features):
                with T.block("apply_gate"):
                    seq_idx, out_idx = T.axis.remap("SS", [s, c])

                    gate_buffer = T.alloc_buffer((1, ), dtype="float32", scope="local")
                    gate_buffer[0] = T.cast(0.0, "float32")

                    for expert_rank in T.serial(experts_per_token):
                        gate_buffer[0] += (
                                T.cast(x[seq_idx, expert_rank, out_idx], "float32") *
                                T.cast(e_weights[seq_idx, expert_rank], "float32")
                        )

                    out[seq_idx, out_idx] = T.cast(gate_buffer[0], dtype)


        return op.tensor_ir_op(
            _apply_gate,
            "moe_gemv_gating",
            args=[input_tensor, expert_weights],
            out=Tensor.placeholder((seq_len, out_features), dtype),
        )

Comparing the Two

Both functions share a similar TIR structure but differ in their reduction axis:

run_einsum run_gating
Role per-expert matmul weighted aggregation
Outer loop seq, experts_per_token, out_features seq, out_features
Inner loop (reduction) in_features experts_per_token
Accumulator init bias[e_idx, out_idx] 0.0
Core operation input × weight[e_idx, out_idx, in_idx] input × expert_weights[seq_idx, expert_rank]
Output shape (seq, experts_per_token, out_features) (seq, out_features)

Comparison with Standard TVM Approaches

MLC-LLM provides MoE utilities (moe_matmul.py, moe_misc.py) with a different design:

MLC-LLM This implementation
Expert indexing indptr array (cumsum-based) Direct expert_indices
Batching Group tokens by expert → batched GEMM Per-token einsum
Dequantization Separate dequantize_gemv Fused into einsum

MLC-LLM groups tokens assigned to the same expert together, enabling batched matrix operations. This implementation instead processes each token independently with direct expert indexing, prioritizing a faithful port of the gpt-oss reference.

Getting Started

1. Architectural Implementations

2. Low-Level Optimization

Clone this wiki locally