@@ -552,9 +552,10 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
552552 mCpSize = genCp;
553553 }
554554
555- mTpRank = mRankInInstance % mTpSize ;
555+ // Rank formula must match targetIRanks: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
556+ mCpRank = mRankInInstance % mCpSize ;
557+ mTpRank = (mRankInInstance % (mTpSize * mCpSize )) / mCpSize ;
556558 mPpRank = mRankInInstance / (mTpSize * mCpSize );
557- mCpRank = (mRankInInstance % (mTpSize * mCpSize )) / mTpSize ;
558559 mContextRankSize = contextRanks;
559560 mGenRankSize = genRanks;
560561 mContextTpSize = contextTp;
@@ -887,7 +888,16 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
887888 auto makeLlmRequestWithDP (SizeType32 length, LlmRequest::RequestIdType requestId, int contextDpRank)
888889 {
889890 constexpr SizeType32 maxNewTokens{1 };
890- texec::Request request{VecTokens (length), maxNewTokens};
891+ auto const tokensPerBlock = mContextCacheState ->getModelConfig ().mTokensPerBlock ;
892+
893+ std::optional<CPMetaData> cpMetaData;
894+ int seqLen = length;
895+ if (mCpSize > 1 )
896+ {
897+ cpMetaData.emplace (length, tokensPerBlock, mCpRank , mCpSize );
898+ seqLen = cpMetaData.value ().mSeqLenOnThisCPRank ;
899+ }
900+ texec::Request request{VecTokens (seqLen, seqLen), maxNewTokens};
891901
892902 auto state = std::make_unique<texec::DataTransceiverState>();
893903 state->setCommState (texec::kv_cache::CommState{*mContextCommState });
@@ -905,7 +915,6 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
905915 request.setContextPhaseParams (std::move (stats));
906916 auto llmRequestPtr = std::make_unique<LlmRequest>(requestId, std::move (request));
907917
908- std::optional<CPMetaData> cpMetaData;
909918 return std::make_unique<WrappedLlmRequest>(std::move (llmRequestPtr), cpMetaData);
910919 }
911920
@@ -1428,6 +1437,27 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
14281437 {
14291438 GTEST_SKIP () << " Temporarily skipping cache transceiver tests with NIXL and MOONCAKE backend for CP." ;
14301439 }
1440+ // Filter request lengths based on CP requirements.
1441+ // Each request must have at least one block per CP rank to be valid for CP tests.
1442+ std::vector<int > lenList = {60 , 30 , 60 , 10 };
1443+ if (genCp > 1 )
1444+ {
1445+ std::vector<int > updatedLenList;
1446+ for (auto len : lenList)
1447+ {
1448+ if (len > tokensPerBlock * (genCp - 1 ))
1449+ {
1450+ updatedLenList.push_back (len);
1451+ }
1452+ }
1453+ if (updatedLenList.empty ())
1454+ {
1455+ GTEST_SKIP () << " Skipping test because not even one request has one block per genCP rank. tokensPerBlock="
1456+ << tokensPerBlock << " , genCp=" << genCp;
1457+ }
1458+ lenList = updatedLenList;
1459+ }
1460+
14311461 setUpCommunicator (contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP);
14321462
14331463 if (mIsContext || mIsGeneration )
@@ -1438,7 +1468,7 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
14381468 setUpCacheTransceiver ();
14391469 std::vector<std::shared_ptr<WrappedLlmRequest>> requests;
14401470 int requestId = 0 ;
1441- for (auto len : { 60 , 30 , 60 , 10 } )
1471+ for (auto len : lenList )
14421472 {
14431473 requests.emplace_back (makeLlmRequestWithDP (len, requestId, requestId % contextTp));
14441474 requestId++;
@@ -1814,6 +1844,44 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1WithCPForMLA, AsymmetricalCacheTest,
18141844 /* generationDP*/ testing::Values(false ),
18151845 /* isWindow*/ testing::Values(false ), testing::Values(false ), testing::Values(0 ), testing::Values(128 )));
18161846
1847+ // Tests cases where there's non-trivial TP and PP on context side while non-trivial CP & DP on gen side.
1848+ INSTANTIATE_TEST_CASE_P (AsymmetricCaseTestWithCPAndDPForMLA0, AsymmetricalCacheTestWithDP,
1849+ testing::Combine (/* contextTp*/ testing::Values(1 , 2 ),
1850+ /* contextPp*/ testing::Values(1 , 2 ),
1851+ /* contextCp*/ testing::Values(1 ),
1852+ /* genTp*/ testing::Values(2 ),
1853+ /* genPp*/ testing::Values(1 ),
1854+ /* genCp*/ testing::Values(2 ),
1855+ /* numLayers*/ testing::Values(4 ),
1856+ /* numHeads*/ testing::Values(1 ),
1857+ /* sizePerHead*/ testing::Values(4 ),
1858+ /* tokensPerBlock*/ testing::Values(8 ),
1859+ /* dataType*/ testing::Values(nvinfer1::DataType::kFLOAT , nvinfer1::DataType::kINT8 ),
1860+ /* kvFactor*/ testing::Values(1 ),
1861+ /* isMLA*/ testing::Values(true ),
1862+ /* contextDP*/ testing::Values(false ),
1863+ /* generationDP*/ testing::Values(true ),
1864+ /* isWindow*/ testing::Values(false ), testing::Values(false ), testing::Values(0 ), testing::Values(128 )));
1865+
1866+ // Tests cases where there's non-trivial DP on context side while non-trivial CP & DP on gen side.
1867+ INSTANTIATE_TEST_CASE_P (AsymmetricCaseTestWithCPAndDPForMLA1, AsymmetricalCacheTestWithDP,
1868+ testing::Combine (/* contextTp*/ testing::Values(2 , 4 ),
1869+ /* contextPp*/ testing::Values(1 ),
1870+ /* contextCp*/ testing::Values(1 ),
1871+ /* genTp*/ testing::Values(2 ),
1872+ /* genPp*/ testing::Values(1 ),
1873+ /* genCp*/ testing::Values(2 ),
1874+ /* numLayers*/ testing::Values(4 ),
1875+ /* numHeads*/ testing::Values(1 ),
1876+ /* sizePerHead*/ testing::Values(4 ),
1877+ /* tokensPerBlock*/ testing::Values(8 ),
1878+ /* dataType*/ testing::Values(nvinfer1::DataType::kFLOAT , nvinfer1::DataType::kINT8 ),
1879+ /* kvFactor*/ testing::Values(1 ),
1880+ /* isMLA*/ testing::Values(true ),
1881+ /* contextDP*/ testing::Values(true ),
1882+ /* generationDP*/ testing::Values(true ),
1883+ /* isWindow*/ testing::Values(false ), testing::Values(false ), testing::Values(0 ), testing::Values(128 )));
1884+
18171885INSTANTIATE_TEST_CASE_P (AsymmetricCaseTestWithDPForMLA1, AsymmetricalCacheTestWithDP,
18181886 testing::Combine (testing::Values(1 , 2 ), testing::Values(1 , 2 ), testing::Values(1 ), testing::Values(1 , 2 ),
18191887 testing::Values(1 , 2 ), testing::Values(1 ), testing::Values(4 ), testing::Values(1 ), testing::Values(4 ),
@@ -2226,8 +2294,8 @@ TEST(targetTest, CacheStateContextDP)
22262294 auto const verifyContext = [&](int contextRank, int generationRank, std::vector<int > const & expectRanks,
22272295 int expectPPDomain, int expectTPDomain, bool expectNeedSend)
22282296 {
2229- int contextDPRank = contextRank % contextTP;
2230- int generationDPRank = generationRank % genTP;
2297+ int contextDPRank = ( contextRank % ( contextTP * contextCP)) / contextCP ;
2298+ int generationDPRank = ( generationRank % ( genTP * genCP)) / genCP ;
22312299 auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
22322300 : texec::kv_cache::CacheState::AttentionType::kDEFAULT ;
22332301
@@ -2239,12 +2307,12 @@ TEST(targetTest, CacheStateContextDP)
22392307 tokensPerBlock, genTP, genPP, genCP, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor,
22402308 genEnableDP, generationDPRank, genTP};
22412309
2242- auto const contextTragetInfo
2310+ auto const contextTargetInfo
22432311 = tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP (genCache, contextCache, contextRank);
22442312
2245- EXPECT_EQ (expectRanks, contextTragetInfo .mIRanks );
2246- EXPECT_EQ (expectPPDomain, contextTragetInfo .mDomainPPSize );
2247- EXPECT_EQ (expectTPDomain, contextTragetInfo .mDomainTPSize );
2313+ EXPECT_EQ (expectRanks, contextTargetInfo .mIRanks );
2314+ EXPECT_EQ (expectPPDomain, contextTargetInfo .mDomainPPSize );
2315+ EXPECT_EQ (expectTPDomain, contextTargetInfo .mDomainTPSize );
22482316 EXPECT_EQ (expectNeedSend, MLACacheFormatter::needSendCache (contextCache, genCache, contextRank));
22492317 };
22502318
@@ -2330,11 +2398,11 @@ TEST(targetTest, CacheStateContextDP)
23302398 contextTP = 1 ;
23312399 genTP = 2 ;
23322400
2333- auto const verfiyGeneration = [&](int contextRank, int generationRank, std::vector<int > const & expectRanks,
2401+ auto const verifyGeneration = [&](int contextRank, int generationRank, std::vector<int > const & expectRanks,
23342402 int expectPPDomain, int expectTPDomain)
23352403 {
2336- int contextDPRank = contextRank % contextTP;
2337- int generationDPRank = generationRank % genTP;
2404+ int contextDPRank = ( contextRank % ( contextTP * contextCP)) / contextCP ;
2405+ int generationDPRank = ( generationRank % ( genTP * genCP)) / genCP ;
23382406 auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
23392407 : texec::kv_cache::CacheState::AttentionType::kDEFAULT ;
23402408
@@ -2346,17 +2414,17 @@ TEST(targetTest, CacheStateContextDP)
23462414 tokensPerBlock, genTP, genPP, genCP, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor,
23472415 genEnableDP, generationDPRank, genTP};
23482416
2349- auto const contextTragetInfo
2417+ auto const contextTargetInfo
23502418 = tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP (contextCache, genCache, generationRank);
23512419
2352- EXPECT_EQ (expectRanks, contextTragetInfo .mIRanks );
2353- EXPECT_EQ (expectPPDomain, contextTragetInfo .mDomainPPSize );
2354- EXPECT_EQ (expectTPDomain, contextTragetInfo .mDomainTPSize );
2420+ EXPECT_EQ (expectRanks, contextTargetInfo .mIRanks );
2421+ EXPECT_EQ (expectPPDomain, contextTargetInfo .mDomainPPSize );
2422+ EXPECT_EQ (expectTPDomain, contextTargetInfo .mDomainTPSize );
23552423 };
23562424
2357- verfiyGeneration (
2425+ verifyGeneration (
23582426 /* contextRank*/ 0 , /* generationRank*/ 0 , /* expectRanks*/ {0 }, /* expectPPDomain*/ 1 , /* expectTPDomain*/ 1 );
2359- verfiyGeneration (
2427+ verifyGeneration (
23602428 /* contextRank*/ 0 , /* generationRank*/ 1 , /* expectRanks*/ {0 }, /* expectPPDomain*/ 1 , /* expectTPDomain*/ 1 );
23612429
23622430 contextTP = 1 ;
@@ -2366,9 +2434,9 @@ TEST(targetTest, CacheStateContextDP)
23662434 contextAttentionLayerNumPerPP = std::vector<SizeType32>(contextPP, numLayers / contextPP);
23672435 genAttentionLayerNumPerPP = std::vector<SizeType32>(genPP, numLayers / genPP);
23682436
2369- verfiyGeneration (
2437+ verifyGeneration (
23702438 /* contextRank*/ 0 , /* generationRank*/ 0 , /* expectRanks*/ {0 }, /* expectPPDomain*/ 1 , /* expectTPDomain*/ 1 );
2371- verfiyGeneration (
2439+ verifyGeneration (
23722440 /* contextRank*/ 0 , /* generationRank*/ 1 , /* expectRanks*/ {0 }, /* expectPPDomain*/ 1 , /* expectTPDomain*/ 1 );
23732441
23742442 genEnableDP = false ;
@@ -2381,8 +2449,8 @@ TEST(targetTest, CacheStateContextDP)
23812449 contextAttentionLayerNumPerPP = std::vector<SizeType32>(contextPP, numLayers / contextPP);
23822450 genAttentionLayerNumPerPP = std::vector<SizeType32>(genPP, numLayers / genPP);
23832451
2384- verfiyGeneration (
2452+ verifyGeneration (
23852453 /* contextRank*/ 0 , /* generationRank*/ 0 , /* expectRanks*/ {0 }, /* expectPPDomain*/ 1 , /* expectTPDomain*/ 1 );
2386- verfiyGeneration (
2454+ verifyGeneration (
23872455 /* contextRank*/ 1 , /* generationRank*/ 0 , /* expectRanks*/ {1 }, /* expectPPDomain*/ 1 , /* expectTPDomain*/ 1 );
23882456}
0 commit comments