diff --git a/cpp/include/tensorrt_llm/runtime/worldConfig.h b/cpp/include/tensorrt_llm/runtime/worldConfig.h index ca6c878378c..9ff2d0970df 100644 --- a/cpp/include/tensorrt_llm/runtime/worldConfig.h +++ b/cpp/include/tensorrt_llm/runtime/worldConfig.h @@ -104,12 +104,14 @@ class WorldConfig [[nodiscard]] SizeType32 constexpr getTensorParallelRank() const noexcept { - return mRank % mTensorParallelism; + // Layout: pp is outermost, then tp, then cp is innermost (consecutive). + return (mRank % (mTensorParallelism * mContextParallelism)) / mContextParallelism; } [[nodiscard]] SizeType32 constexpr getContextParallelRank() const noexcept { - return (mRank % (mTensorParallelism * mContextParallelism)) / mTensorParallelism; + // Layout: pp is outermost, then tp, then cp is innermost (consecutive). + return mRank % mContextParallelism; } [[nodiscard]] SizeType32 constexpr getLocalRank() const noexcept diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 66083a55eb7..96eec0fd04c 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -154,7 +154,8 @@ bool CacheFormatter::needSendCache( return true; } - int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; + int selfCpSize = selfConfig.getParallelConfig().mContextParallelism; + int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize; int selfTpRankInDpGroup = selfTpRank; if (selfConfig.getParallelConfig().mEnableAttentionDP) { diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index b2aaa4983ec..b2a60a3eda0 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -60,7 +60,8 @@ std::vector MLACacheFormatter::pickRecvConnections( bool MLACacheFormatter::needSendCache( CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx) { - int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; + int selfCpSize = selfConfig.getParallelConfig().mContextParallelism; + int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize; int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize diff --git a/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu b/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu index e01228c5d3a..e16e00670b7 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu +++ b/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu @@ -107,9 +107,9 @@ TargetRanksInfo TargetRanksInfoForDP( auto const peerCPNum = peerParConfig.mContextParallelism; auto const selfCPNum = selfParConfig.mContextParallelism; - auto const selfTPRank = selfRank % selfTPNum; + auto const selfCPRank = selfRank % selfCPNum; + auto const selfTPRank = (selfRank % (selfTPNum * selfCPNum)) / selfCPNum; auto const selfPPRank = selfRank / (selfTPNum * selfCPNum); - auto const selfCPRank = (selfRank % (selfTPNum * selfCPNum)) / selfTPNum; int peerPPRankStart = 0; int mDomainPPSize = 1; @@ -205,13 +205,14 @@ TargetRanksInfo TargetRanksInfoForDP( } std::vector retRanks; - for (int i = peerTPRankStart; i < peerTPRankEnd; i++) + for (int i = peerCPRankStart; i < peerCPRankEnd; i++) { - for (int j = peerCPRankStart; j < peerCPRankEnd; j++) + for (int j = peerTPRankStart; j < peerTPRankEnd; j++) { for (int k = peerPPRankStart; k < peerPPRankEnd; k++) { - int irank = (k * peerTPNum * peerCPNum) + (j * peerTPNum) + i; + // Rank formula: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank. + int irank = (k * peerTPNum * peerCPNum) + (j * peerCPNum) + i; retRanks.push_back(irank); } } diff --git a/cpp/tensorrt_llm/runtime/worldConfig.cpp b/cpp/tensorrt_llm/runtime/worldConfig.cpp index 396c2bd6d09..3ca85269296 100644 --- a/cpp/tensorrt_llm/runtime/worldConfig.cpp +++ b/cpp/tensorrt_llm/runtime/worldConfig.cpp @@ -142,6 +142,9 @@ WorldConfig WorldConfig::mpi(SizeType32 gpusPerNode, std::optional t std::vector WorldConfig::getPipelineParallelGroup() const { + // Layout: pp is outermost, then tp, then cp is innermost (consecutive). + // rank = ppRank * (tp * cp) + tpRank * cp + cpRank + // PP group: all ranks with same (tpRank, cpRank) but different ppRank. auto const pp = getPipelineParallelism(); auto const tp = getTensorParallelism(); auto const cp = getContextParallelism(); @@ -157,29 +160,35 @@ std::vector WorldConfig::getPipelineParallelGroup() const std::vector WorldConfig::getTensorParallelGroup() const { + // Layout: pp is outermost, then tp, then cp is innermost (consecutive). + // rank = ppRank * (tp * cp) + tpRank * cp + cpRank + // TP group: all ranks with same (ppRank, cpRank) but different tpRank. auto const tp = getTensorParallelism(); + auto const cp = getContextParallelism(); auto const rank = getRank(); auto const tpRank = getTensorParallelRank(); std::vector group; group.reserve(tp); for (SizeType32 idx = 0; idx < tp; idx++) { - group.push_back(rank - tpRank + idx); + group.push_back(rank - tpRank * cp + idx * cp); } return group; } std::vector WorldConfig::getContextParallelGroup() const { + // Layout: pp is outermost, then tp, then cp is innermost (consecutive). + // rank = ppRank * (tp * cp) + tpRank * cp + cpRank + // CP group: all ranks with same (ppRank, tpRank) but different cpRank. auto const cp = getContextParallelism(); - auto const tp = getTensorParallelism(); - auto const pp = getPipelineParallelism(); auto const rank = getRank(); + auto const cpRank = getContextParallelRank(); std::vector group; group.reserve(cp); for (SizeType32 idx = 0; idx < cp; idx++) { - group.push_back(rank + cp % (tp * pp)); + group.push_back(rank - cpRank + idx); } return group; } diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index 782a10d0cb8..5e7528b8dd4 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -2029,16 +2029,16 @@ TEST(targetTest, CacheStateNODP) tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 1}; tr::WorldConfig const genWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 2}; verifyContext( - /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2}, + /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 3}, + /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 6}, + /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 5}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {5, 7}, + /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {6, 7}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); } @@ -2047,19 +2047,19 @@ TEST(targetTest, CacheStateNODP) tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 1}; tr::WorldConfig const genWC{/*tpSize*/ 4, /*ppSize*/ 2, /*cpSize*/ 2}; verifyContext( - /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 1, 5}, + /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2, 1, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 6, 3, 7}, + /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 6, 5, 7}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {8, 12, 9, 13}, + /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {8, 10, 9, 11}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {10, 14, 11, 15}, + /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {12, 14, 13, 15}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); } @@ -2069,16 +2069,16 @@ TEST(targetTest, CacheStateNODP) tr::WorldConfig const contextWC{/*tpSize*/ 4, /*ppSize*/ 1, /*cpSize*/ 1}; tr::WorldConfig const genWC{/*tpSize*/ 2, /*ppSize*/ 1, /*cpSize*/ 2}; verifyContext( - /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2}, /*expectPPDomain*/ 1, + /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2}, /*expectPPDomain*/ 1, + /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ false); verifyContext( - /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 3}, /*expectPPDomain*/ 1, + /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 3}, /*expectPPDomain*/ 1, + /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ false); } @@ -2087,19 +2087,19 @@ TEST(targetTest, CacheStateNODP) tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 1}; tr::WorldConfig const genWC{/*tpSize*/ 2, /*ppSize*/ 4, /*cpSize*/ 2}; verifyContext( - /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 2, 6}, + /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 1, 5}, /*expectPPDomain*/ 2, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 5, 3, 7}, + /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 6, 3, 7}, /*expectPPDomain*/ 2, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {8, 12, 10, 14}, + /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {8, 12, 9, 13}, /*expectPPDomain*/ 2, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {9, 13, 11, 15}, + /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {10, 14, 11, 15}, /*expectPPDomain*/ 2, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); } @@ -2109,28 +2109,28 @@ TEST(targetTest, CacheStateNODP) tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 4, /*cpSize*/ 1}; tr::WorldConfig const genWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 2}; verifyContext( - /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2}, /*expectPPDomain*/ 1, + /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 3}, /*expectPPDomain*/ 1, + /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2}, /*expectPPDomain*/ 1, + /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 3}, /*expectPPDomain*/ 1, + /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 4, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 6}, /*expectPPDomain*/ 1, + /*contextRank*/ 4, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 5}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 5, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {5, 7}, /*expectPPDomain*/ 1, + /*contextRank*/ 5, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {6, 7}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 6, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 6}, /*expectPPDomain*/ 1, + /*contextRank*/ 6, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 5}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 7, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {5, 7}, /*expectPPDomain*/ 1, + /*contextRank*/ 7, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {6, 7}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); } @@ -2139,28 +2139,28 @@ TEST(targetTest, CacheStateNODP) tr::WorldConfig const contextWC{/*tpSize*/ 4, /*ppSize*/ 2, /*cpSize*/ 1}; tr::WorldConfig const genWC{/*tpSize*/ 2, /*ppSize*/ 1, /*cpSize*/ 2}; verifyContext( - /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2}, /*expectPPDomain*/ 1, + /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2}, /*expectPPDomain*/ 1, + /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ false); verifyContext( - /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 3}, /*expectPPDomain*/ 1, + /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 3}, /*expectPPDomain*/ 1, + /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ false); verifyContext( - /*contextRank*/ 4, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2}, /*expectPPDomain*/ 1, + /*contextRank*/ 4, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 5, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2}, /*expectPPDomain*/ 1, + /*contextRank*/ 5, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ false); verifyContext( - /*contextRank*/ 6, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 3}, /*expectPPDomain*/ 1, + /*contextRank*/ 6, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 7, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 3}, /*expectPPDomain*/ 1, + /*contextRank*/ 7, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ false); } @@ -2169,19 +2169,19 @@ TEST(targetTest, CacheStateNODP) tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 1}; tr::WorldConfig const genWC{/*tpSize*/ 4, /*ppSize*/ 1, /*cpSize*/ 2}; verifyContext( - /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 1, 5}, + /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2, 1, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 6, 3, 7}, + /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 6, 5, 7}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 1, 5}, + /*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2, 1, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); verifyContext( - /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 6, 3, 7}, + /*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 6, 5, 7}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true); } diff --git a/cpp/tests/unit_tests/runtime/worldConfigTest.cpp b/cpp/tests/unit_tests/runtime/worldConfigTest.cpp index 03d6eef9b03..66157d6a3ce 100644 --- a/cpp/tests/unit_tests/runtime/worldConfigTest.cpp +++ b/cpp/tests/unit_tests/runtime/worldConfigTest.cpp @@ -56,3 +56,139 @@ TEST(WorldConfig, DeviceIds) EXPECT_NO_THROW(tr::WorldConfig(tensorParallelism, pipelineParallelism, contextParallelism, rank, gpusPerNode, std::vector{0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12})); } + +// Test for parallel rank calculations and group membership. +// Layout: pp is outermost, then tp, then cp is innermost (consecutive). +// rank = ppRank * (tp * cp) + tpRank * cp + cpRank +TEST(WorldConfig, ParallelRanks) +{ + auto constexpr tp = 2; + auto constexpr pp = 2; + auto constexpr cp = 2; + auto constexpr gpusPerNode = 16; + + // Test all 8 ranks in a tp=2, pp=2, cp=2 configuration. + // Rank 0: ppRank=0, tpRank=0, cpRank=0 + { + tr::WorldConfig config(tp, pp, cp, 0, gpusPerNode); + EXPECT_EQ(config.getPipelineParallelRank(), 0); + EXPECT_EQ(config.getTensorParallelRank(), 0); + EXPECT_EQ(config.getContextParallelRank(), 0); + } + // Rank 1: ppRank=0, tpRank=0, cpRank=1 + { + tr::WorldConfig config(tp, pp, cp, 1, gpusPerNode); + EXPECT_EQ(config.getPipelineParallelRank(), 0); + EXPECT_EQ(config.getTensorParallelRank(), 0); + EXPECT_EQ(config.getContextParallelRank(), 1); + } + // Rank 2: ppRank=0, tpRank=1, cpRank=0 + { + tr::WorldConfig config(tp, pp, cp, 2, gpusPerNode); + EXPECT_EQ(config.getPipelineParallelRank(), 0); + EXPECT_EQ(config.getTensorParallelRank(), 1); + EXPECT_EQ(config.getContextParallelRank(), 0); + } + // Rank 3: ppRank=0, tpRank=1, cpRank=1 + { + tr::WorldConfig config(tp, pp, cp, 3, gpusPerNode); + EXPECT_EQ(config.getPipelineParallelRank(), 0); + EXPECT_EQ(config.getTensorParallelRank(), 1); + EXPECT_EQ(config.getContextParallelRank(), 1); + } + // Rank 4: ppRank=1, tpRank=0, cpRank=0 + { + tr::WorldConfig config(tp, pp, cp, 4, gpusPerNode); + EXPECT_EQ(config.getPipelineParallelRank(), 1); + EXPECT_EQ(config.getTensorParallelRank(), 0); + EXPECT_EQ(config.getContextParallelRank(), 0); + } + // Rank 5: ppRank=1, tpRank=0, cpRank=1 + { + tr::WorldConfig config(tp, pp, cp, 5, gpusPerNode); + EXPECT_EQ(config.getPipelineParallelRank(), 1); + EXPECT_EQ(config.getTensorParallelRank(), 0); + EXPECT_EQ(config.getContextParallelRank(), 1); + } + // Rank 6: ppRank=1, tpRank=1, cpRank=0 + { + tr::WorldConfig config(tp, pp, cp, 6, gpusPerNode); + EXPECT_EQ(config.getPipelineParallelRank(), 1); + EXPECT_EQ(config.getTensorParallelRank(), 1); + EXPECT_EQ(config.getContextParallelRank(), 0); + } + // Rank 7: ppRank=1, tpRank=1, cpRank=1 + { + tr::WorldConfig config(tp, pp, cp, 7, gpusPerNode); + EXPECT_EQ(config.getPipelineParallelRank(), 1); + EXPECT_EQ(config.getTensorParallelRank(), 1); + EXPECT_EQ(config.getContextParallelRank(), 1); + } +} + +TEST(WorldConfig, ParallelGroups) +{ + auto constexpr tp = 2; + auto constexpr pp = 2; + auto constexpr cp = 2; + auto constexpr gpusPerNode = 16; + + // Test group membership for rank 3 (ppRank=0, tpRank=1, cpRank=1). + // CP group: all ranks with same (ppRank=0, tpRank=1) = [2, 3]. + // TP group: all ranks with same (ppRank=0, cpRank=1) = [1, 3]. + // PP group: all ranks with same (tpRank=1, cpRank=1) = [3, 7]. + { + tr::WorldConfig config(tp, pp, cp, 3, gpusPerNode); + auto cpGroup = config.getContextParallelGroup(); + auto tpGroup = config.getTensorParallelGroup(); + auto ppGroup = config.getPipelineParallelGroup(); + + EXPECT_EQ(cpGroup, (std::vector{2, 3})); + EXPECT_EQ(tpGroup, (std::vector{1, 3})); + EXPECT_EQ(ppGroup, (std::vector{3, 7})); + } + + // Test group membership for rank 5 (ppRank=1, tpRank=0, cpRank=1). + // CP group: all ranks with same (ppRank=1, tpRank=0) = [4, 5]. + // TP group: all ranks with same (ppRank=1, cpRank=1) = [5, 7]. + // PP group: all ranks with same (tpRank=0, cpRank=1) = [1, 5]. + { + tr::WorldConfig config(tp, pp, cp, 5, gpusPerNode); + auto cpGroup = config.getContextParallelGroup(); + auto tpGroup = config.getTensorParallelGroup(); + auto ppGroup = config.getPipelineParallelGroup(); + + EXPECT_EQ(cpGroup, (std::vector{4, 5})); + EXPECT_EQ(tpGroup, (std::vector{5, 7})); + EXPECT_EQ(ppGroup, (std::vector{1, 5})); + } +} + +TEST(WorldConfig, ParallelGroupsLargerConfig) +{ + // Test with tp=2, pp=2, cp=4, worldSize=16. + auto constexpr tp = 2; + auto constexpr pp = 2; + auto constexpr cp = 4; + auto constexpr gpusPerNode = 16; + + // Rank 9: ppRank = 9 / (2*4) = 1, tpRank = (9 % 8) / 4 = 0, cpRank = 9 % 4 = 1. + // CP group: ranks with same (ppRank=1, tpRank=0) = [8, 9, 10, 11]. + // TP group: ranks with same (ppRank=1, cpRank=1) = [9, 13]. + // PP group: ranks with same (tpRank=0, cpRank=1) = [1, 9]. + { + tr::WorldConfig config(tp, pp, cp, 9, gpusPerNode); + + EXPECT_EQ(config.getPipelineParallelRank(), 1); + EXPECT_EQ(config.getTensorParallelRank(), 0); + EXPECT_EQ(config.getContextParallelRank(), 1); + + auto cpGroup = config.getContextParallelGroup(); + auto tpGroup = config.getTensorParallelGroup(); + auto ppGroup = config.getPipelineParallelGroup(); + + EXPECT_EQ(cpGroup, (std::vector{8, 9, 10, 11})); + EXPECT_EQ(tpGroup, (std::vector{9, 13})); + EXPECT_EQ(ppGroup, (std::vector{1, 9})); + } +} diff --git a/tensorrt_llm/_torch/device_mesh.py b/tensorrt_llm/_torch/device_mesh.py index b5034f8ef72..bdc9d94f8f5 100644 --- a/tensorrt_llm/_torch/device_mesh.py +++ b/tensorrt_llm/_torch/device_mesh.py @@ -118,8 +118,10 @@ def build_mesh(self): "DeviceMesh creation requested but torch.distributed process group " "has not been initialised.") - dims = ["cp", "pp"] - shape = [self.cp_size, self.pp_size] + # Dimensions go from slowest-varying (outermost) to fastest-varying (innermost). + # Layout: pp is outermost, then tp, then cp is innermost (consecutive). + dims = ["pp"] + shape = [self.pp_size] if self.moe_ep_size > 1: dims += ["moe_tp", "moe_ep"] @@ -128,6 +130,9 @@ def build_mesh(self): dims += ["tp"] shape += [self.tp_size] + dims += ["cp"] + shape += [self.cp_size] + cls.device_mesh = init_device_mesh( "cuda", mesh_shape=tuple(shape), diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index 386d18da747..818ee33dce9 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -292,18 +292,16 @@ def has_cp(self): return self.cp_size > 1 def prev_cp_rank(self): - p = self.rank - self.tp_size - if p // (self.tp_size * self.cp_size) < self.rank // (self.tp_size * - self.cp_size): - return p + self.tp_size * self.cp_size - return p + # cp ranks are consecutive, so prev is rank - 1 with wraparound within cp group. + if self.cp_rank == 0: + return self.rank + self.cp_size - 1 + return self.rank - 1 def next_cp_rank(self): - p = self.rank + self.tp_size - if p // (self.tp_size * self.cp_size) > self.rank // (self.tp_size * - self.cp_size): - return p - self.tp_size * self.cp_size - return p + # cp ranks are consecutive, so next is rank + 1 with wraparound within cp group. + if self.cp_rank == self.cp_size - 1: + return self.rank - self.cp_size + 1 + return self.rank + 1 def has_moe_cluster(self): return self.moe_cluster_size > 1 @@ -378,17 +376,17 @@ class Mapping(MappingBase): A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1 - 2 tp groups: + 4 cp groups: - - [0, 1, 2, 3] - - [4, 5, 6, 7] + - [0, 1] + - [2, 3] + - [4, 5] + - [6, 7] - 4 cp groups: + 2 tp groups: - - [0, 4] - - [1, 5] - - [2, 6] - - [3, 7] + - [0, 2, 4, 6] + - [1, 3, 5, 7] A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4 @@ -437,23 +435,23 @@ class Mapping(MappingBase): 2 nodes with 8 GPUs, tp_size 2, pp_size 2, cp_size 2 - 4 tp groups: + 4 cp groups: - [0, 1] - [2, 3] - [4, 5] - [6, 7] + 4 tp groups: + - [0, 2] + - [1, 3] + - [4, 6] + - [5, 7] + 4 pp groups: - [0, 4] - [1, 5] - [2, 6] - [3, 7] - - 4 cp groups: - - [0, 2] - - [1, 3] - - [4, 6] - - [5, 7] """ def __new__(cls, *args, **kwargs): @@ -551,7 +549,7 @@ def __init__(self, *args, **kwargs): @property def tp_rank(self) -> int: - return self.rank % self.tp_size + return self.rank % (self.tp_size * self.cp_size) // self.cp_size @property def pp_rank(self) -> int: @@ -559,7 +557,7 @@ def pp_rank(self) -> int: @property def cp_rank(self) -> int: - return self.rank % (self.tp_size * self.cp_size) // self.tp_size + return self.rank % self.cp_size @property def tp_group(self) -> List[int]: @@ -567,7 +565,7 @@ def tp_group(self) -> List[int]: @property def pp_group(self) -> List[int]: - return self.pp_groups[self.cp_rank * self.tp_size + self.tp_rank] + return self.pp_groups[self.tp_rank * self.cp_size + self.cp_rank] @property def cp_group(self) -> List[int]: @@ -598,20 +596,20 @@ def _init_parallel_groups(self): ranks = range(i, self.world_size, self.tp_size * self.cp_size) self.pp_groups.append(list(ranks)) - # init cp group + # init cp group (consecutive ranks within each tp slice). for i in range(self.pp_size): for j in range(self.tp_size): - ranks = range(i * self.tp_size * self.cp_size + j, - (i + 1) * self.tp_size * self.cp_size + j, - self.tp_size) + ranks = range( + i * self.tp_size * self.cp_size + j * self.cp_size, + i * self.tp_size * self.cp_size + (j + 1) * self.cp_size) self.cp_groups.append(list(ranks)) - # init tp group + # init tp group (interleaved ranks with stride of cp_size). for i in range(self.pp_size): for j in range(self.cp_size): - ranks = range( - i * self.tp_size * self.cp_size + j * self.tp_size, - i * self.tp_size * self.cp_size + (j + 1) * self.tp_size) + ranks = range(i * self.tp_size * self.cp_size + j, + (i + 1) * self.tp_size * self.cp_size + j, + self.cp_size) self.tp_groups.append(list(ranks)) # init moe tp group diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 03c1ee60ae5..39866156384 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -740,10 +740,11 @@ def from_checkpoint( rank = config.mapping.rank if config.mapping.cp_size > 1: - # tp_cp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt - tp_size = config.mapping.tp_size + # cp_tp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt. cp_size = config.mapping.cp_size - rank = rank % tp_size + rank // (tp_size * cp_size) * tp_size + # rank = pp_rank × tp_size × cp_size + tp_rank × cp_size + cp_rank. + # rank // cp_size is equivalent to pp_rank × tp_size + tp_rank. + rank = rank // cp_size weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors') assert os.path.isfile(weights_path) diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 8c7111f4490..664e6349682 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -872,8 +872,9 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn): task.evaluate(llm) @pytest.mark.skip_less_device(8) - @pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)], - ids=["pp1tp2cp2", "pp2tp1cp2"]) + @pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2), + (2, 1, 2)], + ids=["pp1tp1cp4", "pp1tp2cp2", "pp2tp1cp2"]) @pytest.mark.parametrize("cuda_graph_config", [ None, { diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 93a5a8120c8..265acc23463 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -540,6 +540,14 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp1tp1cp4] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1tp1cp4] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 04c2285858d..bc8f3978ffe 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -67,6 +67,8 @@ l0_dgx_b200: orchestrator: mpi tests: - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60) @@ -94,6 +96,8 @@ l0_dgx_b200: orchestrator: mpi tests: - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60) diff --git a/tests/unittest/others/test_mapping.py b/tests/unittest/others/test_mapping.py index bc9839239bf..b6ab8d9f8c1 100644 --- a/tests/unittest/others/test_mapping.py +++ b/tests/unittest/others/test_mapping.py @@ -57,27 +57,27 @@ def test_mapping(self): self.assertEqual(len(m.tp_groups), 4) self.assertEqual(len(m.pp_groups), 4) self.assertEqual(len(m.cp_groups), 4) - self.assertEqual(m.tp_group, [2, 3]) + self.assertEqual(m.tp_group, [1, 3]) self.assertEqual(m.pp_group, [3, 7]) - self.assertEqual(m.cp_group, [1, 3]) + self.assertEqual(m.cp_group, [2, 3]) self.assertTrue(m.is_first_pp_rank()) self.assertFalse(m.is_last_pp_rank()) self.assertFalse(m.is_first_cp_rank()) self.assertTrue(m.is_last_cp_rank()) self.assertEqual(m.prev_pp_rank(), 7) self.assertEqual(m.next_pp_rank(), 7) - self.assertEqual(m.prev_cp_rank(), 1) - self.assertEqual(m.next_cp_rank(), 1) + self.assertEqual(m.prev_cp_rank(), 2) + self.assertEqual(m.next_cp_rank(), 2) m = Mapping(world_size=16, rank=9, tp_size=2, pp_size=2, cp_size=4) - self.assertEqual(m.tp_group, [8, 9]) + self.assertEqual(m.tp_group, [9, 13]) self.assertEqual(m.pp_group, [1, 9]) - self.assertEqual(m.cp_group, [9, 11, 13, 15]) + self.assertEqual(m.cp_group, [8, 9, 10, 11]) self.assertFalse(m.is_first_pp_rank()) self.assertTrue(m.is_last_pp_rank()) - self.assertTrue(m.is_first_cp_rank()) + self.assertFalse(m.is_first_cp_rank()) self.assertFalse(m.is_last_cp_rank()) self.assertEqual(m.prev_pp_rank(), 1) self.assertEqual(m.next_pp_rank(), 1) - self.assertEqual(m.prev_cp_rank(), 15) - self.assertEqual(m.next_cp_rank(), 11) + self.assertEqual(m.prev_cp_rank(), 8) + self.assertEqual(m.next_cp_rank(), 10)