Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
46205b4
Using ActivationType instead of GatedActType, added compiled kernels,…
amitz-nv Jan 28, 2026
b8eac34
Add actType and eltwiseActType to 'no kernel found' message, move is_…
amitz-nv Jan 28, 2026
f771e0c
Update remaining GatedActType uses to ActivationType, remove GatedAct…
amitz-nv Jan 28, 2026
440c062
Use ActivationType in benchmarks, add missing activation_type argument
amitz-nv Jan 28, 2026
5725739
Minor fixes
amitz-nv Jan 28, 2026
c2c8531
Fix activation_type default value to Swiglu on trtllm_fp4_block_scale…
amitz-nv Jan 28, 2026
bb4e821
Minor improvement
amitz-nv Jan 28, 2026
c6ac4af
Support non-gated activation in NVFP4 block scale MoE
amitz-nv Jan 28, 2026
3bf918e
Rename useShuffledMatrixA to useShuffledMatrix (remove the 'A' suffix)
amitz-nv Jan 28, 2026
1193b02
Add FP4_NVFP4_NVFP4 parameterization to test_llama4_routing, update t…
amitz-nv Jan 28, 2026
b0e6d59
Increase supported topK and num experts in deepseek routing for nemotron
amitz-nv Jan 28, 2026
d4182ae
Commit more files for increase supported topK and num experts in deep…
amitz-nv Jan 28, 2026
8ee2193
Fix formatting
amitz-nv Jan 28, 2026
c899d16
Change TODO to comment
amitz-nv Jan 28, 2026
0f6f15c
Change default activation_type to Swiglu
amitz-nv Jan 28, 2026
cf6f76b
Restore intermediate size factor of 2 for gated activation in getWork…
amitz-nv Jan 28, 2026
e63e17d
Formatting fixes
amitz-nv Jan 28, 2026
8398e20
Treat SwigluBias as gated activation
amitz-nv Jan 28, 2026
ea67cef
Fix use of ActivationType enum in CLI
amitz-nv Jan 28, 2026
abefe22
Fix activation-type command line argument handling in benchmarks
amitz-nv Jan 29, 2026
da35764
Fix choices of activation-type command line argument handling in benc…
amitz-nv Jan 29, 2026
205989f
GEMM (non batched) still has mUseShuffledMatrixA member (with 'A' suf…
amitz-nv Jan 29, 2026
e467f1d
Update bench_trtllm_gen_fused_moe_autotuner.py to support more activa…
amitz-nv Jan 29, 2026
80d1b53
Revert activation_Type check in bench_trtllm_gen_fused_moe_autotuner.…
amitz-nv Jan 29, 2026
21e0e08
Include activation type in results in benchmarks/routings/moe.py
amitz-nv Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from flashinfer import (
RoutingMethodType,
GatedActType,
ActivationType,
fp4_quantize,
mxfp8_quantize,
)
Expand All @@ -17,6 +17,7 @@
from flashinfer.autotuner import autotune
from flashinfer.testing.utils import bench_gpu_time
from flashinfer.utils import device_support_pdl
from routines.flashinfer_benchmark_utils import enum_type

FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
FLOAT4_E2M1_MAX = 6.0
Expand All @@ -39,6 +40,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
top_k: int,
warmups: int,
iterations: int,
activation_type: ActivationType,
):
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)
Expand Down Expand Up @@ -97,6 +99,10 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
)

if is_block_scale:
if activation_type != ActivationType.Swiglu:
raise ValueError(
"Only Swiglu activation is supported for FP8 block scale MoE."
)
fn = lambda: trtllm_fp8_block_scale_moe(
routing_logits,
routing_bias,
Expand Down Expand Up @@ -144,6 +150,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
RoutingMethodType.TopK.value,
enable_pdl,
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
activation_type.value,
)

def bench(do_autotune):
Expand Down Expand Up @@ -175,6 +182,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
top_k: int,
warmups: int,
iterations: int,
activation_type: ActivationType,
):
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)
Expand Down Expand Up @@ -234,6 +242,10 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
w13_global_scale = 1.0 / 448.0 / 6.0
w2_global_scale = 1.0 / 448.0 / 6.0
else:
if activation_type == ActivationType.Relu2:
raise ValueError(
"Relu2 activation is supported for FP4 only with 'NvFP4xNvFP4' quant mode"
)
w13, w13_scale = fp4_quantize(
w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True
)
Expand Down Expand Up @@ -288,7 +300,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
RoutingMethodType.Renormalize.value,
True,
enable_pdl,
GatedActType.SwiGlu.value, # gated_act_type
activation_type.value, # act_type
None,
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
)
Expand Down Expand Up @@ -348,6 +360,14 @@ def bench(do_autotune):
parser.add_argument(
"--iterations", type=int, default=100, help="Number of benchmark iterations"
)
parser.add_argument(
"--activation-type",
type=enum_type(ActivationType),
metavar=str([e.name for e in ActivationType]),
required=False,
default=ActivationType.Swiglu,
help=f"Type of activation function: {[e.name for e in ActivationType]}",
)
args = parser.parse_args()
if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]:
bench_trtllm_gen_fused_moe_autotuner_fp8(
Expand All @@ -360,6 +380,7 @@ def bench(do_autotune):
args.top_k,
args.warmups,
args.iterations,
args.activation_type,
)
else:
bench_trtllm_gen_fused_moe_autotuner_fp4(
Expand All @@ -372,4 +393,5 @@ def bench(do_autotune):
args.top_k,
args.warmups,
args.iterations,
args.activation_type,
)
16 changes: 16 additions & 0 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import torch

from flashinfer.testing.utils import set_seed
Expand Down Expand Up @@ -453,3 +454,18 @@ def filter_backends_by_compute_capability(backends, routine, device):
f"[WARNING] {backend} for routine {routine} is not supported on compute capability {compute_capability}. Skipping."
)
return backends


def enum_type(enum_class):
"""Generic factory for argparse enum types."""

def converter(value):
try:
lower_name_to_member = {m.name.lower(): m for m in enum_class}
return lower_name_to_member[value.lower()]
except KeyError as e:
raise argparse.ArgumentTypeError(
f"Invalid value '{value}'. Must be one of: {', '.join([m.name for m in enum_class])}"
) from e

return converter
29 changes: 14 additions & 15 deletions benchmarks/routines/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

import flashinfer
from flashinfer import ActivationType
from flashinfer.autotuner import autotune
from flashinfer.fused_moe import (
WeightLayout,
Expand All @@ -23,6 +24,7 @@

from .flashinfer_benchmark_utils import (
dtype_str_to_torch_dtype,
enum_type,
get_device,
print_perf_metrics,
filter_backends_by_compute_capability,
Expand Down Expand Up @@ -175,12 +177,12 @@ def parse_moe_args(line, parser):
help="Data type of the weights (before quantization).",
)
parser.add_argument(
"--gated_act",
type=str,
"--activation-type",
type=enum_type(ActivationType),
metavar=str([e.name for e in ActivationType]),
required=False,
default="swiglu",
choices=["swiglu", "geglu"],
help="Type of gated activation function: swiglu | geglu.",
default=ActivationType.Swiglu,
help=f"Type of activation function: {[e.name for e in ActivationType]}",
)
parser.add_argument(
"--autotune",
Expand Down Expand Up @@ -247,13 +249,6 @@ def parse_moe_args(line, parser):
}
args.routing_method_type = routing_method_name_to_type[args.routing_method]

# Normalize gated act type (map string to internal int expected by kernels)
gated_act_name_to_type = {
"swiglu": 0,
"geglu": 1,
}
args.gated_act_type = gated_act_name_to_type[args.gated_act]

if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
Expand Down Expand Up @@ -630,7 +625,7 @@ def testTrtllmFp4BlockScaleMoe(args):
use_shuffled_weight = args.use_shuffled_weight
weight_layout = args.weight_layout
is_cuda_graph_compatible = not args.no_cuda_graph
gated_act_type = args.gated_act_type
activation_type = args.activation_type
res = []

backends = ["trtllm"]
Expand Down Expand Up @@ -795,7 +790,7 @@ def run_fp4_moe(
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=routing_method_type,
gated_act_type=gated_act_type,
activation_type=activation_type.value,
do_finalize=True,
)

Expand Down Expand Up @@ -900,7 +895,7 @@ def run_fp4_moe(
cur_res["use_routing_scales_on_input"] = args.use_routing_scales_on_input
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = weight_dtype
cur_res["gated_act"] = args.gated_act
cur_res["activation_type"] = args.activation_type.name
res.append(cur_res)

return res
Expand Down Expand Up @@ -1671,6 +1666,7 @@ def run_fp8_per_tensor_moe(
output1_scales_gate_scalar,
gemm2_weights_fp8,
output2_scales_scalar,
activation_type,
):
# Note: FP8 per-tensor MOE expects int64_t for n_group/topk_group, not Optional[int64_t]
# So we convert None to 0 to indicate "no groups" mode
Expand All @@ -1693,6 +1689,7 @@ def run_fp8_per_tensor_moe(
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
routing_method_type=routing_method_type,
activation_type=activation_type.value,
)

# Benchmark timing
Expand All @@ -1713,6 +1710,7 @@ def run_fp8_per_tensor_moe(
output1_scales_gate_scalar,
gemm2_weights_fp8,
output2_scales_scalar,
args.activation_type,
),
)

Expand Down Expand Up @@ -1764,6 +1762,7 @@ def run_fp8_per_tensor_moe(
cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = weight_dtype
cur_res["activation_type"] = args.activation_type.name
res.append(cur_res)

return res
10 changes: 8 additions & 2 deletions csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,16 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
options.mTransposeMmaOutput == mOptions.transposeMmaOutput &&
(!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct &&
options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch &&
tileSize == mOptions.tileSize &&
options.mUseShuffledMatrix == mOptions.useShuffledMatrixA &&
tileSize == mOptions.tileSize && options.mUseShuffledMatrix == mOptions.useShuffledMatrix &&
options.mLayoutA == mOptions.weightLayout) {
if (options.mFusedAct) {
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType)) {
continue;
}
}
if ((int64_t)options.mEltwiseActType != (int64_t)mOptions.eltwiseActType) {
continue;
}

if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) {
mPassingConfigIndices.push_back(i);
Expand All @@ -122,6 +124,8 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
<< ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB)
<< ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC)
<< ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8
<< ", mActType: " << (int64_t)mOptions.actType
<< ", mEltwiseActType: " << (int64_t)mOptions.eltwiseActType
<< ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput
<< ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct
<< ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize;
Expand Down Expand Up @@ -219,6 +223,8 @@ void TrtllmGenBatchedGemmRunner::run(
gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB;
gemmData.mInputBuffers.mPtrScaleC = scaleC;
gemmData.mInputBuffers.mPtrScaleGate = scaleGateC;
// For simplicity pass set scaleAct to scaleGateC
gemmData.mInputBuffers.mPtrScaleAct = scaleGateC;
Copy link
Contributor Author

@amitz-nv amitz-nv Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decide whether it's OK or fix in the future?

gemmData.mInputBuffers.mPtrPerTokenSfA =
mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA;
gemmData.mInputBuffers.mPtrPerTokenSfB =
Expand Down
Loading
Loading