Skip to content

Commit cfc6143

Browse files
committed
refactor logical check for gemm+allreduce fusion
Signed-off-by: benzh-2025 <[email protected]>
1 parent da85f02 commit cfc6143

File tree

3 files changed

+35
-32
lines changed

3 files changed

+35
-32
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -672,13 +672,17 @@ def __init__(
672672
# Disable fusion for small models due to accuracy issues
673673
self.enable_fusion &= config.hidden_size > 4096
674674

675-
use_fused_gemm_allreduce = True
676-
use_fused_gemm_allreduce &= (not mpi_disabled())
677-
use_fused_gemm_allreduce &= (self.mapping.tp_size > 1)
678-
use_fused_gemm_allreduce &= (config.torch_dtype
679-
in (torch.float16, torch.bfloat16))
680-
use_fused_gemm_allreduce &= (self.is_nvfp4 is not None
681-
and self.is_nvfp4)
675+
mpi_enabled = not mpi_disabled()
676+
dtype_supported = config.torch_dtype in (torch.float16, torch.bfloat16)
677+
tp_valid = self.mapping.tp_size > 1
678+
quant_valid = self.is_nvfp4 is not None and self.is_nvfp4
679+
use_fused_gemm_allreduce = all(
680+
[mpi_enabled, dtype_supported, tp_valid, quant_valid])
681+
682+
def check_in_out_features(in_features, out_features):
683+
in_feature_valid = in_features % 128 == 0 and in_features >= 1024
684+
out_feature_valid = out_features % 64 == 0 and out_features >= 1024
685+
return all([in_feature_valid, out_feature_valid])
682686

683687
num_heads = config.num_attention_heads
684688
head_dim = getattr(config, 'head_dim', None)
@@ -687,21 +691,22 @@ def __init__(
687691

688692
in_features = num_heads * head_dim
689693
out_features = config.hidden_size
690-
in_features_div_by = 128
691-
attn_fused_gemm_allreduce = use_fused_gemm_allreduce and in_features % in_features_div_by == 0 and in_features >= 1024
692-
attn_fused_gemm_allreduce &= (out_features % 64 == 0
693-
and out_features >= 1024)
694+
in_out_features_valid = check_in_out_features(in_features, out_features)
694695

696+
attn_fused_gemm_allreduce = all(
697+
[use_fused_gemm_allreduce, in_out_features_valid])
695698
self.PRE_MLP_FUSION = not attn_fused_gemm_allreduce and self.mapping.has_tp(
696699
) and not self.enable_attention_dp and self.enable_fusion
697700

698701
in_features = config.intermediate_size
699702
out_features = config.hidden_size
700-
in_features_div_by = 128 * self.mapping.tp_size
701-
mlp_fused_gemm_allreduce = use_fused_gemm_allreduce and in_features % in_features_div_by == 0 and in_features >= 1024
702-
mlp_fused_gemm_allreduce &= (out_features % 64 == 0
703-
and out_features >= 1024)
704-
703+
in_features_aligned_with_tp = in_features % self.mapping.tp_size == 0
704+
in_out_features_valid = check_in_out_features(
705+
in_features // self.mapping.tp_size, out_features)
706+
mlp_fused_gemm_allreduce = all([
707+
use_fused_gemm_allreduce, in_features_aligned_with_tp,
708+
in_out_features_valid
709+
])
705710
self.POST_MLP_FUSION = not mlp_fused_gemm_allreduce and self.mapping.has_tp(
706711
) and self.enable_fusion
707712

tensorrt_llm/_torch/modules/linear.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,19 +2156,17 @@ def __init__(
21562156
self.use_custom_cublas_mm = use_custom_cublas_mm
21572157
self.lora = lora
21582158

2159-
use_fused_gemm_allreduce = True
2160-
use_fused_gemm_allreduce &= (not mpi_disabled())
2161-
use_fused_gemm_allreduce &= self.dtype in (torch.float16,
2162-
torch.bfloat16)
2163-
use_fused_gemm_allreduce &= (self.in_features % 128 == 0)
2164-
use_fused_gemm_allreduce &= (self.tp_mode is not None
2165-
and self.tp_mode == TensorParallelMode.ROW)
2166-
use_fused_gemm_allreduce &= (self.tp_size > 1 and self.reduce_output)
2167-
use_fused_gemm_allreduce &= (self.out_features % 64 == 0)
2168-
use_fused_gemm_allreduce &= (
2169-
self.quant_config is not None
2170-
and self.quant_config.layer_quant_mode.has_nvfp4())
2171-
self.use_fused_gemm_allreduce = use_fused_gemm_allreduce
2159+
mpi_enabled = not mpi_disabled()
2160+
dtype_supported = self.dtype in (torch.float16, torch.bfloat16)
2161+
in_features_aligned = self.in_features % 128 == 0
2162+
out_features_aligned = self.out_features % 64 == 0
2163+
tp_valid = self.tp_mode is not None and self.tp_mode == TensorParallelMode.ROW and self.tp_size > 1
2164+
quant_valid = self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
2165+
)
2166+
self.use_fused_gemm_allreduce = all([
2167+
self.reduce_output, mpi_enabled, dtype_supported,
2168+
in_features_aligned, out_features_aligned, tp_valid, quant_valid
2169+
])
21722170

21732171
self.enable_cuda_core = False
21742172
if torch.cuda.is_available():

tests/unittest/_torch/multi_gpu/test_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,16 +420,16 @@ def fp4_row_linear_allreduce_run_single_rank(func, tp_size, seq_len,
420420
func(tp_size, local_rank, seq_len, output_size, hidden_size, dtype,
421421
output_ref, x_sf_global, w_sf_global, x_fp4s, w_fp4, x_sf_blocks,
422422
w_sf_block_unswizzled)
423-
except Exception:
424-
traceback.print_exc()
423+
except Exception as e:
424+
print(f"Error: {e}")
425425
raise
426426
return True
427427

428428

429429
@skip_pre_blackwell
430430
@pytest.mark.skipif(torch.cuda.device_count() < 2,
431431
reason='needs 2 GPUs to run this test')
432-
@pytest.mark.parametrize("seq_len", [256], ids=lambda x: f"seqlen:{x}")
432+
@pytest.mark.parametrize("seq_len", [256, 400], ids=lambda x: f"seqlen:{x}")
433433
@pytest.mark.parametrize("output_size", [32, 64], ids=lambda x: f"output:{x}")
434434
@pytest.mark.parametrize("hidden_size", [128, 256], ids=lambda x: f"hidden:{x}")
435435
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16],

0 commit comments

Comments
 (0)