Skip to content

Commit 2225745

Browse files
hyuknchzblych
authored andcommitted
[TRTLLM-8129][feat] Allreduce tuning and benchmark script revising (NVIDIA#7870)
Because we have encountered some perf regression due to using a one-shot kernel instead of NCCL on A100/H100, it will be beneficial if we can have a solid benchmarking of allreduce Op and analyze the data collected from it. Implemented new AllreduceOp heuristic: - Added Linear programming-based heuristic implementation. - Added LUT-based heuristic implementation and corresponding code generation script. AllreduceOp minor fixing: - Fixed a minor issue in AllreduceOp, that the strategy can not be overridden when ONESHOT or TWOSHOT is set. - Fixed a minor TWOSHOT kernel perf issue. - Cleaned up Dispatching code in AllReduceOp. This PR will fix the perf gaps reported in: https://nvbugspro.nvidia.com/bug/5517023 For Deepseek-R1, it shows a performance gain of about 3-4% in concurrency levels of 256 and 512. Signed-off-by: Yukun He <[email protected]> Signed-off-by: Mike Iovine <[email protected]>
1 parent 34fbc70 commit 2225745

File tree

9 files changed

+1597
-300
lines changed

9 files changed

+1597
-300
lines changed

cpp/tensorrt_llm/common/customAllReduceUtils.h

Lines changed: 255 additions & 0 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,12 @@ public:
134134
// corresponding CTA has not been launched.
135135
for (int flag_idx = blockIdx.x; flag_idx < kBarrierFlagCount; flag_idx += gridDim.x)
136136
{
137-
st_flag(m_target_flag + flag_idx * NRanks, m_flag_value);
137+
asm volatile(
138+
"st.global.relaxed.sys.b32 [%1], %0;" ::"r"(m_flag_value), "l"(m_target_flag + flag_idx * NRanks));
138139
}
140+
// Single release fence
141+
asm volatile("fence.release.sys;");
142+
139143
while (ld_flag(m_current_flag) == prev_flag(m_flag_value))
140144
{
141145
}

cpp/tensorrt_llm/kernels/customAllReduceKernels.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,30 @@ inline std::string toString(AllReduceFusionOp op)
106106
return oss.str();
107107
}
108108

109+
inline std::ostream& operator<<(std::ostream& os, AllReduceStrategyType op)
110+
{
111+
switch (op)
112+
{
113+
case AllReduceStrategyType::NCCL: os << "NCCL"; break;
114+
case AllReduceStrategyType::MIN_LATENCY: os << "MIN_LATENCY"; break;
115+
case AllReduceStrategyType::UB: os << "UB"; break;
116+
case AllReduceStrategyType::AUTO: os << "AUTO"; break;
117+
case AllReduceStrategyType::ONESHOT: os << "ONESHOT"; break;
118+
case AllReduceStrategyType::TWOSHOT: os << "TWOSHOT"; break;
119+
case AllReduceStrategyType::LOWPRECISION: os << "LOWPRECISION"; break;
120+
case AllReduceStrategyType::MNNVL: os << "MNNVL"; break;
121+
case AllReduceStrategyType::NCCL_SYMMETRIC: os << "NCCL_SYMMETRIC"; break;
122+
}
123+
return os;
124+
}
125+
126+
inline std::string toString(AllReduceStrategyType op)
127+
{
128+
std::ostringstream oss;
129+
oss << op;
130+
return oss.str();
131+
}
132+
109133
struct AllReduceFusionParams
110134
{
111135
AllReduceFusionParams()

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 37 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -273,14 +273,16 @@ class AllreduceOp
273273
{
274274
size_t size = input.numel();
275275
size_t seq_len = input.size(0);
276+
size_t hidden_size = input.size(-1);
276277
size_t bytes_per_element = input.element_size();
277278
TLLM_LOG_DEBUG("All reduce message size is %zu", size * bytes_per_element);
278279

279-
AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size);
280+
AllReduceStrategyType runtime_strategy = selectImplementation(seq_len, hidden_size);
280281

281282
// Log runtime strategy
282283
auto const rank = getRank();
283-
logRunTimeStrategy(runtime_strategy, rank);
284+
TLLM_LOG_DEBUG(
285+
"AllReduceOp runtime strategy for rank %d: " + tensorrt_llm::kernels::toString(runtime_strategy), rank);
284286

285287
// Dispatch to different allreduce implementations
286288
switch (runtime_strategy)
@@ -584,10 +586,11 @@ class AllreduceOp
584586
allreduce_fusion_params.norm_out = nullptr;
585587
allreduce_fusion_params.trigger_completion_at_end = trigger_completion_at_end;
586588

587-
// Determine if using oneshot or twoshot allreduce kernel
589+
// Determine if using oneshot or twoshot allreduce kernel in case using MIN_LATENCY strategy.
588590
if (strategy == AllReduceStrategyType::MIN_LATENCY)
589591
{
590-
allreduce_fusion_params.use_oneshot = seq_len <= tensorrt_llm::kernels::ar_fusion::kOneShotMaxToken;
592+
allreduce_fusion_params.use_oneshot = seq_len <= tensorrt_llm::kernels::ar_fusion::kOneShotMaxToken
593+
|| hidden_size < static_cast<int64_t>(tp_size);
591594
}
592595
else
593596
{
@@ -794,70 +797,6 @@ class AllreduceOp
794797
return {};
795798
}
796799

797-
AllReduceStrategyType getRuntimeStrategy(size_t seq_len, size_t size)
798-
{
799-
AllReduceStrategyType runtime_strategy;
800-
if (mStrategy == AllReduceStrategyType::UB)
801-
{
802-
runtime_strategy = AllReduceStrategyType::UB;
803-
}
804-
else if (mStrategy == AllReduceStrategyType::NCCL)
805-
{
806-
runtime_strategy = AllReduceStrategyType::NCCL;
807-
}
808-
else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC)
809-
{
810-
runtime_strategy = AllReduceStrategyType::NCCL_SYMMETRIC;
811-
}
812-
else
813-
{
814-
// This is for DEBUG and BENCHMARK purpose. It will overried the strategy if AUTO is set.
815-
static char* ifForBenchMark = std::getenv("OVERRIDE_HEURISTIC_ALLREDUCE_STRATEGY");
816-
if (ifForBenchMark != nullptr)
817-
{
818-
runtime_strategy = mStrategy;
819-
}
820-
else
821-
{
822-
runtime_strategy = selectImplementation(seq_len, size, mGroup.size(), mType);
823-
}
824-
}
825-
return runtime_strategy;
826-
}
827-
828-
void logRunTimeStrategy(AllReduceStrategyType strategy, int rank)
829-
{
830-
switch (strategy)
831-
{
832-
case AllReduceStrategyType::NCCL:
833-
{
834-
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL", rank);
835-
break;
836-
}
837-
case AllReduceStrategyType::NCCL_SYMMETRIC:
838-
{
839-
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC", rank);
840-
break;
841-
}
842-
case AllReduceStrategyType::MIN_LATENCY:
843-
{
844-
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: MIN_LATENCY", rank);
845-
break;
846-
}
847-
case AllReduceStrategyType::UB:
848-
{
849-
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UB", rank);
850-
break;
851-
}
852-
case AllReduceStrategyType::LOWPRECISION:
853-
{
854-
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank);
855-
break;
856-
}
857-
default: TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UNKNOWN: %d", rank, strategy); break;
858-
}
859-
}
860-
861800
void initGroupTopology()
862801
{
863802
static std::map<std::set<int>, std::tuple<bool, bool>> cache;
@@ -985,134 +924,60 @@ class AllreduceOp
985924
}
986925
}
987926

988-
bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size, bool is_auto)
927+
AllReduceStrategyType selectImplementation(size_t seq_len, size_t hidden_size)
989928
{
990-
// If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type.
991-
if (message_size_bytes > max_workspace_size)
992-
{
993-
if (!is_auto)
994-
{
995-
TLLM_LOG_WARNING(
996-
"Since messageSize is greater than maxWorkspaceSize, fallback to AllReduceStrategy: NCCL");
997-
}
998-
return true;
999-
}
1000-
1001-
// If Peer to Peer is not supported, fallback to NCCL.
1002-
if (!mIsP2PSupported)
1003-
{
1004-
if (!is_auto)
1005-
{
1006-
TLLM_LOG_WARNING("Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL");
1007-
}
1008-
return true;
1009-
}
1010-
1011-
// If NVLINK is not supported, fallback to NCCL.
1012-
if (!mIsNVLINKSupported)
929+
if (mStrategy != AllReduceStrategyType::AUTO)
1013930
{
1014-
if (!is_auto)
931+
// For UB,NCCL,NCCL_SYMMETRIC, the correctness of the strategy dispatching is guaranteed by the user.
932+
if (mStrategy == AllReduceStrategyType::UB || mStrategy == AllReduceStrategyType::NCCL
933+
|| mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC)
1015934
{
1016-
TLLM_LOG_WARNING("Since NVLINK not supported, fallback to AllReduceStrategy: NCCL");
935+
return mStrategy;
1017936
}
1018-
return true;
1019937
}
1020-
return false;
1021-
}
1022938

1023-
AllReduceStrategyType selectImplementation(
1024-
size_t seq_len, size_t message_size, int world_size, nvinfer1::DataType type)
1025-
{
939+
// For ONESHOT, TWOSHOT, LOWPRECISION, fallback is allowed.
940+
auto const message_size = seq_len * hidden_size;
1026941

1027-
if (isUsingLowPrecision(message_size))
942+
// Check if LOWPRECISION is supported.
943+
if (isUsingLowPrecision(hidden_size))
1028944
{
1029945
return AllReduceStrategyType::LOWPRECISION;
1030946
}
1031-
else
1032-
{
1033-
if (mStrategy == AllReduceStrategyType::LOWPRECISION)
1034-
{
1035-
mStrategy = AllReduceStrategyType::AUTO;
1036-
}
1037-
}
1038947

1039-
// Check that heuristic is only applied when AUTO is set.
1040-
// Use Auto select
1041-
bool const is_auto = (mStrategy == AllReduceStrategyType::AUTO);
1042-
auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize(type);
948+
auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize(mType);
1043949
auto const max_workspace_size
1044-
= tensorrt_llm::utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(world_size);
950+
= tensorrt_llm::utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(mGroup.size());
1045951

1046-
if (ifFallbackToNCCL(seq_len, message_size_bytes, max_workspace_size, is_auto))
952+
if (ifFallbackToNCCL(seq_len, message_size_bytes, max_workspace_size))
1047953
{
1048954
return AllReduceStrategyType::NCCL;
1049955
}
1050956

1051957
// This rule based heuristic only chooses between NCCL and MIN_LATENCY strategies.
1052-
1053-
// Heurisitic will only be applied on NONE and RESIDUAL_RMS_NORM fusion types.
1054-
// Because NCCL might be faster on some large messageSize cases.
1055-
// Otherwise, MIN_LATENCY strategy will be directly returned due to more fusions it can support.
1056-
// TODO: NCCL AllReduce + subsequent quantization ops (as fallback) can also support the fusion types.
1057-
// This should be compared with MIN_LATENCY fused kernels to determine the best strategy.
1058-
switch (mOp)
1059-
{
1060-
case AllReduceFusionOp::NONE:
1061-
case AllReduceFusionOp::RESIDUAL_RMS_NORM: break;
1062-
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8:
1063-
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8:
1064-
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4:
1065-
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: return AllReduceStrategyType::MIN_LATENCY;
1066-
// Suppose NCCL has fallback implementations for all fusion types.
1067-
default: return AllReduceStrategyType::NCCL;
1068-
}
1069-
1070-
// Check mOp to be supported by the heuristic.
1071-
TORCH_CHECK(mOp == AllReduceFusionOp::NONE || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM,
1072-
"Only NONE and RESIDUAL_RMS_NORM are supported for NCCL/MIN_LATENCY heuristic.");
1073-
1074-
// Default to NCCL.
1075-
AllReduceStrategyType strategy = AllReduceStrategyType::NCCL;
1076-
1077-
// Currently we will not remove ONESHOT and TWOSHOT from the strategy list
1078-
// But torch flow user should not use them, but use AUTO or MIN_LATENCY instead.
1079-
// NOTICE: When a fusion type is not supported by the corresponding strategy but strategy is not AUTO,
1080-
// user should guarantee the correctness of the fusion pattern dispatching.
1081-
if (!is_auto)
958+
// From this point, all fusion patterns are supported by all these strategies: NCCL, ONESHOT, TWOSHOT and
959+
// MIN_LATENCY.
960+
if (mStrategy != AllReduceStrategyType::AUTO)
1082961
{
1083-
if (mStrategy == AllReduceStrategyType::ONESHOT || mStrategy == AllReduceStrategyType::TWOSHOT)
1084-
{
1085-
strategy = AllReduceStrategyType::MIN_LATENCY;
1086-
}
1087-
else
1088-
{
1089-
strategy = mStrategy;
1090-
}
962+
return mStrategy;
1091963
}
1092-
else if (world_size <= 2)
964+
else
1093965
{
1094-
strategy = AllReduceStrategyType::MIN_LATENCY;
966+
return tensorrt_llm::utils::customAllReduceUtils::selectStrategyLookUpTable(
967+
seq_len, hidden_size, mOp, mGroup.size());
1095968
}
1096-
else
969+
return AllReduceStrategyType::NCCL;
970+
}
971+
972+
bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size)
973+
{
974+
// If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type.
975+
if (message_size_bytes > max_workspace_size || !mIsP2PSupported || !mIsNVLINKSupported)
1097976
{
1098-
static char* threshold_ptr = std::getenv("ALLREDUCE_AUTO_HEURISTIC_MIN_LATENCY_THRESHOLD_TOKEN_NUM");
1099-
size_t threshold = 128;
1100-
if (threshold_ptr)
1101-
{
1102-
threshold = static_cast<size_t>(std::atoi(threshold_ptr));
1103-
}
1104-
// Generally, NCCL is faster than MIN_LATENCY when the token number is greater than 256. I conservatively
1105-
// set the threshold here to 128 tokens.
1106-
if (seq_len > threshold)
1107-
{
1108-
strategy = AllReduceStrategyType::NCCL;
1109-
}
1110-
else
1111-
{
1112-
strategy = AllReduceStrategyType::MIN_LATENCY;
1113-
}
977+
return true;
1114978
}
1115-
return strategy;
979+
980+
return false;
1116981
}
1117982

1118983
bool isUsingLowPrecision(size_t message_size) const noexcept

0 commit comments

Comments
 (0)