@@ -107,72 +107,67 @@ def trtllm_quant_fp8_moe_fused(
107107 x : torch .Tensor ,
108108 selected_experts : torch .Tensor ,
109109 routing_weights : torch .Tensor ,
110- w1_weight : torch .Tensor , # [E, I, H] stacked FP8 weights
111- w2_weight : torch .Tensor , # [E, H, I] stacked FP8 weights
112- w3_weight : torch .Tensor , # [E, I, H] for gated_mlp, unused for mlp
113- w1_input_scale : torch .Tensor , # [E] stacked input scales
114- w2_input_scale : torch .Tensor , # [E] stacked input scales
115- w3_input_scale : torch .Tensor , # [E] or unused
116- w1_weight_scale : torch .Tensor , # [E] stacked weight scales
117- w2_weight_scale : torch .Tensor , # [E] stacked weight scales
118- w3_weight_scale : torch .Tensor , # [E] or unused
119- gemm1_dequant : torch .Tensor , # [E]
120- gemm2_act_quant : torch .Tensor , # [E]
121- gemm2_dequant : torch .Tensor , # [E]
110+ fc1_expert_weights : torch .Tensor ,
111+ fc2_expert_weights : torch .Tensor ,
112+ fc1_act_scale : torch .Tensor ,
113+ fc1_dequant_scale : torch .Tensor ,
114+ fc2_act_scale_reciprocal : torch .Tensor ,
115+ fc2_dequant_scale : torch .Tensor ,
122116 is_gated_mlp : bool = True ,
123117 act_fn : int = int (ActivationType .Silu ),
124118) -> torch .Tensor :
125- """
126- TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP.
119+ """TensorRT-LLM Cutlass FP8 (W8A8) MoE for gated and non-gated MLP.
120+
121+ Computes (per expert):
122+ For gated_mlp:
123+ y = (act(x @ w1.T) * (x @ w3.T)) @ w2.T # act := SiLU
124+ For mlp:
125+ y = act(x @ w1.T) @ w2.T # act := ReLU^2
126+ Notes:
127+ - FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T)
128+ - FC2 implements: fc2_output = fc1_output @ w2.T
129+ - FC1 weights are concatenated w3 and w1 if gated_mlp, otherwise w1
130+
127131 Parameters:
128132 x: BF16/FP16 input tensor of shape (B, H) or (B, S, H)
129133 selected_experts: Expert indices (B*S, TOP_K)
130134 routing_weights: Routing weights (B*S, TOP_K)
131- w1_weight: FP8 w1 weights [E, I, H]
132- w2_weight: FP8 w2 weights [E, H, I]
133- w3_weight: FP8 w3 weights [E, I, H] (for gated_mlp)
134- w1_input_scale: Input scales for w1 [E]
135- w2_input_scale: Input scales for w2 [E]
136- w3_input_scale: Input scales for w3 [E]
137- w1_weight_scale: Weight scales for w1 [E]
138- w2_weight_scale: Weight scales for w2 [E]
139- w3_weight_scale: Weight scales for w3 [E]
140- gemm1_dequant: Precomputed gemm1 dequant scale [E]
141- gemm2_act_quant: Precomputed gemm2 act quant scale [1]
142- gemm2_dequant: Precomputed gemm2 dequant scale [E]
135+ fc1_expert_weights: FC1 weights [E, 2*I, H] for gated_mlp, [E, I, H] for mlp
136+ fc2_expert_weights: FC2 weights [E, H, I]
137+ fc1_act_scale: FC1 activation scale [E]
138+ fc1_dequant_scale: FC1 dequant scale [E]
139+ fc2_act_scale_reciprocal: FC2 activation scale reciprocal [E]
140+ fc2_dequant_scale: FC2 dequant scale [E]
143141 is_gated_mlp: True for gated_mlp, False for mlp
144142 act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp
145143
146- Non-Gated MLP:
147- activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t()
148-
149- Gated MLP:
150- activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t()
144+ Returns:
145+ Output tensor of shape (B, H) or (B, S, H)
151146 """
152147
153148 _validate_mlp_style_and_act_fn (is_gated_mlp , act_fn )
154149
155150 # Store original shape and flatten to 2D
156151 x_shape = x .shape
157152 x2d = x .view (- 1 , x_shape [- 1 ])
158- # Quantize input
159- x_q_fp8 = _quantize_fp8 (x2d , w1_input_scale [0 ])
153+ # Quantize the input
154+ x_q_fp8 = _quantize_fp8 (x2d , fc1_act_scale [0 ])
160155
161156 # Scales are stored in float32
162- w1_input_scale = w1_input_scale [0 ]
157+ w1_input_scale = fc1_act_scale [0 ]
163158
164- # Prepare quant_scales for TensorRT-LLM FP8 format:
165- # [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale , gemm1_input_dequant_scale]
159+ # Prepare quant_scales for TensorRT-LLM (Cutlass) FP8 format:
160+ # [fc1_dequant_scale, fc2_act_scale_reciprocal, fc2_dequant_scale , gemm1_input_dequant_scale]
166161 # For gated MLP:
167162 # These are precomputed in `fused_moe` transform
168- # - gemm1_dequant_scale : w1_weight_scale * w1_input_scale (combined for w1 and w3)
169- # - gemm2_act_quant_scale : 1 / w2_input_scale
170- # - gemm2_dequant_scale : w2_weight_scale * w2_input_scale
171- # - gemm1_input_dequant_scale : w1_input_scale
163+ # - fc1_dequant_scale : w1_weight_scale * w1_input_scale (combined for w1 and w3)
164+ # - fc2_act_scale_reciprocal : 1 / w2_input_scale
165+ # - fc1_dequant_scale : w2_weight_scale * w2_input_scale
166+ # - fc1_act_scale : w1_input_scale
172167
173- assert gemm1_dequant .ndim == 1 , "gemm1_dequant must be 1D"
174- assert gemm2_dequant .ndim == 1 , "gemm2_dequant must be 1D"
175- quant_scales = [gemm1_dequant , gemm2_act_quant , gemm2_dequant , w1_input_scale ]
168+ assert fc1_dequant_scale .ndim == 1 , "fc1_dequant_scale must be 1D"
169+ assert fc2_dequant_scale .ndim == 1 , "fc2_dequant_scale must be 1D"
170+ quant_scales = [fc1_dequant_scale , fc2_act_scale_reciprocal , fc2_dequant_scale , w1_input_scale ]
176171
177172 # Ensure contiguous tensors
178173 selected_experts = selected_experts .int ().contiguous ()
@@ -182,11 +177,9 @@ def trtllm_quant_fp8_moe_fused(
182177
183178 # Determine activation type
184179 activation_type = ActivationType .Swiglu
180+
185181 if is_gated_mlp :
186182 # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
187- # For gated MLP, concatenate w1 and w3 as [w3, w1]
188- w3_w1_stacked = torch .cat ([w3_weight , w1_weight ], dim = 1 ).contiguous () # [E, 2*I, H]
189- fc1_expert_weights = w3_w1_stacked
190183 if act_fn in [ActivationType .Silu , ActivationType .Swiglu ]:
191184 activation_type = ActivationType .Swiglu
192185 else :
@@ -195,7 +188,6 @@ def trtllm_quant_fp8_moe_fused(
195188 )
196189 else :
197190 # For non-gated MLP with ReLU^2
198- fc1_expert_weights = w1_weight .contiguous ()
199191 if act_fn == ActivationType .Relu2 :
200192 activation_type = ActivationType .Relu2
201193 else :
@@ -210,7 +202,7 @@ def trtllm_quant_fp8_moe_fused(
210202 routing_weights ,
211203 fc1_expert_weights = fc1_expert_weights ,
212204 fc1_expert_biases = None ,
213- fc2_expert_weights = w2_weight .contiguous (),
205+ fc2_expert_weights = fc2_expert_weights .contiguous (),
214206 fc2_expert_biases = None ,
215207 output_dtype = x .dtype ,
216208 quant_scales = quant_scales ,
@@ -226,20 +218,14 @@ def trtllm_quant_fp8_moe_fused_fake(
226218 x : torch .Tensor ,
227219 selected_experts : torch .Tensor ,
228220 routing_weights : torch .Tensor ,
229- w1_weight : torch .Tensor ,
230- w2_weight : torch .Tensor ,
231- w3_weight : torch .Tensor ,
232- w1_input_scale : torch .Tensor ,
233- w2_input_scale : torch .Tensor ,
234- w3_input_scale : torch .Tensor ,
235- w1_weight_scale : torch .Tensor ,
236- w2_weight_scale : torch .Tensor ,
237- w3_weight_scale : torch .Tensor ,
238- gemm1_dequant : torch .Tensor ,
239- gemm2_act_quant : torch .Tensor ,
240- gemm2_dequant : torch .Tensor ,
241- is_gated_mlp : bool ,
242- act_fn : int ,
221+ fc1_expert_weights : torch .Tensor ,
222+ fc2_expert_weights : torch .Tensor ,
223+ fc1_act_scale : torch .Tensor ,
224+ fc1_dequant_scale : torch .Tensor ,
225+ fc2_act_scale_reciprocal : torch .Tensor ,
226+ fc2_dequant_scale : torch .Tensor ,
227+ is_gated_mlp : bool = True ,
228+ act_fn : int = int (ActivationType .Silu ),
243229) -> torch .Tensor :
244230 _validate_mlp_style_and_act_fn (is_gated_mlp , act_fn )
245231 return torch .empty_like (x )
0 commit comments