|
23 | 23 | import torch
|
24 | 24 | import torch.distributed as dist
|
25 | 25 | from torch.distributed import ProcessGroup
|
| 26 | +from torch.utils.cpp_extension import _get_cuda_arch_flags |
26 | 27 |
|
27 | 28 | from ..jit import JitSpec
|
28 | 29 | from ..jit import env as jit_env
|
29 | 30 | from ..jit import gen_jit_spec, sm100a_nvcc_flags
|
30 |
| -from ..utils import register_custom_op, round_up |
| 31 | +from ..utils import register_custom_op, round_up, version_at_least |
31 | 32 | from .cuda_ipc import create_shared_buffer, cudart, free_shared_buffer
|
32 | 33 |
|
33 | 34 |
|
@@ -96,15 +97,18 @@ class FP4QuantizationSFLayout:
|
96 | 97 |
|
97 | 98 |
|
98 | 99 | def gen_trtllm_comm_module() -> JitSpec:
|
99 |
| - major, minor = torch.cuda.get_device_capability() |
| 100 | + gencode_flags = _get_cuda_arch_flags() |
| 101 | + has_sm100 = any( |
| 102 | + "compute_100" in flag for flag in gencode_flags |
| 103 | + ) and version_at_least(torch.version.cuda, "12.8") |
100 | 104 | return gen_jit_spec(
|
101 | 105 | "trtllm_comm",
|
102 | 106 | [
|
103 | 107 | jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce.cu",
|
104 | 108 | jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce_fusion.cu",
|
105 | 109 | jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_allreduce_fusion.cu",
|
106 | 110 | ],
|
107 |
| - extra_cuda_cflags=sm100a_nvcc_flags if major >= 10 and minor >= 0 else [], |
| 111 | + extra_cuda_cflags=sm100a_nvcc_flags if has_sm100 else [], |
108 | 112 | )
|
109 | 113 |
|
110 | 114 |
|
|
0 commit comments