@@ -273,14 +273,16 @@ class AllreduceOp
273273 {
274274 size_t size = input.numel ();
275275 size_t seq_len = input.size (0 );
276+ size_t hidden_size = input.size (-1 );
276277 size_t bytes_per_element = input.element_size ();
277278 TLLM_LOG_DEBUG (" All reduce message size is %zu" , size * bytes_per_element);
278279
279- AllReduceStrategyType runtime_strategy = getRuntimeStrategy (seq_len, size );
280+ AllReduceStrategyType runtime_strategy = selectImplementation (seq_len, hidden_size );
280281
281282 // Log runtime strategy
282283 auto const rank = getRank ();
283- logRunTimeStrategy (runtime_strategy, rank);
284+ TLLM_LOG_DEBUG (
285+ " AllReduceOp runtime strategy for rank %d: " + tensorrt_llm::kernels::toString (runtime_strategy), rank);
284286
285287 // Dispatch to different allreduce implementations
286288 switch (runtime_strategy)
@@ -584,10 +586,11 @@ class AllreduceOp
584586 allreduce_fusion_params.norm_out = nullptr ;
585587 allreduce_fusion_params.trigger_completion_at_end = trigger_completion_at_end;
586588
587- // Determine if using oneshot or twoshot allreduce kernel
589+ // Determine if using oneshot or twoshot allreduce kernel in case using MIN_LATENCY strategy.
588590 if (strategy == AllReduceStrategyType::MIN_LATENCY)
589591 {
590- allreduce_fusion_params.use_oneshot = seq_len <= tensorrt_llm::kernels::ar_fusion::kOneShotMaxToken ;
592+ allreduce_fusion_params.use_oneshot = seq_len <= tensorrt_llm::kernels::ar_fusion::kOneShotMaxToken
593+ || hidden_size < static_cast <int64_t >(tp_size);
591594 }
592595 else
593596 {
@@ -794,70 +797,6 @@ class AllreduceOp
794797 return {};
795798 }
796799
797- AllReduceStrategyType getRuntimeStrategy (size_t seq_len, size_t size)
798- {
799- AllReduceStrategyType runtime_strategy;
800- if (mStrategy == AllReduceStrategyType::UB)
801- {
802- runtime_strategy = AllReduceStrategyType::UB;
803- }
804- else if (mStrategy == AllReduceStrategyType::NCCL)
805- {
806- runtime_strategy = AllReduceStrategyType::NCCL;
807- }
808- else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC)
809- {
810- runtime_strategy = AllReduceStrategyType::NCCL_SYMMETRIC;
811- }
812- else
813- {
814- // This is for DEBUG and BENCHMARK purpose. It will overried the strategy if AUTO is set.
815- static char * ifForBenchMark = std::getenv (" OVERRIDE_HEURISTIC_ALLREDUCE_STRATEGY" );
816- if (ifForBenchMark != nullptr )
817- {
818- runtime_strategy = mStrategy ;
819- }
820- else
821- {
822- runtime_strategy = selectImplementation (seq_len, size, mGroup .size (), mType );
823- }
824- }
825- return runtime_strategy;
826- }
827-
828- void logRunTimeStrategy (AllReduceStrategyType strategy, int rank)
829- {
830- switch (strategy)
831- {
832- case AllReduceStrategyType::NCCL:
833- {
834- TLLM_LOG_DEBUG (" AllReducePlugin strategy for rank %d: NCCL" , rank);
835- break ;
836- }
837- case AllReduceStrategyType::NCCL_SYMMETRIC:
838- {
839- TLLM_LOG_DEBUG (" AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC" , rank);
840- break ;
841- }
842- case AllReduceStrategyType::MIN_LATENCY:
843- {
844- TLLM_LOG_DEBUG (" AllReducePlugin strategy for rank %d: MIN_LATENCY" , rank);
845- break ;
846- }
847- case AllReduceStrategyType::UB:
848- {
849- TLLM_LOG_DEBUG (" AllReducePlugin strategy for rank %d: UB" , rank);
850- break ;
851- }
852- case AllReduceStrategyType::LOWPRECISION:
853- {
854- TLLM_LOG_DEBUG (" AllReducePlugin strategy for rank %d: LOWPRECISION" , rank);
855- break ;
856- }
857- default : TLLM_LOG_DEBUG (" AllReducePlugin strategy for rank %d: UNKNOWN: %d" , rank, strategy); break ;
858- }
859- }
860-
861800 void initGroupTopology ()
862801 {
863802 static std::map<std::set<int >, std::tuple<bool , bool >> cache;
@@ -985,134 +924,60 @@ class AllreduceOp
985924 }
986925 }
987926
988- bool ifFallbackToNCCL (size_t seq_len, size_t message_size_bytes, size_t max_workspace_size, bool is_auto )
927+ AllReduceStrategyType selectImplementation (size_t seq_len, size_t hidden_size )
989928 {
990- // If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type.
991- if (message_size_bytes > max_workspace_size)
992- {
993- if (!is_auto)
994- {
995- TLLM_LOG_WARNING (
996- " Since messageSize is greater than maxWorkspaceSize, fallback to AllReduceStrategy: NCCL" );
997- }
998- return true ;
999- }
1000-
1001- // If Peer to Peer is not supported, fallback to NCCL.
1002- if (!mIsP2PSupported )
1003- {
1004- if (!is_auto)
1005- {
1006- TLLM_LOG_WARNING (" Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL" );
1007- }
1008- return true ;
1009- }
1010-
1011- // If NVLINK is not supported, fallback to NCCL.
1012- if (!mIsNVLINKSupported )
929+ if (mStrategy != AllReduceStrategyType::AUTO)
1013930 {
1014- if (!is_auto)
931+ // For UB,NCCL,NCCL_SYMMETRIC, the correctness of the strategy dispatching is guaranteed by the user.
932+ if (mStrategy == AllReduceStrategyType::UB || mStrategy == AllReduceStrategyType::NCCL
933+ || mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC)
1015934 {
1016- TLLM_LOG_WARNING ( " Since NVLINK not supported, fallback to AllReduceStrategy: NCCL " ) ;
935+ return mStrategy ;
1017936 }
1018- return true ;
1019937 }
1020- return false ;
1021- }
1022938
1023- AllReduceStrategyType selectImplementation (
1024- size_t seq_len, size_t message_size, int world_size, nvinfer1::DataType type)
1025- {
939+ // For ONESHOT, TWOSHOT, LOWPRECISION, fallback is allowed.
940+ auto const message_size = seq_len * hidden_size;
1026941
1027- if (isUsingLowPrecision (message_size))
942+ // Check if LOWPRECISION is supported.
943+ if (isUsingLowPrecision (hidden_size))
1028944 {
1029945 return AllReduceStrategyType::LOWPRECISION;
1030946 }
1031- else
1032- {
1033- if (mStrategy == AllReduceStrategyType::LOWPRECISION)
1034- {
1035- mStrategy = AllReduceStrategyType::AUTO;
1036- }
1037- }
1038947
1039- // Check that heuristic is only applied when AUTO is set.
1040- // Use Auto select
1041- bool const is_auto = (mStrategy == AllReduceStrategyType::AUTO);
1042- auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize (type);
948+ auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize (mType );
1043949 auto const max_workspace_size
1044- = tensorrt_llm::utils::customAllReduceUtils::getMaxRequiredWorkspaceSize (world_size );
950+ = tensorrt_llm::utils::customAllReduceUtils::getMaxRequiredWorkspaceSize (mGroup . size () );
1045951
1046- if (ifFallbackToNCCL (seq_len, message_size_bytes, max_workspace_size, is_auto ))
952+ if (ifFallbackToNCCL (seq_len, message_size_bytes, max_workspace_size))
1047953 {
1048954 return AllReduceStrategyType::NCCL;
1049955 }
1050956
1051957 // This rule based heuristic only chooses between NCCL and MIN_LATENCY strategies.
1052-
1053- // Heurisitic will only be applied on NONE and RESIDUAL_RMS_NORM fusion types.
1054- // Because NCCL might be faster on some large messageSize cases.
1055- // Otherwise, MIN_LATENCY strategy will be directly returned due to more fusions it can support.
1056- // TODO: NCCL AllReduce + subsequent quantization ops (as fallback) can also support the fusion types.
1057- // This should be compared with MIN_LATENCY fused kernels to determine the best strategy.
1058- switch (mOp )
1059- {
1060- case AllReduceFusionOp::NONE:
1061- case AllReduceFusionOp::RESIDUAL_RMS_NORM: break ;
1062- case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8:
1063- case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8:
1064- case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4:
1065- case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: return AllReduceStrategyType::MIN_LATENCY;
1066- // Suppose NCCL has fallback implementations for all fusion types.
1067- default : return AllReduceStrategyType::NCCL;
1068- }
1069-
1070- // Check mOp to be supported by the heuristic.
1071- TORCH_CHECK (mOp == AllReduceFusionOp::NONE || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM,
1072- " Only NONE and RESIDUAL_RMS_NORM are supported for NCCL/MIN_LATENCY heuristic." );
1073-
1074- // Default to NCCL.
1075- AllReduceStrategyType strategy = AllReduceStrategyType::NCCL;
1076-
1077- // Currently we will not remove ONESHOT and TWOSHOT from the strategy list
1078- // But torch flow user should not use them, but use AUTO or MIN_LATENCY instead.
1079- // NOTICE: When a fusion type is not supported by the corresponding strategy but strategy is not AUTO,
1080- // user should guarantee the correctness of the fusion pattern dispatching.
1081- if (!is_auto)
958+ // From this point, all fusion patterns are supported by all these strategies: NCCL, ONESHOT, TWOSHOT and
959+ // MIN_LATENCY.
960+ if (mStrategy != AllReduceStrategyType::AUTO)
1082961 {
1083- if (mStrategy == AllReduceStrategyType::ONESHOT || mStrategy == AllReduceStrategyType::TWOSHOT)
1084- {
1085- strategy = AllReduceStrategyType::MIN_LATENCY;
1086- }
1087- else
1088- {
1089- strategy = mStrategy ;
1090- }
962+ return mStrategy ;
1091963 }
1092- else if (world_size <= 2 )
964+ else
1093965 {
1094- strategy = AllReduceStrategyType::MIN_LATENCY;
966+ return tensorrt_llm::utils::customAllReduceUtils::selectStrategyLookUpTable (
967+ seq_len, hidden_size, mOp , mGroup .size ());
1095968 }
1096- else
969+ return AllReduceStrategyType::NCCL;
970+ }
971+
972+ bool ifFallbackToNCCL (size_t seq_len, size_t message_size_bytes, size_t max_workspace_size)
973+ {
974+ // If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type.
975+ if (message_size_bytes > max_workspace_size || !mIsP2PSupported || !mIsNVLINKSupported )
1097976 {
1098- static char * threshold_ptr = std::getenv (" ALLREDUCE_AUTO_HEURISTIC_MIN_LATENCY_THRESHOLD_TOKEN_NUM" );
1099- size_t threshold = 128 ;
1100- if (threshold_ptr)
1101- {
1102- threshold = static_cast <size_t >(std::atoi (threshold_ptr));
1103- }
1104- // Generally, NCCL is faster than MIN_LATENCY when the token number is greater than 256. I conservatively
1105- // set the threshold here to 128 tokens.
1106- if (seq_len > threshold)
1107- {
1108- strategy = AllReduceStrategyType::NCCL;
1109- }
1110- else
1111- {
1112- strategy = AllReduceStrategyType::MIN_LATENCY;
1113- }
977+ return true ;
1114978 }
1115- return strategy;
979+
980+ return false ;
1116981 }
1117982
1118983 bool isUsingLowPrecision (size_t message_size) const noexcept
0 commit comments