@@ -695,28 +695,36 @@ def _(
695695class NVFP4GemmUnifiedRunner (TunableRunner ):
696696 runner_dict = dict ()
697697
698- def __init__ (self ,
699- to_userbuffers : bool ,
700- output_dtype : torch .dtype ,
701- backend : str = "auto" ):
698+ def __init__ (self , to_userbuffers : bool , output_dtype : torch .dtype ,
699+ allowed_backends : List [str ]):
702700 super ().__init__ ()
703701 self .to_userbuffers = to_userbuffers
704702 self .output_dtype = output_dtype
705- self .backend = backend
703+ self .allowed_backends = allowed_backends
706704
707705 def unique_id (self ):
708- """Include backend in cache key to avoid sharing cache across backends."""
709- return (self .to_userbuffers , self .output_dtype , self .backend )
706+ """Include allowed_backends in cache key to avoid sharing cache across different backend configs."""
707+ # Convert list to tuple for hashability
708+ allowed_tuple = tuple (self .allowed_backends )
709+ return (self .to_userbuffers , self .output_dtype , allowed_tuple )
710+
711+ def _is_backend_allowed (self , backend_name : str ) -> bool :
712+ """Check if a backend is allowed based on allowed_backends list."""
713+ return backend_name in self .allowed_backends
714+
715+ def _is_only_backend (self , backend_name : str ) -> bool :
716+ """Check if this is the only backend in allowed_backends (explicitly forced)."""
717+ return self .allowed_backends == [backend_name ]
710718
711719 def get_valid_tactics (self , inputs : List [torch .Tensor ],
712720 profile : OptimizationProfile ,
713721 ** kwargs ) -> List [Tuple ]:
714- # return valid nvfp4 gemm implementations
722+ # return valid nvfp4 gemm implementations from allowed_backends
715723 tactics = []
716724 act_fp4 , weight , act_sf , weight_scale , alpha = inputs
717- backend = self .backend
718725
719- if backend in ["auto" , "cuda_core" ]:
726+ # Add CUDA Core backend if available
727+ if self ._is_backend_allowed ("cuda_core" ):
720728 is_cuda_core_supported = False
721729 m = act_fp4 .shape [0 ]
722730 sm_version = None
@@ -732,40 +740,39 @@ def get_valid_tactics(self, inputs: List[torch.Tensor],
732740
733741 if is_cuda_core_supported :
734742 tactics .append ("cuda_core" )
735- elif backend == "cuda_core" :
736- # Explicitly requested but conditions not met - raise error
743+ elif self . _is_only_backend ( "cuda_core" ) :
744+ # Explicitly forced but conditions not met - raise error
737745 error_msg = f"CUDA Core backend requires SM >= { CudaCoreNVFP4Runner .MIN_SM_VERSION } and M <= { CudaCoreNVFP4Runner .MAX_M_DIMENSION } . "
738746 error_msg += f"Current: SM={ sm_version if sm_version else 'N/A' } , M={ m } . "
739- error_msg += "Please use backend='auto' or another backend ."
747+ error_msg += "Please add other backends to allowed_backends ."
740748 raise ValueError (error_msg )
741749
742750 # Add CUTLASS runner (always available)
743- if backend in [ "auto" , " cutlass"] :
751+ if self . _is_backend_allowed ( " cutlass") :
744752 tactics .append ("cutlass" )
745753
746754 # Add cuBLASLt runner if available
747- if backend in [ "auto" , " cublaslt"] :
755+ if self . _is_backend_allowed ( " cublaslt") :
748756 if IS_CUBLASLT_AVAILABLE :
749757 tactics .append ("cublaslt" )
750- elif backend == "cublaslt" :
758+ elif self . _is_only_backend ( "cublaslt" ) :
751759 raise ValueError (
752760 "cuBLASLt backend is not available. "
753- "Please check cuBLASLt installation or use backend='auto'." )
761+ "Please check cuBLASLt installation or add other backends to allowed_backends."
762+ )
754763
755764 # Add CuteDSL runner if available
756- if backend in [ "auto" , " cutedsl"] :
765+ if self . _is_backend_allowed ( " cutedsl") :
757766 if IS_CUTLASS_DSL_AVAILABLE :
758767 # Check SM version first - CuteDSL NVFP4 only supports SM 100 (B200)
759768 sm_version = get_sm_version ()
760769 if sm_version not in [100 , 103 ]:
761- if backend == "cutedsl" :
762- # Explicitly requested CuteDSL but SM version not supported
770+ if self . _is_only_backend ( "cutedsl" ) :
771+ # Explicitly forced CuteDSL but SM version not supported
763772 raise ValueError (
764773 f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM { sm_version } . "
765774 f"CuteDSL NVFP4 is not supported on this GPU architecture. "
766- f"Please use backend='auto' to automatically select a compatible backend."
767- )
768- # else: backend='auto' → silently skip CuteDSL
775+ "Please add other backends to allowed_backends." )
769776 else :
770777 # SM version OK, check if CuteDSL supports the current shape
771778 from tensorrt_llm ._torch .custom_ops .cute_dsl_custom_ops import \
@@ -778,8 +785,8 @@ def get_valid_tactics(self, inputs: List[torch.Tensor],
778785 if cutedsl_tactics :
779786 # CuteDSL supports this shape
780787 tactics .append ("cutedsl" )
781- elif backend == "cutedsl" :
782- # Explicitly requested CuteDSL but it doesn't support this shape
788+ elif self . _is_only_backend ( "cutedsl" ) :
789+ # Explicitly forced CuteDSL but it doesn't support this shape
783790 m , n , k = inputs [0 ].shape [0 ], inputs [1 ].shape [
784791 0 ], inputs [0 ].shape [1 ] * 2
785792 raise ValueError (
@@ -788,13 +795,12 @@ def get_valid_tactics(self, inputs: List[torch.Tensor],
788795 f"CuteDSL requires 16-byte alignment for major (contiguous) dimensions:\n "
789796 f" - K must be divisible by 32 (FP4 K-major layout): K%32={ '0✓' if k % 32 == 0 else str (k % 32 )+ '✗' } \n "
790797 f" - Or the combination of (M, N, K, tiling, cluster shape) is not supported\n "
791- f"Please use backend='auto' to automatically select a compatible backend."
792- )
793- # else: backend='auto' and CuteDSL doesn't support shape → silently skip
794- elif backend == "cutedsl" :
798+ f"Please add other backends to allowed_backends." )
799+ elif self ._is_only_backend ("cutedsl" ):
795800 raise ValueError (
796801 "CuteDSL backend is not available. "
797- "Please check CuteDSL installation or use backend='auto'." )
802+ "Please check CuteDSL installation or add other backends to allowed_backends."
803+ )
798804
799805 return tactics
800806
@@ -807,31 +813,23 @@ def forward(
807813 ) -> torch .Tensor :
808814 act_fp4 , weight , act_sf , weight_scale , alpha = inputs
809815
810- requested_backend = self .backend
811-
812- # If a specific backend was requested (not 'auto') and we're using fallback tactic
813- # This can happen on cache miss, where AutoTuner uses tactic=-1 as default
814- if requested_backend != 'auto' and requested_backend != tactic and tactic == - 1 :
815- # User explicitly requested a backend, but we're falling back to default
816- # This might happen on cache miss. We should validate the requested backend supports this shape.
817-
818- # Get valid tactics for the requested backend
816+ # Handle fallback tactic (-1) on cache miss
817+ if tactic == - 1 :
818+ # Get valid tactics and use first available
819819 from tensorrt_llm ._torch .autotuner import OptimizationProfile
820820 valid_tactics = self .get_valid_tactics (inputs ,
821821 OptimizationProfile ())
822-
823- if not valid_tactics or requested_backend not in valid_tactics :
824- # Requested backend doesn't support this shape
822+ if valid_tactics :
823+ # Prefer cutlass as fallback if available, otherwise use first valid tactic
824+ tactic = "cutlass" if "cutlass" in valid_tactics else valid_tactics [
825+ 0 ]
826+ else :
825827 m , n , k = inputs [0 ].shape [0 ], inputs [1 ].shape [
826828 0 ], inputs [0 ].shape [1 ] * 2
827829 raise ValueError (
828- f"Backend ' { requested_backend } ' was explicitly requested but does not support the current shape:\n "
830+ f"No valid backends available for the current shape:\n "
829831 f" M={ m } , N={ n } , K={ k } \n "
830- f"Please use backend='auto' to automatically select a compatible backend."
831- )
832-
833- # Backend supports it, use the requested backend instead of fallback
834- tactic = requested_backend
832+ f" Allowed backends: { self .allowed_backends } " )
835833
836834 if tactic == "cuda_core" :
837835 # Unswizzle the activation scale factors
@@ -882,20 +880,19 @@ def nvfp4_gemm(
882880 alpha : torch .Tensor ,
883881 output_dtype : torch .dtype ,
884882 to_userbuffers : bool = False ,
885- backend : str = "auto " ,
883+ allowed_backends : str = "cutlass,cublaslt,cuda_core " ,
886884) -> torch .Tensor :
887- """Unified NVFP4 GEMM with automatic or manual backend selection.
885+ """Unified NVFP4 GEMM with automatic backend selection.
888886
889- This function can automatically choose the best backend or force a specific backend :
887+ This function automatically chooses the best backend from the allowed list :
890888 - CUTLASS: Predefined CUTLASS configurations with auto-tuning
891889 - cuBLASLt: Heuristic-based algorithms from cuBLASLt library
892890 - CuteDSL: Blackwell-optimized persistent kernels (when available and inputs are valid)
893891 - CUDA Core: CUDA Core implementation (requires SM >= 100 and M <= 8)
894892
895893 The AutoTuner profiles all available backends during the first run and caches
896894 the best choice for each input shape. Subsequent calls use the cached selection
897- with zero overhead. In 'auto' mode, backends are only considered if their
898- requirements are met (e.g., CUDA Core only participates when SM >= 100 and M <= 8).
895+ with zero overhead.
899896
900897 Args:
901898 act_fp4: Activation tensor [m, k] in FP4 format (packed in uint8)
@@ -905,12 +902,10 @@ def nvfp4_gemm(
905902 alpha: Scaling factor (as torch.Tensor for CUTLASS/cuBLASLt compatibility)
906903 output_dtype: Output data type
907904 to_userbuffers: Whether to use user buffers (CUTLASS/cuBLASLt only)
908- backend: Backend selection, one of:
909- - 'auto': AutoTuner automatically selects best backend (default)
910- - 'cutlass': Force use CUTLASS (FP4GemmRunner)
911- - 'cublaslt': Force use cuBLASLt (CublasLtFP4GemmRunner)
912- - 'cutedsl': Force use CuteDSL (CuteDSLNVFP4Wrapper)
913- - 'cuda_core': Force use CUDA Core (CudaCoreNVFP4Runner, requires SM >= 100, M <= 8)
905+ allowed_backends: Comma-separated list of backends to consider for auto-selection.
906+ Default: "cutlass,cublaslt,cuda_core" (excludes cutedsl for faster build)
907+ Add 'cutedsl' for extreme performance at the cost of longer build time.
908+ Valid backends: 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'.
914909
915910 Returns:
916911 Output tensor [m, n] with dtype=output_dtype
@@ -919,14 +914,26 @@ def nvfp4_gemm(
919914 ValueError: If backend is invalid/unavailable
920915 """
921916
922- # Validate backend parameter
923- valid_backends = ['auto' , 'cutlass' , 'cublaslt' , 'cutedsl' , 'cuda_core' ]
924- if backend not in valid_backends :
917+ valid_individual_backends = {'cutlass' , 'cublaslt' , 'cutedsl' , 'cuda_core' }
918+
919+ # Parse comma-separated string to list
920+ backends_list = [
921+ b .strip () for b in allowed_backends .split (',' ) if b .strip ()
922+ ]
923+
924+ # Validate allowed_backends
925+ invalid_backends = set (backends_list ) - valid_individual_backends
926+ if invalid_backends :
927+ raise ValueError (
928+ f"Invalid backends in allowed_backends: { invalid_backends } . "
929+ f"Valid backends are: { sorted (valid_individual_backends )} ." )
930+ if not backends_list :
925931 raise ValueError (
926- f"Invalid backend '{ backend } '. Must be one of { valid_backends } " )
932+ f"allowed_backends cannot be empty. "
933+ f"Valid backends are: { sorted (valid_individual_backends )} ." )
927934
928- # Build list of runners based on backend parameter
929- runner = NVFP4GemmUnifiedRunner (to_userbuffers , output_dtype , backend )
935+ # Build runner with allowed backends
936+ runner = NVFP4GemmUnifiedRunner (to_userbuffers , output_dtype , backends_list )
930937
931938 # Use AutoTuner to select best runner and tactic
932939 # - For 'auto' mode: compare across all backends, find global optimum
@@ -966,7 +973,7 @@ def _(
966973 alpha : torch .Tensor ,
967974 output_dtype : torch .dtype ,
968975 to_userbuffers : bool = False ,
969- backend : str = "auto " ,
976+ allowed_backends : str = "cutlass,cublaslt,cuda_core " ,
970977) -> torch .Tensor :
971978 """Fake implementation for torch.compile support."""
972979 return act_fp4 .new_empty ((act_fp4 .size (0 ), weight .size (0 )),
0 commit comments