Skip to content

Commit f3f0231

Browse files
[None][chore]: small refactoring to auto-deploy MoE operator (#10300)
Signed-off-by: Neta Zmora <[email protected]>
1 parent db3430f commit f3f0231

File tree

2 files changed

+6
-25
lines changed

2 files changed

+6
-25
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
9494

9595

9696
def _validate_mlp_style_and_act_fn(is_gated_mlp: bool, act_fn: int) -> None:
97-
assert (is_gated_mlp and act_fn == ActivationType.Silu) or (
97+
assert (is_gated_mlp and act_fn in [ActivationType.Silu, ActivationType.Swiglu]) or (
9898
not is_gated_mlp and act_fn == ActivationType.Relu2
9999
), (
100100
f"Unsupported combination: is_gated_mlp='{is_gated_mlp}', act_fn='{act_fn}'. "
@@ -146,6 +146,7 @@ def trtllm_quant_fp8_moe_fused(
146146
"""
147147

148148
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
149+
act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn
149150

150151
# Store original shape and flatten to 2D
151152
x_shape = x.shape
@@ -173,28 +174,6 @@ def trtllm_quant_fp8_moe_fused(
173174
selected_experts = selected_experts.int().contiguous()
174175
routing_weights = routing_weights.contiguous()
175176

176-
# Todo: refactor this repeating code block
177-
178-
# Determine activation type
179-
activation_type = ActivationType.Swiglu
180-
181-
if is_gated_mlp:
182-
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
183-
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
184-
activation_type = ActivationType.Swiglu
185-
else:
186-
raise ValueError(
187-
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
188-
)
189-
else:
190-
# For non-gated MLP with ReLU^2
191-
if act_fn == ActivationType.Relu2:
192-
activation_type = ActivationType.Relu2
193-
else:
194-
raise ValueError(
195-
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
196-
)
197-
198177
# Note! Outputting Float8_e4m3fn directly is not currently supported
199178
output = torch.ops.trtllm.fused_moe(
200179
x_q_fp8,
@@ -206,7 +185,7 @@ def trtllm_quant_fp8_moe_fused(
206185
fc2_expert_biases=None,
207186
output_dtype=x.dtype,
208187
quant_scales=quant_scales,
209-
activation_type=activation_type,
188+
activation_type=act_fn,
210189
)
211190

212191
# Restore original shape

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,9 @@ def get_fc1_expert_weights(
307307
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
308308
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
309309
@pytest.mark.parametrize("itype, otype, wtype", FP8_TEST_DTYPES)
310-
@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2])
310+
@pytest.mark.parametrize(
311+
"activation_func", [ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2]
312+
)
311313
@pytest.mark.skipif(
312314
not fp8_compatible() or not trtllm_ops_available(),
313315
reason="Requires fp8 and trtllm support",

0 commit comments

Comments
 (0)