@@ -26,8 +26,8 @@ def trtllm_moe_fused(
2626 routing_weights : torch .Tensor ,
2727 w3_w1_stacked_weight : torch .Tensor ,
2828 w2_stacked_weight : torch .Tensor ,
29- mlp_style : str = "gated_mlp" ,
30- act_fn : str = "silu" ,
29+ is_gated_mlp : bool = True ,
30+ act_fn : int = int ( ActivationType . Silu ) ,
3131) -> torch .Tensor :
3232 x_shape = x .shape
3333 x = x .view (- 1 , x_shape [- 1 ])
@@ -37,24 +37,24 @@ def trtllm_moe_fused(
3737 quant_scales = []
3838
3939 # Determine activation type
40- mlp_style = mlp_style .lower ()
41- act_fn = act_fn .lower ()
4240
4341 activation_type = ActivationType .Swiglu
44- if mlp_style == "gated_mlp" :
42+ if is_gated_mlp :
4543 # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
46- if act_fn == "silu" :
44+ if act_fn in [ ActivationType . Silu , ActivationType . Swiglu ] :
4745 activation_type = ActivationType .Swiglu
4846 else :
49- raise ValueError (f"Unsupported activation '{ act_fn } ' for gated_mlp. Use 'silu'." )
50- elif mlp_style == "mlp" :
47+ raise ValueError (
48+ f"Unsupported activation '{ ActivationType (act_fn ).name } ' for gated_mlp. Use 'silu'."
49+ )
50+ else :
5151 # For non-gated MLP with ReLU^2
52- if act_fn == "relu2" :
52+ if act_fn == ActivationType . Relu2 :
5353 activation_type = ActivationType .Relu2
5454 else :
55- raise ValueError (f"Unsupported activation ' { act_fn } ' for mlp. Use 'relu2'." )
56- else :
57- raise ValueError ( f"Unknown mlp_style ' { mlp_style } '. Use 'gated_mlp' or 'mlp'." )
55+ raise ValueError (
56+ f"Unsupported activation ' { ActivationType ( act_fn ). name } ' for mlp. Use 'relu2'."
57+ )
5858
5959 return torch .ops .trtllm .fused_moe (
6060 x ,
@@ -77,8 +77,8 @@ def trtllm_moe_fused_fake(
7777 routing_weights : torch .Tensor ,
7878 w3_w1_stacked_weight : torch .Tensor ,
7979 w2_stacked_weight : torch .Tensor ,
80- mlp_style : str = "gated_mlp" ,
81- act_fn : str = "silu" ,
80+ is_gated_mlp : bool = True ,
81+ act_fn : int = int ( ActivationType . Silu ) ,
8282) -> torch .Tensor :
8383 return torch .empty_like (x )
8484
@@ -93,21 +93,12 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
9393 return (x / scale ).clamp (FP8_MIN , FP8_MAX ).to (torch .float8_e4m3fn )
9494
9595
96- def _validate_mlp_style_and_act_fn (mlp_style : str , act_fn : str ) -> None :
97- supported_combinations = {
98- "gated_mlp" : ["silu" ],
99- "mlp" : ["relu2" ],
100- }
101- supported_act_fns = [
102- act_fn for act_fn_list in supported_combinations .values () for act_fn in act_fn_list
103- ]
104- assert mlp_style in supported_combinations .keys (), (
105- f"Unknown mlp_style '{ mlp_style } '. Use { supported_combinations .keys ()} ."
106- )
107- assert act_fn in supported_act_fns , f"Unknown act_fn '{ act_fn } '. Use { supported_act_fns } ."
108- assert act_fn in supported_combinations [mlp_style ], (
109- f"Unsupported combination: mlp_style='{ mlp_style } ', act_fn='{ act_fn } '. "
110- f"Supported combinations: { supported_combinations } "
96+ def _validate_mlp_style_and_act_fn (is_gated_mlp : bool , act_fn : int ) -> None :
97+ assert (is_gated_mlp and act_fn == ActivationType .Silu ) or (
98+ not is_gated_mlp and act_fn == ActivationType .Relu2
99+ ), (
100+ f"Unsupported combination: is_gated_mlp='{ is_gated_mlp } ', act_fn='{ act_fn } '. "
101+ f"Supported combinations: gated mlp with silu or mlp with relu2."
111102 )
112103
113104
@@ -128,8 +119,8 @@ def trtllm_quant_fp8_moe_fused(
128119 gemm1_dequant : torch .Tensor , # [E]
129120 gemm2_act_quant : torch .Tensor , # [E]
130121 gemm2_dequant : torch .Tensor , # [E]
131- mlp_style : str = "gated_mlp" ,
132- act_fn : str = "silu" ,
122+ is_gated_mlp : bool = True ,
123+ act_fn : int = int ( ActivationType . Silu ) ,
133124) -> torch .Tensor :
134125 """
135126 TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP.
@@ -149,8 +140,8 @@ def trtllm_quant_fp8_moe_fused(
149140 gemm1_dequant: Precomputed gemm1 dequant scale [E]
150141 gemm2_act_quant: Precomputed gemm2 act quant scale [1]
151142 gemm2_dequant: Precomputed gemm2 dequant scale [E]
152- mlp_style: " gated_mlp" or " mlp"
153- act_fn: "silu" for gated_mlp, "relu2" for mlp
143+ is_gated_mlp: True for gated_mlp, False for mlp
144+ act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp
154145
155146 Non-Gated MLP:
156147 activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t()
@@ -159,7 +150,7 @@ def trtllm_quant_fp8_moe_fused(
159150 activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t()
160151 """
161152
162- _validate_mlp_style_and_act_fn (mlp_style , act_fn )
153+ _validate_mlp_style_and_act_fn (is_gated_mlp , act_fn )
163154
164155 # Store original shape and flatten to 2D
165156 x_shape = x .shape
@@ -190,28 +181,27 @@ def trtllm_quant_fp8_moe_fused(
190181 # Todo: refactor this repeating code block
191182
192183 # Determine activation type
193- mlp_style = mlp_style .lower ()
194- act_fn = act_fn .lower ()
195-
196184 activation_type = ActivationType .Swiglu
197- if mlp_style == "gated_mlp" :
185+ if is_gated_mlp :
198186 # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
199187 # For gated MLP, concatenate w1 and w3 as [w3, w1]
200188 w3_w1_stacked = torch .cat ([w3_weight , w1_weight ], dim = 1 ).contiguous () # [E, 2*I, H]
201189 fc1_expert_weights = w3_w1_stacked
202- if act_fn == "silu" :
190+ if act_fn in [ ActivationType . Silu , ActivationType . Swiglu ] :
203191 activation_type = ActivationType .Swiglu
204192 else :
205- raise ValueError (f"Unsupported activation '{ act_fn } ' for gated_mlp. Use 'silu'." )
206- elif mlp_style == "mlp" :
193+ raise ValueError (
194+ f"Unsupported activation '{ ActivationType (act_fn ).name } ' for gated_mlp. Use 'silu'."
195+ )
196+ else :
207197 # For non-gated MLP with ReLU^2
208198 fc1_expert_weights = w1_weight .contiguous ()
209- if act_fn == "relu2" :
199+ if act_fn == ActivationType . Relu2 :
210200 activation_type = ActivationType .Relu2
211201 else :
212- raise ValueError (f"Unsupported activation ' { act_fn } ' for mlp. Use 'relu2'." )
213- else :
214- raise ValueError ( f"Unknown mlp_style ' { mlp_style } '. Use 'gated_mlp' or 'mlp'." )
202+ raise ValueError (
203+ f"Unsupported activation ' { ActivationType ( act_fn ). name } ' for mlp. Use 'relu2'."
204+ )
215205
216206 # Note! Outputting Float8_e4m3fn directly is not currently supported
217207 output = torch .ops .trtllm .fused_moe (
@@ -248,10 +238,10 @@ def trtllm_quant_fp8_moe_fused_fake(
248238 gemm1_dequant : torch .Tensor ,
249239 gemm2_act_quant : torch .Tensor ,
250240 gemm2_dequant : torch .Tensor ,
251- mlp_style : str ,
252- act_fn : str ,
241+ is_gated_mlp : bool ,
242+ act_fn : int ,
253243) -> torch .Tensor :
254- _validate_mlp_style_and_act_fn (mlp_style , act_fn )
244+ _validate_mlp_style_and_act_fn (is_gated_mlp , act_fn )
255245 return torch .empty_like (x )
256246
257247
@@ -268,8 +258,8 @@ def trtllm_quant_nvfp4_moe_fused(
268258 fc2_act_global_scale : torch .Tensor , # Global scale for FC2 activations
269259 fc1_alpha : torch .Tensor , # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8))
270260 fc2_alpha : torch .Tensor , # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8))
271- mlp_style : str = "gated_mlp" ,
272- act_fn : str = "silu" ,
261+ is_gated_mlp : bool = True ,
262+ act_fn : int = int ( ActivationType . Silu ) ,
273263) -> torch .Tensor :
274264 """TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP.
275265
@@ -285,22 +275,22 @@ def trtllm_quant_nvfp4_moe_fused(
285275
286276 """
287277 NVFP4_BLOCK_SIZE = 16
288- mlp_style = mlp_style .lower ()
289- act_fn = act_fn .lower ()
290278
291279 activation_type = ActivationType .Swiglu
292- if mlp_style == "gated_mlp" :
293- if act_fn == "silu" :
280+ if is_gated_mlp :
281+ if act_fn in [ ActivationType . Silu , ActivationType . Swiglu ] :
294282 activation_type = ActivationType .Swiglu
295283 else :
296- raise ValueError (f"Unsupported activation '{ act_fn } ' for gated_mlp. Use 'silu'." )
297- elif mlp_style == "mlp" :
298- if act_fn == "relu2" :
284+ raise ValueError (
285+ f"Unsupported activation '{ ActivationType (act_fn ).name } ' for gated_mlp. Use 'silu'."
286+ )
287+ else :
288+ if act_fn == ActivationType .Relu2 :
299289 activation_type = ActivationType .Relu2
300290 else :
301- raise ValueError (f"Unsupported activation ' { act_fn } ' for mlp. Use 'relu2'." )
302- else :
303- raise ValueError ( f"Unknown mlp_style ' { mlp_style } '. Use 'gated_mlp' or 'mlp'." )
291+ raise ValueError (
292+ f"Unsupported activation ' { ActivationType ( act_fn ). name } ' for mlp. Use 'relu2'."
293+ )
304294
305295 # quant_scales is described by this code:
306296 # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
@@ -353,7 +343,7 @@ def trtllm_quant_nvfp4_moe_fused_fake(
353343 fc2_act_global_scale : torch .Tensor ,
354344 fc1_alpha : torch .Tensor ,
355345 fc2_alpha : torch .Tensor ,
356- mlp_style : str = "gated_mlp" ,
357- act_fn : str = "silu" ,
346+ is_gated_mlp : bool = True ,
347+ act_fn : int = int ( ActivationType . Silu ) ,
358348) -> torch .Tensor :
359349 return torch .empty_like (x )
0 commit comments