@@ -268,8 +268,7 @@ def get_valid_tactics(
268268 ** kwargs ,
269269 ) -> List [Tuple [int , int ]]:
270270 # Early exit: Check SM version - CuteDSL NVFP4 only supports SM 100 and SM 103
271- sm_version = get_sm_version ()
272- if sm_version not in [100 , 103 ]:
271+ if (sm_version := get_sm_version ()) not in (100 , 103 ):
273272 logger .debug (
274273 f"CuteDSL: SM version { sm_version } is not supported. "
275274 f"CuteDSL NVFP4 only supports SM 100 (B200) and SM 103 (B300). Skipping all tactics."
@@ -597,8 +596,7 @@ def cute_dsl_nvfp4_gemm_blackwell(
597596 for automatic backend selection with better performance.
598597 """
599598 # Validate SM version before attempting to use CuteDSL
600- sm_version = get_sm_version ()
601- if sm_version not in [100 , 103 ]:
599+ if (sm_version := get_sm_version ()) not in (100 , 103 ):
602600 raise ValueError (
603601 f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM { sm_version } . "
604602 f"Please use nvfp4_gemm with backend='auto' for automatic backend selection."
@@ -660,9 +658,9 @@ def __init__(self,
660658 self .output_dtype = output_dtype
661659 self .scaling_vector_size = scaling_vector_size
662660
663- if get_sm_version () != 100 :
661+ if ( sm_version := get_sm_version ()) not in ( 100 , 103 ) :
664662 raise ValueError (
665- f"SM version { get_sm_version () } is not supported for { self . __class__ . __name__ } , it only supports SM 100 "
663+ f"{ self . __class__ . kernel_class . __name__ } supports SM 100 (B200) and SM 103 (B300) only, but got SM { sm_version } "
666664 )
667665
668666 def unique_id (self ):
@@ -947,9 +945,9 @@ def __init__(self,
947945 self .output_dtype = output_dtype
948946 self .scaling_vector_size = scaling_vector_size
949947
950- if get_sm_version () != 100 :
948+ if ( sm_version := get_sm_version ()) not in ( 100 , 103 ) :
951949 raise ValueError (
952- f"SM version { get_sm_version () } is not supported for { self . __class__ . __name__ } , it only supports SM 100 "
950+ f"{ self . __class__ . kernel_class . __name__ } supports SM 100 (B200) and SM 103 (B300) only, but got SM { sm_version } "
953951 )
954952
955953 def unique_id (self ):
@@ -1326,9 +1324,9 @@ def __init__(self,
13261324 self .tile_size = tile_size
13271325 self .scaling_vector_size = scaling_vector_size
13281326
1329- if get_sm_version () != 100 :
1327+ if ( sm_version := get_sm_version ()) not in ( 100 , 103 ) :
13301328 raise ValueError (
1331- f"SM version { get_sm_version () } is not supported for { self . __class__ . __name__ } , it only supports SM 100 "
1329+ f"{ self . __class__ . kernel_class . __name__ } supports SM 100 (B200) and SM 103 (B300) only, but got SM { sm_version } "
13321330 )
13331331
13341332 def unique_id (self ):
0 commit comments