|
37 | 37 | namespace tensorrt_llm::batch_manager::kv_cache_manager |
38 | 38 | { |
39 | 39 |
|
40 | | -int getBlockNumAccountingForCP(int cpRank, int cpSize, int numTotalBlocks, bool strict) |
41 | | -{ |
42 | | - TLLM_CHECK(cpRank >= 0 && cpRank < cpSize); |
43 | | - if (cpSize == 1) |
44 | | - { |
45 | | - return numTotalBlocks; |
46 | | - } |
47 | | - // NOTE: Non-strict mode may over-allocate blocks when numTotalBlocks is not divisible by cpSize. |
48 | | - // This is a known limitation and will be addressed in a future MR. |
49 | | - if (!strict) |
50 | | - { |
51 | | - // Simple ceiling division. |
52 | | - return (numTotalBlocks + cpSize - 1) / cpSize; |
53 | | - } |
54 | | - // In strict mode, blocks are distributed among CP ranks in a round-robin fashion as evenly as possible. |
55 | | - // When the number of blocks is not divisible by cpSize, the remainder shall be distributed evenly among |
56 | | - // lowest-indexed CP ranks (let's call them overflow ranks). |
57 | | - int numBlocksCurrRank = numTotalBlocks / cpSize; |
58 | | - if (numTotalBlocks % cpSize > cpRank) |
59 | | - { |
60 | | - numBlocksCurrRank++; |
61 | | - } |
62 | | - return numBlocksCurrRank; |
63 | | -} |
64 | | - |
65 | 40 | // some context rank in connection |
66 | 41 | std::vector<size_t> MLACacheFormatter::pickRecvConnections( |
67 | 42 | size_t numConnections, CacheState const& selfConfig, SizeType32 selfIdx, CacheState const& destConfig) const |
@@ -145,7 +120,8 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses |
145 | 120 | int blockNum = 0; |
146 | 121 | std::vector<runtime::ITensor::SharedPtr> inputKvCacheBlocks; |
147 | 122 | auto const numPools = mCacheManager->getBlockManager().getNumPools(); |
148 | | - auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, lastBlockKey, indexFromEnd); |
| 123 | + bool const recvSideHasCP = destConfig.getParallelConfig().mContextParallelism > 1; |
| 124 | + auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, lastBlockKey, indexFromEnd, recvSideHasCP); |
149 | 125 | auto const& windowSizes = blockRange.getWindowSizes(); |
150 | 126 | TLLM_CHECK_WITH_INFO( |
151 | 127 | static_cast<int>(windowSizes.size()) == numPools, "window sizes should be the same as numPools"); |
@@ -204,7 +180,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses |
204 | 180 | auto const idx = cpDomainIdx * pPDomainSize + ppDomainIdx; |
205 | 181 | // Note: contextCP is always 1. So, cpDomainSize == genCPSize and cpDomainIdx == genCPRank. |
206 | 182 | auto const peerBlockNum |
207 | | - = getBlockNumAccountingForCP(cpDomainIdx, cPDomainSize, blockNum, /*strict=*/false); |
| 183 | + = executor::kv_cache::getBlockNumAccountingForCP(cpDomainIdx, cPDomainSize, blockNum); |
208 | 184 | bufferSizeForTarget[idx] = blockSizePerLayer * peerAttentionLayerNum * peerBlockNum; |
209 | 185 | } |
210 | 186 | } |
@@ -346,7 +322,9 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s |
346 | 322 | auto const& connections = session.getConnections(); |
347 | 323 | auto& bufferManager = session.getBufferManager(); |
348 | 324 | auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig); |
349 | | - auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse()); |
| 325 | + bool const recvSideHasCP = selfConfig.getParallelConfig().mContextParallelism > 1; |
| 326 | + auto blockRange |
| 327 | + = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse(), recvSideHasCP); |
350 | 328 | std::vector<runtime::ITensor::SharedPtr> recvBufferTmps; |
351 | 329 | std::vector<runtime::ITensor::SharedPtr> outputBuffers; |
352 | 330 | auto const numPools = mCacheManager->getBlockManager().getNumPools(); |
|
0 commit comments