Skip to content

Commit 12e1cb8

Browse files
[#9717][chore] Refactor MoE code to use enums (#9910)
Signed-off-by: Tal Cherckez <[email protected]>
1 parent aaa87ab commit 12e1cb8

File tree

13 files changed

+246
-246
lines changed

13 files changed

+246
-246
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py

Lines changed: 86 additions & 89 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import triton
1515
import triton.language as tl
1616

17+
from tensorrt_llm._torch.utils import ActivationType # noqa: F401
18+
1719
from ...utils.logger import ad_logger
1820

1921

@@ -601,15 +603,13 @@ def triton_fused_moe(
601603
routing_weights: torch.Tensor,
602604
w1_stacked_weight: torch.Tensor,
603605
w2_stacked_weight: torch.Tensor,
604-
mlp_style: str = "mlp",
605-
act_fn: str = "relu2",
606+
is_gated_mlp: bool = False,
607+
act_fn: int = int(ActivationType.Relu2),
606608
) -> torch.Tensor:
607609
"""Triton unquantized MoE with 2-layer MLP and ReLU^2 activation."""
608610

609-
mlp_style = mlp_style.lower()
610-
act_fn = act_fn.lower()
611-
assert mlp_style == "mlp", "Triton backend only supports mlp style."
612-
assert act_fn == "relu2", "Triton backend only supports relu2 activation."
611+
assert not is_gated_mlp, "Triton backend only supports non gated MLP style."
612+
assert act_fn == ActivationType.Relu2, "Triton backend only supports relu2 activation."
613613

614614
x_shape = x.shape
615615
x2d = x.view(-1, x_shape[-1])
@@ -661,12 +661,12 @@ def triton_quant_fp8_moe(
661661
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
662662
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
663663
w3_weight_scale: torch.Tensor, # unused
664-
mlp_style: str = "gated_mlp",
665-
act_fn: str = "silu",
664+
is_gated_mlp: bool = False,
665+
act_fn: int = int(ActivationType.Silu),
666666
) -> torch.Tensor:
667667
"""Triton FP8 W8A8 MoE with 2-layer MLP and ReLU^2 activation."""
668-
if mlp_style != "mlp":
669-
raise NotImplementedError("triton_quant_fp8_moe currently supports mlp_style=='mlp' only")
668+
if is_gated_mlp:
669+
raise NotImplementedError("triton_quant_fp8_moe currently supports mlp only")
670670

671671
x_shape = x.shape
672672
x2d = x.view(-1, x_shape[-1])
@@ -760,7 +760,7 @@ def triton_quant_fp8_moe(
760760
w1_weight_scale: torch.Tensor,
761761
w2_weight_scale: torch.Tensor,
762762
w3_weight_scale: torch.Tensor,
763-
mlp_style: str = "gated_mlp",
764-
act_fn: str = "silu",
763+
is_gated_mlp: bool = False,
764+
act_fn: int = int(ActivationType.Silu),
765765
) -> torch.Tensor:
766766
return torch.empty_like(x)

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 52 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def trtllm_moe_fused(
2626
routing_weights: torch.Tensor,
2727
w3_w1_stacked_weight: torch.Tensor,
2828
w2_stacked_weight: torch.Tensor,
29-
mlp_style: str = "gated_mlp",
30-
act_fn: str = "silu",
29+
is_gated_mlp: bool = True,
30+
act_fn: int = int(ActivationType.Silu),
3131
) -> torch.Tensor:
3232
x_shape = x.shape
3333
x = x.view(-1, x_shape[-1])
@@ -37,24 +37,24 @@ def trtllm_moe_fused(
3737
quant_scales = []
3838

3939
# Determine activation type
40-
mlp_style = mlp_style.lower()
41-
act_fn = act_fn.lower()
4240

4341
activation_type = ActivationType.Swiglu
44-
if mlp_style == "gated_mlp":
42+
if is_gated_mlp:
4543
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
46-
if act_fn == "silu":
44+
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
4745
activation_type = ActivationType.Swiglu
4846
else:
49-
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
50-
elif mlp_style == "mlp":
47+
raise ValueError(
48+
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
49+
)
50+
else:
5151
# For non-gated MLP with ReLU^2
52-
if act_fn == "relu2":
52+
if act_fn == ActivationType.Relu2:
5353
activation_type = ActivationType.Relu2
5454
else:
55-
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
56-
else:
57-
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
55+
raise ValueError(
56+
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
57+
)
5858

5959
return torch.ops.trtllm.fused_moe(
6060
x,
@@ -77,8 +77,8 @@ def trtllm_moe_fused_fake(
7777
routing_weights: torch.Tensor,
7878
w3_w1_stacked_weight: torch.Tensor,
7979
w2_stacked_weight: torch.Tensor,
80-
mlp_style: str = "gated_mlp",
81-
act_fn: str = "silu",
80+
is_gated_mlp: bool = True,
81+
act_fn: int = int(ActivationType.Silu),
8282
) -> torch.Tensor:
8383
return torch.empty_like(x)
8484

@@ -93,21 +93,12 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
9393
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
9494

9595

96-
def _validate_mlp_style_and_act_fn(mlp_style: str, act_fn: str) -> None:
97-
supported_combinations = {
98-
"gated_mlp": ["silu"],
99-
"mlp": ["relu2"],
100-
}
101-
supported_act_fns = [
102-
act_fn for act_fn_list in supported_combinations.values() for act_fn in act_fn_list
103-
]
104-
assert mlp_style in supported_combinations.keys(), (
105-
f"Unknown mlp_style '{mlp_style}'. Use {supported_combinations.keys()}."
106-
)
107-
assert act_fn in supported_act_fns, f"Unknown act_fn '{act_fn}'. Use {supported_act_fns}."
108-
assert act_fn in supported_combinations[mlp_style], (
109-
f"Unsupported combination: mlp_style='{mlp_style}', act_fn='{act_fn}'. "
110-
f"Supported combinations: {supported_combinations}"
96+
def _validate_mlp_style_and_act_fn(is_gated_mlp: bool, act_fn: int) -> None:
97+
assert (is_gated_mlp and act_fn == ActivationType.Silu) or (
98+
not is_gated_mlp and act_fn == ActivationType.Relu2
99+
), (
100+
f"Unsupported combination: is_gated_mlp='{is_gated_mlp}', act_fn='{act_fn}'. "
101+
f"Supported combinations: gated mlp with silu or mlp with relu2."
111102
)
112103

113104

@@ -128,8 +119,8 @@ def trtllm_quant_fp8_moe_fused(
128119
gemm1_dequant: torch.Tensor, # [E]
129120
gemm2_act_quant: torch.Tensor, # [E]
130121
gemm2_dequant: torch.Tensor, # [E]
131-
mlp_style: str = "gated_mlp",
132-
act_fn: str = "silu",
122+
is_gated_mlp: bool = True,
123+
act_fn: int = int(ActivationType.Silu),
133124
) -> torch.Tensor:
134125
"""
135126
TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP.
@@ -149,8 +140,8 @@ def trtllm_quant_fp8_moe_fused(
149140
gemm1_dequant: Precomputed gemm1 dequant scale [E]
150141
gemm2_act_quant: Precomputed gemm2 act quant scale [1]
151142
gemm2_dequant: Precomputed gemm2 dequant scale [E]
152-
mlp_style: "gated_mlp" or "mlp"
153-
act_fn: "silu" for gated_mlp, "relu2" for mlp
143+
is_gated_mlp: True for gated_mlp, False for mlp
144+
act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp
154145
155146
Non-Gated MLP:
156147
activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t()
@@ -159,7 +150,7 @@ def trtllm_quant_fp8_moe_fused(
159150
activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t()
160151
"""
161152

162-
_validate_mlp_style_and_act_fn(mlp_style, act_fn)
153+
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
163154

164155
# Store original shape and flatten to 2D
165156
x_shape = x.shape
@@ -190,28 +181,27 @@ def trtllm_quant_fp8_moe_fused(
190181
# Todo: refactor this repeating code block
191182

192183
# Determine activation type
193-
mlp_style = mlp_style.lower()
194-
act_fn = act_fn.lower()
195-
196184
activation_type = ActivationType.Swiglu
197-
if mlp_style == "gated_mlp":
185+
if is_gated_mlp:
198186
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
199187
# For gated MLP, concatenate w1 and w3 as [w3, w1]
200188
w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H]
201189
fc1_expert_weights = w3_w1_stacked
202-
if act_fn == "silu":
190+
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
203191
activation_type = ActivationType.Swiglu
204192
else:
205-
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
206-
elif mlp_style == "mlp":
193+
raise ValueError(
194+
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
195+
)
196+
else:
207197
# For non-gated MLP with ReLU^2
208198
fc1_expert_weights = w1_weight.contiguous()
209-
if act_fn == "relu2":
199+
if act_fn == ActivationType.Relu2:
210200
activation_type = ActivationType.Relu2
211201
else:
212-
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
213-
else:
214-
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
202+
raise ValueError(
203+
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
204+
)
215205

216206
# Note! Outputting Float8_e4m3fn directly is not currently supported
217207
output = torch.ops.trtllm.fused_moe(
@@ -248,10 +238,10 @@ def trtllm_quant_fp8_moe_fused_fake(
248238
gemm1_dequant: torch.Tensor,
249239
gemm2_act_quant: torch.Tensor,
250240
gemm2_dequant: torch.Tensor,
251-
mlp_style: str,
252-
act_fn: str,
241+
is_gated_mlp: bool,
242+
act_fn: int,
253243
) -> torch.Tensor:
254-
_validate_mlp_style_and_act_fn(mlp_style, act_fn)
244+
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
255245
return torch.empty_like(x)
256246

257247

@@ -268,8 +258,8 @@ def trtllm_quant_nvfp4_moe_fused(
268258
fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations
269259
fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8))
270260
fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8))
271-
mlp_style: str = "gated_mlp",
272-
act_fn: str = "silu",
261+
is_gated_mlp: bool = True,
262+
act_fn: int = int(ActivationType.Silu),
273263
) -> torch.Tensor:
274264
"""TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP.
275265
@@ -285,22 +275,22 @@ def trtllm_quant_nvfp4_moe_fused(
285275
286276
"""
287277
NVFP4_BLOCK_SIZE = 16
288-
mlp_style = mlp_style.lower()
289-
act_fn = act_fn.lower()
290278

291279
activation_type = ActivationType.Swiglu
292-
if mlp_style == "gated_mlp":
293-
if act_fn == "silu":
280+
if is_gated_mlp:
281+
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
294282
activation_type = ActivationType.Swiglu
295283
else:
296-
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
297-
elif mlp_style == "mlp":
298-
if act_fn == "relu2":
284+
raise ValueError(
285+
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
286+
)
287+
else:
288+
if act_fn == ActivationType.Relu2:
299289
activation_type = ActivationType.Relu2
300290
else:
301-
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
302-
else:
303-
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
291+
raise ValueError(
292+
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
293+
)
304294

305295
# quant_scales is described by this code:
306296
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
@@ -353,7 +343,7 @@ def trtllm_quant_nvfp4_moe_fused_fake(
353343
fc2_act_global_scale: torch.Tensor,
354344
fc1_alpha: torch.Tensor,
355345
fc2_alpha: torch.Tensor,
356-
mlp_style: str = "gated_mlp",
357-
act_fn: str = "silu",
346+
is_gated_mlp: bool = True,
347+
act_fn: int = int(ActivationType.Silu),
358348
) -> torch.Tensor:
359349
return torch.empty_like(x)

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from transformers.modeling_utils import PreTrainedModel
1414
from transformers.tokenization_utils_base import BatchEncoding
1515

16+
from tensorrt_llm._torch.utils import ActivationType
1617
from tensorrt_llm.inputs.utils import HF_CHAT_TEMPLATE_EXCEPTIONS
1718

1819
from ..nemotron_flash import NemotronFlashForCausalLMFactory
@@ -182,6 +183,8 @@ def __init__(
182183
self.qk_activation = qk_activation
183184
self.qk_norm = qk_norm
184185

186+
# can't use ActivationType enum here,
187+
# because there is no Elu defined in cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
185188
assert self.qk_activation in ["silu", "relu", "elu", "identity"]
186189
assert self.qk_norm in ["l2", "sum"]
187190

@@ -331,7 +334,7 @@ def __init__(
331334
self.num_heads = self.d_inner // self.headdim
332335
self.rmsnorm = rmsnorm
333336
self.dt_limit = dt_limit
334-
self.activation = "silu"
337+
self.activation = ActivationType.Silu
335338
self.chunk_size = chunk_size
336339
self.layer_idx = layer_idx
337340

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import gated_rms_norm_ref
3535
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
36+
from tensorrt_llm._torch.utils import ActivationType
3637

3738

3839
class MambaRMSNormGated(torch.nn.Module):
@@ -308,8 +309,8 @@ def forward(self, hidden_states: torch.Tensor):
308309
w1_weight=[e.up_proj.weight for e in self.experts],
309310
w2_weight=[e.down_proj.weight for e in self.experts],
310311
w3_weight=[],
311-
act_fn="relu2",
312-
mlp_style="mlp",
312+
act_fn=ActivationType.Relu2,
313+
is_gated_mlp=False,
313314
)
314315

315316
if has_latent_proj:

0 commit comments

Comments
 (0)