Skip to content

Commit 2645a78

Browse files
longlee0622kaiyux
andauthored
[TRTLLM-9660][feat] Convert cuteDSL GEMM to opt-in feature (#9682)
Signed-off-by: Jonas Li <6110159+longlee0622@users.noreply.github.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
1 parent 8d2178d commit 2645a78

File tree

8 files changed

+216
-149
lines changed

8 files changed

+216
-149
lines changed

tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass):
526526
alpha_key = KeywordArg('alpha')
527527
output_dtype_key = KeywordArg('output_dtype')
528528
to_userbuffers_key = KeywordArg('to_userbuffers')
529-
backend_key = KeywordArg('backend')
529+
allowed_backends_key = KeywordArg('allowed_backends')
530530
trtllm_nvfp4_gemm_default = CallFunction(
531531
torch.ops.trtllm.nvfp4_gemm.default,
532532
act_fp4_key,
@@ -536,7 +536,7 @@ def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass):
536536
alpha_key,
537537
output_dtype_key,
538538
to_userbuffers=to_userbuffers_key,
539-
backend=backend_key)
539+
allowed_backends=allowed_backends_key)
540540
ub_copy = CallFunction(torch.ops.trtllm.copy_to_userbuffers,
541541
trtllm_nvfp4_gemm_default)
542542

@@ -548,7 +548,7 @@ def empty_nvfp4_gemm_prologue_pattern(
548548
alpha: torch.Tensor,
549549
output_dtype: torch.dtype,
550550
to_userbuffers: bool,
551-
backend: str,
551+
allowed_backends: str,
552552
):
553553
return
554554

@@ -560,26 +560,31 @@ def target_nvfp4_gemm_prologue_pattern(
560560
alpha: torch.Tensor,
561561
output_dtype: torch.dtype,
562562
to_userbuffers: bool,
563-
backend: str,
563+
allowed_backends: str,
564564
):
565565
nvfp4_gemm_output = torch.ops.trtllm.nvfp4_gemm(
566566
act_fp4, weight, act_sf, weight_scale, alpha, output_dtype,
567-
True, backend)
567+
True, allowed_backends)
568568
return nvfp4_gemm_output
569569

570570
def extra_check(match: Match) -> bool:
571-
# Validate backend value
572-
backend_value = match.kwargs.get('backend')
573-
if backend_value is None:
574-
# No backend specified, use default - OK
575-
return True
576-
577-
# backend should be a string literal
578-
if not isinstance(backend_value, str):
579-
return False
580-
581-
valid_backends = {'auto', 'cutlass', 'cublaslt', 'cutedsl'}
582-
return backend_value in valid_backends
571+
# Validate allowed_backends if present (now a comma-separated string)
572+
allowed_backends_value = match.kwargs.get('allowed_backends')
573+
if allowed_backends_value is not None:
574+
# allowed_backends should be a comma-separated string
575+
if not isinstance(allowed_backends_value, str):
576+
return False
577+
backends_list = [
578+
b.strip() for b in allowed_backends_value.split(',')
579+
if b.strip()
580+
]
581+
valid_individual = {
582+
'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'
583+
}
584+
if not all(b in valid_individual for b in backends_list):
585+
return False
586+
587+
return True
583588

584589
register_replacement(
585590
empty_nvfp4_gemm_prologue_pattern,

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 73 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -695,28 +695,36 @@ def _(
695695
class NVFP4GemmUnifiedRunner(TunableRunner):
696696
runner_dict = dict()
697697

698-
def __init__(self,
699-
to_userbuffers: bool,
700-
output_dtype: torch.dtype,
701-
backend: str = "auto"):
698+
def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype,
699+
allowed_backends: List[str]):
702700
super().__init__()
703701
self.to_userbuffers = to_userbuffers
704702
self.output_dtype = output_dtype
705-
self.backend = backend
703+
self.allowed_backends = allowed_backends
706704

707705
def unique_id(self):
708-
"""Include backend in cache key to avoid sharing cache across backends."""
709-
return (self.to_userbuffers, self.output_dtype, self.backend)
706+
"""Include allowed_backends in cache key to avoid sharing cache across different backend configs."""
707+
# Convert list to tuple for hashability
708+
allowed_tuple = tuple(self.allowed_backends)
709+
return (self.to_userbuffers, self.output_dtype, allowed_tuple)
710+
711+
def _is_backend_allowed(self, backend_name: str) -> bool:
712+
"""Check if a backend is allowed based on allowed_backends list."""
713+
return backend_name in self.allowed_backends
714+
715+
def _is_only_backend(self, backend_name: str) -> bool:
716+
"""Check if this is the only backend in allowed_backends (explicitly forced)."""
717+
return self.allowed_backends == [backend_name]
710718

711719
def get_valid_tactics(self, inputs: List[torch.Tensor],
712720
profile: OptimizationProfile,
713721
**kwargs) -> List[Tuple]:
714-
# return valid nvfp4 gemm implementations
722+
# return valid nvfp4 gemm implementations from allowed_backends
715723
tactics = []
716724
act_fp4, weight, act_sf, weight_scale, alpha = inputs
717-
backend = self.backend
718725

719-
if backend in ["auto", "cuda_core"]:
726+
# Add CUDA Core backend if available
727+
if self._is_backend_allowed("cuda_core"):
720728
is_cuda_core_supported = False
721729
m = act_fp4.shape[0]
722730
sm_version = None
@@ -732,40 +740,39 @@ def get_valid_tactics(self, inputs: List[torch.Tensor],
732740

733741
if is_cuda_core_supported:
734742
tactics.append("cuda_core")
735-
elif backend == "cuda_core":
736-
# Explicitly requested but conditions not met - raise error
743+
elif self._is_only_backend("cuda_core"):
744+
# Explicitly forced but conditions not met - raise error
737745
error_msg = f"CUDA Core backend requires SM >= {CudaCoreNVFP4Runner.MIN_SM_VERSION} and M <= {CudaCoreNVFP4Runner.MAX_M_DIMENSION}. "
738746
error_msg += f"Current: SM={sm_version if sm_version else 'N/A'}, M={m}. "
739-
error_msg += "Please use backend='auto' or another backend."
747+
error_msg += "Please add other backends to allowed_backends."
740748
raise ValueError(error_msg)
741749

742750
# Add CUTLASS runner (always available)
743-
if backend in ["auto", "cutlass"]:
751+
if self._is_backend_allowed("cutlass"):
744752
tactics.append("cutlass")
745753

746754
# Add cuBLASLt runner if available
747-
if backend in ["auto", "cublaslt"]:
755+
if self._is_backend_allowed("cublaslt"):
748756
if IS_CUBLASLT_AVAILABLE:
749757
tactics.append("cublaslt")
750-
elif backend == "cublaslt":
758+
elif self._is_only_backend("cublaslt"):
751759
raise ValueError(
752760
"cuBLASLt backend is not available. "
753-
"Please check cuBLASLt installation or use backend='auto'.")
761+
"Please check cuBLASLt installation or add other backends to allowed_backends."
762+
)
754763

755764
# Add CuteDSL runner if available
756-
if backend in ["auto", "cutedsl"]:
765+
if self._is_backend_allowed("cutedsl"):
757766
if IS_CUTLASS_DSL_AVAILABLE:
758767
# Check SM version first - CuteDSL NVFP4 only supports SM 100 (B200)
759768
sm_version = get_sm_version()
760769
if sm_version not in [100, 103]:
761-
if backend == "cutedsl":
762-
# Explicitly requested CuteDSL but SM version not supported
770+
if self._is_only_backend("cutedsl"):
771+
# Explicitly forced CuteDSL but SM version not supported
763772
raise ValueError(
764773
f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. "
765774
f"CuteDSL NVFP4 is not supported on this GPU architecture. "
766-
f"Please use backend='auto' to automatically select a compatible backend."
767-
)
768-
# else: backend='auto' → silently skip CuteDSL
775+
"Please add other backends to allowed_backends.")
769776
else:
770777
# SM version OK, check if CuteDSL supports the current shape
771778
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \
@@ -778,8 +785,8 @@ def get_valid_tactics(self, inputs: List[torch.Tensor],
778785
if cutedsl_tactics:
779786
# CuteDSL supports this shape
780787
tactics.append("cutedsl")
781-
elif backend == "cutedsl":
782-
# Explicitly requested CuteDSL but it doesn't support this shape
788+
elif self._is_only_backend("cutedsl"):
789+
# Explicitly forced CuteDSL but it doesn't support this shape
783790
m, n, k = inputs[0].shape[0], inputs[1].shape[
784791
0], inputs[0].shape[1] * 2
785792
raise ValueError(
@@ -788,13 +795,12 @@ def get_valid_tactics(self, inputs: List[torch.Tensor],
788795
f"CuteDSL requires 16-byte alignment for major (contiguous) dimensions:\n"
789796
f" - K must be divisible by 32 (FP4 K-major layout): K%32={'0✓' if k % 32 == 0 else str(k%32)+'✗'}\n"
790797
f" - Or the combination of (M, N, K, tiling, cluster shape) is not supported\n"
791-
f"Please use backend='auto' to automatically select a compatible backend."
792-
)
793-
# else: backend='auto' and CuteDSL doesn't support shape → silently skip
794-
elif backend == "cutedsl":
798+
f"Please add other backends to allowed_backends.")
799+
elif self._is_only_backend("cutedsl"):
795800
raise ValueError(
796801
"CuteDSL backend is not available. "
797-
"Please check CuteDSL installation or use backend='auto'.")
802+
"Please check CuteDSL installation or add other backends to allowed_backends."
803+
)
798804

799805
return tactics
800806

@@ -807,31 +813,23 @@ def forward(
807813
) -> torch.Tensor:
808814
act_fp4, weight, act_sf, weight_scale, alpha = inputs
809815

810-
requested_backend = self.backend
811-
812-
# If a specific backend was requested (not 'auto') and we're using fallback tactic
813-
# This can happen on cache miss, where AutoTuner uses tactic=-1 as default
814-
if requested_backend != 'auto' and requested_backend != tactic and tactic == -1:
815-
# User explicitly requested a backend, but we're falling back to default
816-
# This might happen on cache miss. We should validate the requested backend supports this shape.
817-
818-
# Get valid tactics for the requested backend
816+
# Handle fallback tactic (-1) on cache miss
817+
if tactic == -1:
818+
# Get valid tactics and use first available
819819
from tensorrt_llm._torch.autotuner import OptimizationProfile
820820
valid_tactics = self.get_valid_tactics(inputs,
821821
OptimizationProfile())
822-
823-
if not valid_tactics or requested_backend not in valid_tactics:
824-
# Requested backend doesn't support this shape
822+
if valid_tactics:
823+
# Prefer cutlass as fallback if available, otherwise use first valid tactic
824+
tactic = "cutlass" if "cutlass" in valid_tactics else valid_tactics[
825+
0]
826+
else:
825827
m, n, k = inputs[0].shape[0], inputs[1].shape[
826828
0], inputs[0].shape[1] * 2
827829
raise ValueError(
828-
f"Backend '{requested_backend}' was explicitly requested but does not support the current shape:\n"
830+
f"No valid backends available for the current shape:\n"
829831
f" M={m}, N={n}, K={k}\n"
830-
f"Please use backend='auto' to automatically select a compatible backend."
831-
)
832-
833-
# Backend supports it, use the requested backend instead of fallback
834-
tactic = requested_backend
832+
f" Allowed backends: {self.allowed_backends}")
835833

836834
if tactic == "cuda_core":
837835
# Unswizzle the activation scale factors
@@ -882,20 +880,19 @@ def nvfp4_gemm(
882880
alpha: torch.Tensor,
883881
output_dtype: torch.dtype,
884882
to_userbuffers: bool = False,
885-
backend: str = "auto",
883+
allowed_backends: str = "cutlass,cublaslt,cuda_core",
886884
) -> torch.Tensor:
887-
"""Unified NVFP4 GEMM with automatic or manual backend selection.
885+
"""Unified NVFP4 GEMM with automatic backend selection.
888886
889-
This function can automatically choose the best backend or force a specific backend:
887+
This function automatically chooses the best backend from the allowed list:
890888
- CUTLASS: Predefined CUTLASS configurations with auto-tuning
891889
- cuBLASLt: Heuristic-based algorithms from cuBLASLt library
892890
- CuteDSL: Blackwell-optimized persistent kernels (when available and inputs are valid)
893891
- CUDA Core: CUDA Core implementation (requires SM >= 100 and M <= 8)
894892
895893
The AutoTuner profiles all available backends during the first run and caches
896894
the best choice for each input shape. Subsequent calls use the cached selection
897-
with zero overhead. In 'auto' mode, backends are only considered if their
898-
requirements are met (e.g., CUDA Core only participates when SM >= 100 and M <= 8).
895+
with zero overhead.
899896
900897
Args:
901898
act_fp4: Activation tensor [m, k] in FP4 format (packed in uint8)
@@ -905,12 +902,10 @@ def nvfp4_gemm(
905902
alpha: Scaling factor (as torch.Tensor for CUTLASS/cuBLASLt compatibility)
906903
output_dtype: Output data type
907904
to_userbuffers: Whether to use user buffers (CUTLASS/cuBLASLt only)
908-
backend: Backend selection, one of:
909-
- 'auto': AutoTuner automatically selects best backend (default)
910-
- 'cutlass': Force use CUTLASS (FP4GemmRunner)
911-
- 'cublaslt': Force use cuBLASLt (CublasLtFP4GemmRunner)
912-
- 'cutedsl': Force use CuteDSL (CuteDSLNVFP4Wrapper)
913-
- 'cuda_core': Force use CUDA Core (CudaCoreNVFP4Runner, requires SM >= 100, M <= 8)
905+
allowed_backends: Comma-separated list of backends to consider for auto-selection.
906+
Default: "cutlass,cublaslt,cuda_core" (excludes cutedsl for faster build)
907+
Add 'cutedsl' for extreme performance at the cost of longer build time.
908+
Valid backends: 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'.
914909
915910
Returns:
916911
Output tensor [m, n] with dtype=output_dtype
@@ -919,14 +914,26 @@ def nvfp4_gemm(
919914
ValueError: If backend is invalid/unavailable
920915
"""
921916

922-
# Validate backend parameter
923-
valid_backends = ['auto', 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core']
924-
if backend not in valid_backends:
917+
valid_individual_backends = {'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'}
918+
919+
# Parse comma-separated string to list
920+
backends_list = [
921+
b.strip() for b in allowed_backends.split(',') if b.strip()
922+
]
923+
924+
# Validate allowed_backends
925+
invalid_backends = set(backends_list) - valid_individual_backends
926+
if invalid_backends:
927+
raise ValueError(
928+
f"Invalid backends in allowed_backends: {invalid_backends}. "
929+
f"Valid backends are: {sorted(valid_individual_backends)}.")
930+
if not backends_list:
925931
raise ValueError(
926-
f"Invalid backend '{backend}'. Must be one of {valid_backends}")
932+
f"allowed_backends cannot be empty. "
933+
f"Valid backends are: {sorted(valid_individual_backends)}.")
927934

928-
# Build list of runners based on backend parameter
929-
runner = NVFP4GemmUnifiedRunner(to_userbuffers, output_dtype, backend)
935+
# Build runner with allowed backends
936+
runner = NVFP4GemmUnifiedRunner(to_userbuffers, output_dtype, backends_list)
930937

931938
# Use AutoTuner to select best runner and tactic
932939
# - For 'auto' mode: compare across all backends, find global optimum
@@ -966,7 +973,7 @@ def _(
966973
alpha: torch.Tensor,
967974
output_dtype: torch.dtype,
968975
to_userbuffers: bool = False,
969-
backend: str = "auto",
976+
allowed_backends: str = "cutlass,cublaslt,cuda_core",
970977
) -> torch.Tensor:
971978
"""Fake implementation for torch.compile support."""
972979
return act_fp4.new_empty((act_fp4.size(0), weight.size(0)),

tensorrt_llm/_torch/model_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ class ModelConfig(Generic[TConfig]):
102102
# If true, use low precision combine in MoE operations (only for NVFP4 quantization)
103103
use_low_precision_moe_combine: bool = False
104104

105+
# NVFP4 GEMM backend configuration - list of backends to consider for auto-selection
106+
# Default excludes 'cutedsl' for faster build time. Add 'cutedsl' for extreme perf.
107+
nvfp4_gemm_allowed_backends: List[str] = field(
108+
default_factory=lambda: ['cutlass', 'cublaslt', 'cuda_core'])
109+
105110
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO
106111

107112
# If true, enable min-latency mode. Currently only used for Llama4.

0 commit comments

Comments
 (0)