Skip to content

Commit 8186180

Browse files
committed
support gemm+allreduce only on arch >= blackwell
Signed-off-by: benzh-2025 <[email protected]>
1 parent a070b9d commit 8186180

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,19 @@ def __init__(
676676
dtype_supported = config.torch_dtype in (torch.float16, torch.bfloat16)
677677
tp_valid = self.mapping.tp_size > 1
678678
quant_valid = self.is_nvfp4 is not None and self.is_nvfp4
679-
use_fused_gemm_allreduce = all(
680-
[mpi_enabled, dtype_supported, tp_valid, quant_valid])
679+
680+
device_supported = False
681+
if torch.cuda.is_available():
682+
capability = torch.cuda.get_device_capability(
683+
torch.device('cuda:0'))
684+
sm_version = capability[0] * 10 + capability[1]
685+
if sm_version >= 100:
686+
device_supported = True
687+
688+
use_fused_gemm_allreduce = all([
689+
mpi_enabled, dtype_supported, tp_valid, quant_valid,
690+
device_supported
691+
])
681692

682693
def check_in_out_features(in_features, out_features):
683694
in_feature_valid = in_features % 128 == 0 and in_features >= 1024

tensorrt_llm/_torch/modules/linear.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2163,9 +2163,19 @@ def __init__(
21632163
tp_valid = self.tp_mode is not None and self.tp_mode == TensorParallelMode.ROW and self.tp_size > 1
21642164
quant_valid = self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
21652165
)
2166+
2167+
device_supported = False
2168+
if torch.cuda.is_available():
2169+
capability = torch.cuda.get_device_capability(
2170+
torch.device('cuda:0'))
2171+
sm_version = capability[0] * 10 + capability[1]
2172+
if sm_version >= 100:
2173+
device_supported = True
2174+
21662175
self.use_fused_gemm_allreduce = all([
21672176
self.reduce_output, mpi_enabled, dtype_supported,
2168-
in_features_aligned, out_features_aligned, tp_valid, quant_valid
2177+
in_features_aligned, out_features_aligned, tp_valid, quant_valid,
2178+
device_supported
21692179
])
21702180

21712181
self.enable_cuda_core = False

0 commit comments

Comments
 (0)