Skip to content

Commit ece1689

Browse files
dtrifiroyzh119
andauthored
bugfix: gen_trtllm_comm_module: fix device capability detection (#1356)
fixes #1256 --------- Co-authored-by: Zihao Ye <[email protected]>
1 parent e458896 commit ece1689

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

flashinfer/comm/trtllm_ar.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@
2323
import torch
2424
import torch.distributed as dist
2525
from torch.distributed import ProcessGroup
26+
from torch.utils.cpp_extension import _get_cuda_arch_flags
2627

2728
from ..jit import JitSpec
2829
from ..jit import env as jit_env
2930
from ..jit import gen_jit_spec, sm100a_nvcc_flags
30-
from ..utils import register_custom_op, round_up
31+
from ..utils import register_custom_op, round_up, version_at_least
3132
from .cuda_ipc import create_shared_buffer, cudart, free_shared_buffer
3233

3334

@@ -96,15 +97,18 @@ class FP4QuantizationSFLayout:
9697

9798

9899
def gen_trtllm_comm_module() -> JitSpec:
99-
major, minor = torch.cuda.get_device_capability()
100+
gencode_flags = _get_cuda_arch_flags()
101+
has_sm100 = any(
102+
"compute_100" in flag for flag in gencode_flags
103+
) and version_at_least(torch.version.cuda, "12.8")
100104
return gen_jit_spec(
101105
"trtllm_comm",
102106
[
103107
jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce.cu",
104108
jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce_fusion.cu",
105109
jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_allreduce_fusion.cu",
106110
],
107-
extra_cuda_cflags=sm100a_nvcc_flags if major >= 10 and minor >= 0 else [],
111+
extra_cuda_cflags=sm100a_nvcc_flags if has_sm100 else [],
108112
)
109113

110114

0 commit comments

Comments
 (0)