@@ -94,7 +94,7 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
9494
9595
9696def _validate_mlp_style_and_act_fn (is_gated_mlp : bool , act_fn : int ) -> None :
97- assert (is_gated_mlp and act_fn == ActivationType .Silu ) or (
97+ assert (is_gated_mlp and act_fn in [ ActivationType .Silu , ActivationType . Swiglu ] ) or (
9898 not is_gated_mlp and act_fn == ActivationType .Relu2
9999 ), (
100100 f"Unsupported combination: is_gated_mlp='{ is_gated_mlp } ', act_fn='{ act_fn } '. "
@@ -146,6 +146,7 @@ def trtllm_quant_fp8_moe_fused(
146146 """
147147
148148 _validate_mlp_style_and_act_fn (is_gated_mlp , act_fn )
149+ act_fn = ActivationType .Swiglu if act_fn == ActivationType .Silu else act_fn
149150
150151 # Store original shape and flatten to 2D
151152 x_shape = x .shape
@@ -173,28 +174,6 @@ def trtllm_quant_fp8_moe_fused(
173174 selected_experts = selected_experts .int ().contiguous ()
174175 routing_weights = routing_weights .contiguous ()
175176
176- # Todo: refactor this repeating code block
177-
178- # Determine activation type
179- activation_type = ActivationType .Swiglu
180-
181- if is_gated_mlp :
182- # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
183- if act_fn in [ActivationType .Silu , ActivationType .Swiglu ]:
184- activation_type = ActivationType .Swiglu
185- else :
186- raise ValueError (
187- f"Unsupported activation '{ ActivationType (act_fn ).name } ' for gated_mlp. Use 'silu'."
188- )
189- else :
190- # For non-gated MLP with ReLU^2
191- if act_fn == ActivationType .Relu2 :
192- activation_type = ActivationType .Relu2
193- else :
194- raise ValueError (
195- f"Unsupported activation '{ ActivationType (act_fn ).name } ' for mlp. Use 'relu2'."
196- )
197-
198177 # Note! Outputting Float8_e4m3fn directly is not currently supported
199178 output = torch .ops .trtllm .fused_moe (
200179 x_q_fp8 ,
@@ -206,7 +185,7 @@ def trtllm_quant_fp8_moe_fused(
206185 fc2_expert_biases = None ,
207186 output_dtype = x .dtype ,
208187 quant_scales = quant_scales ,
209- activation_type = activation_type ,
188+ activation_type = act_fn ,
210189 )
211190
212191 # Restore original shape
0 commit comments