@@ -207,47 +207,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
207207 }
208208}
209209
210- void KVCacheTransferManager::onboard (BlockPtr const & offloadBlock, BlockPtr const & block,
210+ //
211+ // Note about recording events to wait for cudaMempyAsync calls between blocks:
212+ // The memory copy involves raw memory blocks, which are pointed to by the
213+ // memory pool block index. When recording events, you must use getMemoryPoolBlockIndex()
214+ // as the raw memory block identifier. Using getBlockId() when recording events is wrong.
215+ // getBlockId() returns the logical block id, which has nothing to do with the raw memory
216+ // block pointers involved in a cudaMemcpy.
217+ //
218+
219+ //
220+ // Notes about need for synchronization:
221+ //
222+ // Relying on decoder syncing GPU with CPU to ensure that blocks are ready
223+ // for offload/onboard/partial copy is dangerous. We have an asynchronous decoder
224+ // that may not synchronize or synchronize at a later point in the execution stream.
225+ // To avoid synchronization issues caused by changes to decoder design we rely on
226+ // KVCacheTransferManager::syncWithBufferManager() that ensures that internal copy streams
227+ // will wait for prefill and decode kernels that have already been scheduled.
228+ //
229+ // Earlier versions of this code did not account for all possible cases where a new block copy
230+ // needed to wait for a previously scheduled copy to finish. For instance, it is possible
231+ // that two primary blocks are offloaded to the same secondary block in a single step,
232+ // scheduling the second offloading without waiting for the first one to finish leads to
233+ // a corrupted block after offloading. It is possible that partial reuse will copy
234+ // from a block that is currently being onboarded, scheduling the partial copy without
235+ // waiting for the onboarding to finish will lead to a corrupted block. To handle all
236+ // possible cases needing synchronization we record separate events for reads and writes
237+ // to a block. When a new block copy is scheduled, we wait for all writes to the source
238+ // block and all reads and writes to a destination block.
239+ //
240+ // As before, syncTransfers() must be called after last call to KVCacheManager::addSequence.
241+ // Failing to do so will lead to corrupted blocks eventually.
242+ //
243+
244+ void KVCacheTransferManager::onboard (BlockPtr const & offloadedBlock, BlockPtr const & block,
211245 std::vector<KVCacheBlockPool> const & pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
212246 std::string const & directory)
213247{
214- if (mode != executor::KvCacheTransferMode::DRAM
215- && mPendingOffloads .find (offloadBlock->getBlockId ()) == mPendingOffloads .end ())
248+ // Wait for any pending writes before reading from offloadedBlock
249+ auto offloadedBlockPendingWriteItr = mPendingWrites .find (offloadedBlock->getMemoryPoolBlockIndex ());
250+ if (offloadedBlockPendingWriteItr != mPendingWrites .end ())
216251 {
217- TLLM_LOG_DEBUG (" Skipping onboard for block %d because it was never previously offloaded to disk" ,
218- offloadBlock->getBlockId ());
219- return ;
252+ mOnboardManager .getStream ().wait (offloadedBlockPendingWriteItr->second );
253+ // Don't erase, we are not changing state of offloadedBlock
220254 }
221-
222- if (mPendingOffloads .find (offloadBlock->getBlockId ()) != mPendingOffloads .end ())
255+ // Wait for any pending reads before overwriting block
256+ auto blockPendingReadItr = mPendingReads .find (block->getMemoryPoolBlockIndex ());
257+ if (blockPendingReadItr != mPendingReads .end ())
258+ {
259+ mOnboardManager .getStream ().wait (blockPendingReadItr->second );
260+ mPendingReads .erase (blockPendingReadItr);
261+ }
262+ // Wait for any pending writes before overwriting block
263+ auto blockPendingWriteItr = mPendingWrites .find (block->getMemoryPoolBlockIndex ());
264+ if (blockPendingWriteItr != mPendingWrites .end ())
223265 {
224- mOnboardManager .getStream ().wait (mPendingOffloads [offloadBlock->getBlockId ()]);
266+ mOnboardManager .getStream ().wait (blockPendingWriteItr->second );
267+ mPendingWrites .erase (blockPendingWriteItr);
225268 }
226- copyBlock (offloadBlock, block, pools, false , numTokensToCopy, mode, directory);
269+
270+ copyBlock (offloadedBlock, block, pools, false , numTokensToCopy, mode, directory);
271+
272+ // Record new pending read from offloadedBlock
273+ mPendingReads [offloadedBlock->getMemoryPoolBlockIndex ()] = tr::CudaEvent ();
274+ mOnboardManager .getStream ().record (mPendingReads [offloadedBlock->getMemoryPoolBlockIndex ()]);
275+ // Record new pending write to block
276+ mPendingWrites [block->getMemoryPoolBlockIndex ()] = tr::CudaEvent ();
277+ mOnboardManager .getStream ().record (mPendingWrites [block->getMemoryPoolBlockIndex ()]);
227278}
228279
229280void KVCacheTransferManager::offload (BlockPtr const & block, BlockPtr const & offloadBlock,
230281 std::vector<KVCacheBlockPool> const & pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
231282 std::string const & directory)
232283{
233- mPendingOffloads [block->getBlockId ()] = tr::CudaEvent ();
284+ // Wait for any pending writes before reading from block
285+ auto blockPendingWriteItr = mPendingWrites .find (block->getMemoryPoolBlockIndex ());
286+ if (blockPendingWriteItr != mPendingWrites .end ())
287+ {
288+ mOffloadManager .getStream ().wait (blockPendingWriteItr->second );
289+ // Don't erase, we are not changing state of block
290+ }
291+ // Wait for any pending reads before overwriting offloadBlock
292+ auto offloadBlockPendingReadItr = mPendingReads .find (offloadBlock->getMemoryPoolBlockIndex ());
293+ if (offloadBlockPendingReadItr != mPendingReads .end ())
294+ {
295+ mOffloadManager .getStream ().wait (offloadBlockPendingReadItr->second );
296+ mPendingReads .erase (offloadBlockPendingReadItr);
297+ }
298+ // Wait for any pending writes before overwriting offloadBlock
299+ auto offloadBlockPendingWriteItr = mPendingWrites .find (offloadBlock->getMemoryPoolBlockIndex ());
300+ if (offloadBlockPendingWriteItr != mPendingWrites .end ())
301+ {
302+ mOffloadManager .getStream ().wait (offloadBlockPendingWriteItr->second );
303+ mPendingWrites .erase (offloadBlockPendingWriteItr);
304+ }
305+
234306 copyBlock (block, offloadBlock, pools, true , numTokensToCopy, mode, directory);
235- mOffloadManager .getStream ().record (mPendingOffloads [block->getBlockId ()]);
307+
308+ // Record new pending read from block
309+ mPendingReads [block->getMemoryPoolBlockIndex ()] = tr::CudaEvent ();
310+ mOffloadManager .getStream ().record (mPendingReads [block->getMemoryPoolBlockIndex ()]);
311+ // Record new pending write to offloadBlock
312+ mPendingWrites [offloadBlock->getMemoryPoolBlockIndex ()] = tr::CudaEvent ();
313+ mOffloadManager .getStream ().record (mPendingWrites [offloadBlock->getMemoryPoolBlockIndex ()]);
314+ }
315+
316+ void KVCacheTransferManager::syncWithBufferManager ()
317+ {
318+ tr::CudaEvent readyForOffloadEvent;
319+ mBufferManager .getStream ().record (readyForOffloadEvent);
320+ mOffloadManager .getStream ().wait (readyForOffloadEvent);
321+
322+ tr::CudaEvent readyForOnboardEvent;
323+ mBufferManager .getStream ().record (readyForOnboardEvent);
324+ mOnboardManager .getStream ().wait (readyForOnboardEvent);
325+
326+ // Once we synchronize, clear our list of pending thransfers.
327+ mPendingReads .clear ();
328+ mPendingWrites .clear ();
236329}
237330
238331void KVCacheTransferManager::syncTransfers ()
239332{
240333 tr::CudaEvent offloadEvent;
241334 mOffloadManager .getStream ().record (offloadEvent);
335+ mBufferManager .getStream ().wait (offloadEvent);
242336
243337 tr::CudaEvent onboardEvent;
244338 mOnboardManager .getStream ().record (onboardEvent);
245-
246- mBufferManager .getStream ().wait (offloadEvent);
247339 mBufferManager .getStream ().wait (onboardEvent);
248340
249341 // Once we synchronize, clear our list of pending thransfers.
250- mPendingOffloads .clear ();
342+ mPendingReads .clear ();
343+ mPendingWrites .clear ();
251344}
252345
253346} // namespace tensorrt_llm::batch_manager::kv_cache_manager
0 commit comments