Skip to content

Commit c4b36d3

Browse files
[#10137][feat] AutoDeploy FP8 MoE refactor (#10138)
The trtllm (cutlass) fp8 moe operator performs W3+W1 fusion (concat) during inference and we want to move this fusion to the model optimization time. The Cutlass MoE kernel is used thru a trtllm torch operator. Its implementation uses two FC operations (fc1 and fc2) while the canonical MoE API defines three GEMM operations and their associated weights (W1, W2, W3) so when we switch from the torch.moe op to the trtllm.moe op we also change terminology from w1, w2, w3 to fc1, fc2. Signed-off-by: Neta Zmora <[email protected]>
1 parent 8614cd3 commit c4b36d3

File tree

4 files changed

+241
-267
lines changed

4 files changed

+241
-267
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 48 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -107,72 +107,67 @@ def trtllm_quant_fp8_moe_fused(
107107
x: torch.Tensor,
108108
selected_experts: torch.Tensor,
109109
routing_weights: torch.Tensor,
110-
w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights
111-
w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights
112-
w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp
113-
w1_input_scale: torch.Tensor, # [E] stacked input scales
114-
w2_input_scale: torch.Tensor, # [E] stacked input scales
115-
w3_input_scale: torch.Tensor, # [E] or unused
116-
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
117-
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
118-
w3_weight_scale: torch.Tensor, # [E] or unused
119-
gemm1_dequant: torch.Tensor, # [E]
120-
gemm2_act_quant: torch.Tensor, # [E]
121-
gemm2_dequant: torch.Tensor, # [E]
110+
fc1_expert_weights: torch.Tensor,
111+
fc2_expert_weights: torch.Tensor,
112+
fc1_act_scale: torch.Tensor,
113+
fc1_dequant_scale: torch.Tensor,
114+
fc2_act_scale_reciprocal: torch.Tensor,
115+
fc2_dequant_scale: torch.Tensor,
122116
is_gated_mlp: bool = True,
123117
act_fn: int = int(ActivationType.Silu),
124118
) -> torch.Tensor:
125-
"""
126-
TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP.
119+
"""TensorRT-LLM Cutlass FP8 (W8A8) MoE for gated and non-gated MLP.
120+
121+
Computes (per expert):
122+
For gated_mlp:
123+
y = (act(x @ w1.T) * (x @ w3.T)) @ w2.T # act := SiLU
124+
For mlp:
125+
y = act(x @ w1.T) @ w2.T # act := ReLU^2
126+
Notes:
127+
- FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T)
128+
- FC2 implements: fc2_output = fc1_output @ w2.T
129+
- FC1 weights are concatenated w3 and w1 if gated_mlp, otherwise w1
130+
127131
Parameters:
128132
x: BF16/FP16 input tensor of shape (B, H) or (B, S, H)
129133
selected_experts: Expert indices (B*S, TOP_K)
130134
routing_weights: Routing weights (B*S, TOP_K)
131-
w1_weight: FP8 w1 weights [E, I, H]
132-
w2_weight: FP8 w2 weights [E, H, I]
133-
w3_weight: FP8 w3 weights [E, I, H] (for gated_mlp)
134-
w1_input_scale: Input scales for w1 [E]
135-
w2_input_scale: Input scales for w2 [E]
136-
w3_input_scale: Input scales for w3 [E]
137-
w1_weight_scale: Weight scales for w1 [E]
138-
w2_weight_scale: Weight scales for w2 [E]
139-
w3_weight_scale: Weight scales for w3 [E]
140-
gemm1_dequant: Precomputed gemm1 dequant scale [E]
141-
gemm2_act_quant: Precomputed gemm2 act quant scale [1]
142-
gemm2_dequant: Precomputed gemm2 dequant scale [E]
135+
fc1_expert_weights: FC1 weights [E, 2*I, H] for gated_mlp, [E, I, H] for mlp
136+
fc2_expert_weights: FC2 weights [E, H, I]
137+
fc1_act_scale: FC1 activation scale [E]
138+
fc1_dequant_scale: FC1 dequant scale [E]
139+
fc2_act_scale_reciprocal: FC2 activation scale reciprocal [E]
140+
fc2_dequant_scale: FC2 dequant scale [E]
143141
is_gated_mlp: True for gated_mlp, False for mlp
144142
act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp
145143
146-
Non-Gated MLP:
147-
activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t()
148-
149-
Gated MLP:
150-
activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t()
144+
Returns:
145+
Output tensor of shape (B, H) or (B, S, H)
151146
"""
152147

153148
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
154149

155150
# Store original shape and flatten to 2D
156151
x_shape = x.shape
157152
x2d = x.view(-1, x_shape[-1])
158-
# Quantize input
159-
x_q_fp8 = _quantize_fp8(x2d, w1_input_scale[0])
153+
# Quantize the input
154+
x_q_fp8 = _quantize_fp8(x2d, fc1_act_scale[0])
160155

161156
# Scales are stored in float32
162-
w1_input_scale = w1_input_scale[0]
157+
w1_input_scale = fc1_act_scale[0]
163158

164-
# Prepare quant_scales for TensorRT-LLM FP8 format:
165-
# [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale]
159+
# Prepare quant_scales for TensorRT-LLM (Cutlass) FP8 format:
160+
# [fc1_dequant_scale, fc2_act_scale_reciprocal, fc2_dequant_scale, gemm1_input_dequant_scale]
166161
# For gated MLP:
167162
# These are precomputed in `fused_moe` transform
168-
# - gemm1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3)
169-
# - gemm2_act_quant_scale: 1 / w2_input_scale
170-
# - gemm2_dequant_scale: w2_weight_scale * w2_input_scale
171-
# - gemm1_input_dequant_scale: w1_input_scale
163+
# - fc1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3)
164+
# - fc2_act_scale_reciprocal: 1 / w2_input_scale
165+
# - fc1_dequant_scale: w2_weight_scale * w2_input_scale
166+
# - fc1_act_scale: w1_input_scale
172167

173-
assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D"
174-
assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D"
175-
quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, w1_input_scale]
168+
assert fc1_dequant_scale.ndim == 1, "fc1_dequant_scale must be 1D"
169+
assert fc2_dequant_scale.ndim == 1, "fc2_dequant_scale must be 1D"
170+
quant_scales = [fc1_dequant_scale, fc2_act_scale_reciprocal, fc2_dequant_scale, w1_input_scale]
176171

177172
# Ensure contiguous tensors
178173
selected_experts = selected_experts.int().contiguous()
@@ -182,11 +177,9 @@ def trtllm_quant_fp8_moe_fused(
182177

183178
# Determine activation type
184179
activation_type = ActivationType.Swiglu
180+
185181
if is_gated_mlp:
186182
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
187-
# For gated MLP, concatenate w1 and w3 as [w3, w1]
188-
w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H]
189-
fc1_expert_weights = w3_w1_stacked
190183
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
191184
activation_type = ActivationType.Swiglu
192185
else:
@@ -195,7 +188,6 @@ def trtllm_quant_fp8_moe_fused(
195188
)
196189
else:
197190
# For non-gated MLP with ReLU^2
198-
fc1_expert_weights = w1_weight.contiguous()
199191
if act_fn == ActivationType.Relu2:
200192
activation_type = ActivationType.Relu2
201193
else:
@@ -210,7 +202,7 @@ def trtllm_quant_fp8_moe_fused(
210202
routing_weights,
211203
fc1_expert_weights=fc1_expert_weights,
212204
fc1_expert_biases=None,
213-
fc2_expert_weights=w2_weight.contiguous(),
205+
fc2_expert_weights=fc2_expert_weights.contiguous(),
214206
fc2_expert_biases=None,
215207
output_dtype=x.dtype,
216208
quant_scales=quant_scales,
@@ -226,20 +218,14 @@ def trtllm_quant_fp8_moe_fused_fake(
226218
x: torch.Tensor,
227219
selected_experts: torch.Tensor,
228220
routing_weights: torch.Tensor,
229-
w1_weight: torch.Tensor,
230-
w2_weight: torch.Tensor,
231-
w3_weight: torch.Tensor,
232-
w1_input_scale: torch.Tensor,
233-
w2_input_scale: torch.Tensor,
234-
w3_input_scale: torch.Tensor,
235-
w1_weight_scale: torch.Tensor,
236-
w2_weight_scale: torch.Tensor,
237-
w3_weight_scale: torch.Tensor,
238-
gemm1_dequant: torch.Tensor,
239-
gemm2_act_quant: torch.Tensor,
240-
gemm2_dequant: torch.Tensor,
241-
is_gated_mlp: bool,
242-
act_fn: int,
221+
fc1_expert_weights: torch.Tensor,
222+
fc2_expert_weights: torch.Tensor,
223+
fc1_act_scale: torch.Tensor,
224+
fc1_dequant_scale: torch.Tensor,
225+
fc2_act_scale_reciprocal: torch.Tensor,
226+
fc2_dequant_scale: torch.Tensor,
227+
is_gated_mlp: bool = True,
228+
act_fn: int = int(ActivationType.Silu),
243229
) -> torch.Tensor:
244230
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
245231
return torch.empty_like(x)

0 commit comments

Comments
 (0)