Skip to content

Commit 464847c

Browse files
[#9717][chore] Standardize MoE weights interface (#10295)
Signed-off-by: Tal Cherckez <[email protected]>
1 parent ef1d4a4 commit 464847c

File tree

6 files changed

+556
-614
lines changed

6 files changed

+556
-614
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py

Lines changed: 14 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -102,98 +102,38 @@ def torch_moe(
102102
Unified Mixture-of-Experts (MoE) operator that uses a Mixtral-style dispatch
103103
(token routing + index_add_ accumulation) and a selectable per-expert MLP.
104104
105-
Supports both:
106-
- Standard MoE with per-expert weight lists (apply_routing_on_input=False)
107-
- Llama4 MoE with stacked weight tensors (apply_routing_on_input=True)
108-
109105
Parameters:
110106
x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size,
111107
S is the sequence length, and H is the hidden size.
112108
selected_experts (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the indices
113109
of the selected experts for each token. Only experts within range [0,num_experts) is processed
114110
routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized
115111
routing weights for the selected experts.
116-
- Standard MoE: softmax normalized weights
117-
- Llama4 MoE: sigmoid activated weights
118-
w1_weight:
119-
For per-expert lists:
120-
• is_gated_mlp==True: List of W1 with shape (I, H) — "gate" projection.
121-
• is_gated_mlp==False: List of W_up with shape (I, H) — up projection.
122-
For stacked tensors (Llama4):
123-
• Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format
124-
w2_weight:
125-
For per-expert lists:
126-
• List of W2/W_down with shape (H, I) — down projection.
127-
For stacked tensors (Llama4):
128-
• Single-element list containing stacked w2 tensor with shape (E, H, I) in TRT-LLM format
129-
w3_weight:
130-
For per-expert lists with gated_mlp:
131-
• List of W3 with shape (I, H) — "up" (second) projection in gated MLP.
132-
For is_gated_mlp==False or stacked tensors:
133-
• pass an empty list []; ignored.
134-
is_gated_mlp:
135-
Selects the per-expert MLP computation:
136-
• is_gated_mlp==True (default, Mixtral/DeepSeek/Llama4-style):
137-
y = W2( act(W1 x) * (W3 x) )
138-
• is_gated_mlp==False (NemotronH-style 2-layer MLP):
139-
y = W_down( act(W_up x) )
140-
act_fn:
141-
Elementwise activation applied inside the expert MLP.
112+
w1_weight: List of per-expert weight tensors of up projection.
113+
w2_weight: List of per-expert weight tensors of down projection.
114+
w3_weight: List of per-expert weight tensors of gate projection.
115+
is_gated_mlp: If True, use a gated MLP. If False, use a simple MLP.
116+
act_fn: Activation function applied inside the expert MLP.
142117
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
143-
apply_routing_on_input:
144-
If True (Llama4 pattern): multiply routing weights with INPUT before MLP
145-
Result: act(input * routing_weight) - routing affects activation
146-
If False (standard pattern): multiply routing weights with OUTPUT after MLP
147-
Result: act(input) * routing_weight - routing scales output
118+
apply_routing_on_input: If True, multiply routing weights with INPUT before MLP
119+
This means: silu(input * routing_weight)
120+
If False, multiply routing weights with OUTPUT after MLP
121+
This means: silu(input) * routing_weight
148122
Returns:
149123
torch.Tensor: Output tensor with the same shape as the input x.
150124
"""
151125
torch_act_fn = _resolve_torch_fn(act_fn)
152126

153-
# Detect if using stacked tensor format (Llama4) vs per-expert lists (standard)
154-
is_stacked = len(w1_weight) == 1 and w1_weight[0].ndim == 3
155-
156-
# Todo: either change torch_moe to use a single condition, or refactor this code.
157-
# it should be :
158-
# is_gated_mlp:
159-
# stacked:
160-
# ...
161-
# not stacked:
162-
# .
163-
# else:
164-
# assert (not stacked)
165-
# ...
166-
# .
167-
if is_stacked:
168-
# Llama4 stacked tensor format - only supports gated_mlp
169-
if not is_gated_mlp:
170-
raise ValueError("Stacked tensor format only supports gated MLP style")
171-
172-
w3_w1_stacked = w1_weight[0] # (E, 2*I, H)
173-
intermediate_size = w3_w1_stacked.shape[1] // 2
174-
w2_stacked = w2_weight[0] # (E, H, I)
175-
176-
def make_mlp(idx: int):
177-
gate_up = w3_w1_stacked[idx] # (2*I, H)
178-
W3 = gate_up[:intermediate_size, :] # (I, H)
179-
W1 = gate_up[intermediate_size:, :] # (I, H)
180-
W2 = w2_stacked[idx] # (H, I)
181-
weight_dtype = W1.dtype
182-
return lambda inp: F.linear(
183-
torch_act_fn(F.linear(inp.to(weight_dtype), W1))
184-
* F.linear(inp.to(weight_dtype), W3),
185-
W2,
186-
)
187-
188-
mlps = [make_mlp(idx) for idx in range(w3_w1_stacked.shape[0])]
189-
190-
elif is_gated_mlp:
127+
mlps = []
128+
if is_gated_mlp:
191129
# Standard per-expert list format with gated MLP
192130
def make_mlp(i: int):
193131
W1 = w1_weight[i] # (I, H)
194132
W2 = w2_weight[i] # (H, I)
195133
W3 = w3_weight[i] # (I, H)
196-
return lambda inp: F.linear(torch_act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2)
134+
return lambda inp: F.linear(
135+
torch_act_fn(F.linear(inp.to(W1.dtype), W1)) * F.linear(inp.to(W3.dtype), W3), W2
136+
)
197137

198138
mlps = [make_mlp(i) for i in range(len(w1_weight))]
199139

0 commit comments

Comments
 (0)