Skip to content

Commit c32e54a

Browse files
committed
enable sm103
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent 574aa70 commit c32e54a

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,7 @@ def get_valid_tactics(
268268
**kwargs,
269269
) -> List[Tuple[int, int]]:
270270
# Early exit: Check SM version - CuteDSL NVFP4 only supports SM 100 and SM 103
271-
sm_version = get_sm_version()
272-
if sm_version not in [100, 103]:
271+
if (sm_version := get_sm_version()) not in (100, 103):
273272
logger.debug(
274273
f"CuteDSL: SM version {sm_version} is not supported. "
275274
f"CuteDSL NVFP4 only supports SM 100 (B200) and SM 103 (B300). Skipping all tactics."
@@ -597,8 +596,7 @@ def cute_dsl_nvfp4_gemm_blackwell(
597596
for automatic backend selection with better performance.
598597
"""
599598
# Validate SM version before attempting to use CuteDSL
600-
sm_version = get_sm_version()
601-
if sm_version not in [100, 103]:
599+
if (sm_version := get_sm_version()) not in (100, 103):
602600
raise ValueError(
603601
f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. "
604602
f"Please use nvfp4_gemm with backend='auto' for automatic backend selection."
@@ -660,9 +658,9 @@ def __init__(self,
660658
self.output_dtype = output_dtype
661659
self.scaling_vector_size = scaling_vector_size
662660

663-
if get_sm_version() != 100:
661+
if (sm_version := get_sm_version()) not in (100, 103):
664662
raise ValueError(
665-
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
663+
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
666664
)
667665

668666
def unique_id(self):
@@ -947,9 +945,9 @@ def __init__(self,
947945
self.output_dtype = output_dtype
948946
self.scaling_vector_size = scaling_vector_size
949947

950-
if get_sm_version() != 100:
948+
if (sm_version := get_sm_version()) not in (100, 103):
951949
raise ValueError(
952-
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
950+
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
953951
)
954952

955953
def unique_id(self):
@@ -1326,9 +1324,9 @@ def __init__(self,
13261324
self.tile_size = tile_size
13271325
self.scaling_vector_size = scaling_vector_size
13281326

1329-
if get_sm_version() != 100:
1327+
if (sm_version := get_sm_version()) not in (100, 103):
13301328
raise ValueError(
1331-
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
1329+
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
13321330
)
13331331

13341332
def unique_id(self):

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,9 +1368,10 @@ def test_fused_moe_nvfp4(dtype, moe_backend):
13681368
if dtype == torch.float16:
13691369
pytest.skip(
13701370
"CUTEDSL NVFP4 MoE backend does not support float16 yet")
1371-
if get_sm_version() != 100:
1371+
if get_sm_version() not in (100, 103):
13721372
pytest.skip(
1373-
"CUTEDSL NVFP4 MoE backend is only supported on SM 100 GPUs")
1373+
"CUTEDSL NVFP4 MoE backend supports SM 100 (B200) and SM 103 (B300) only"
1374+
)
13741375

13751376
test_all_kernels = True
13761377
if get_sm_version() == 120:

0 commit comments

Comments
 (0)