Skip to content

[TRTLLM-11513][fix] fix iterator undefined behavior in WindowBlockManager::getFreeBlock offload path#12297

Open
thorjohnsen wants to merge 18 commits intoNVIDIA:mainfrom
thorjohnsen:thorjohnsen/fix_block_leak_bugs
Open

[TRTLLM-11513][fix] fix iterator undefined behavior in WindowBlockManager::getFreeBlock offload path#12297
thorjohnsen wants to merge 18 commits intoNVIDIA:mainfrom
thorjohnsen:thorjohnsen/fix_block_leak_bugs

Conversation

@thorjohnsen
Copy link
Copy Markdown
Collaborator

@thorjohnsen thorjohnsen commented Mar 17, 2026

Summary

  • In WindowBlockManager::getFreeBlock, the KV-cache offload path called swapMemoryPoolBlockOffset() before calling claimBlock() on either block. After the swap, getCacheLevel() returns the post-swap cache level, but mFreeBlockIterators[id] still points into the pre-swap queue. claimBlock() then calls .erase() on the wrong std::list — undefined behaviour per C++17 §26.3.10.4.
  • In practice this silently corrupts mNumFreeBlocksPerLevel and the free-list structure in LRUEvictionPolicy, causing block counts to diverge from reality and eventually producing "No free block found" aborts.
  • Fix: claim both block (primary) and offloadBlock (secondary) before swapMemoryPoolBlockOffset(), so each claimBlock() call sees the cache level that matches the iterator. This is identical to the ordering already used correctly by WindowBlockManager::offloadBlock() (line 1121).

Root cause (detailed)

LRUEvictionPolicy::claimBlock() (evictionPolicy.cpp):

SizeType32 const cacheLevel = getCacheLevel(block);  // reads isPrimary() — reflects post-swap level
mFreeQueues[cacheLevel][getPriorityIdx(...)].erase(*mFreeBlockIterators[id]);  // iterator is from pre-swap level

Both block and offloadBlock are obtained from the eviction policy's free queues and therefore have valid mFreeBlockIterators. After the swap the iterator/level mismatch affects both of them.

API cleanup

I have also updated the API so that it no longer matters whether claimBlock is called before or after swapMemoryPoolBlockOffset. With the cleaned up API, the above fix is no longer necessary, but I think it is the correct way of doing things. We should claim the blocks before changing their state with swapMemoryPoolBlockOffset.

Test plan

  • Enable KV-cache offloading (--kv_cache_offload_gpu_to_cpu_frac or equivalent) and run multi-request inference — previously this path would eventually corrupt the free-list counters leading to a crash or incorrect "no free blocks" error.
  • Run existing tests/unittest/batch_manager/test_kv_cache_manager.* suite.
  • Verify with ASan/UBSan that no std::list iterator misuse is reported in LRUEvictionPolicy::claimBlock.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Refactor
    • Improved KV cache memory block management efficiency during offloading and swapping operations, enhancing resource allocation and ownership tracking.

@thorjohnsen thorjohnsen requested a review from a team as a code owner March 17, 2026 21:47
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 17, 2026

📝 Walkthrough

Walkthrough

The getFreeBlock function in the KV cache manager is modified to claim both blocks from the eviction policy before performing the swap operation, then release the secondary block directly into its queue post-swap. This ensures cache level ownership accurately reflects block states throughout the operation.

Changes

Cohort / File(s) Summary
KV Cache Manager Block Claiming Logic
cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Modified getFreeBlock to claim blocks before swapping and release the secondary block directly after swap without intermediate claim/release cycles, ensuring correct eviction policy state during offload and swap operations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title directly describes the main fix: addressing undefined behavior in the WindowBlockManager::getFreeBlock offload path, which is the core change in the kvCacheManager.cpp file.
Description check ✅ Passed PR description comprehensively covers the issue, root cause, fix, and test plan, aligning well with the template requirements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@thorjohnsen thorjohnsen changed the title batch_manager: fix iterator UB in WindowBlockManager::getFreeBlock offload path [TRTLLM-11513][fix] fix iterator undefined behavior in WindowBlockManager::getFreeBlock offload path Mar 18, 2026
@nvpohanh
Copy link
Copy Markdown
Collaborator

@eopXD could you review this? thanks

@thorjohnsen thorjohnsen force-pushed the thorjohnsen/fix_block_leak_bugs branch from fa19cda to 8d1d0fd Compare March 19, 2026 15:44
…fload path

claimBlock() reads getCacheLevel(block) which is based on isPrimary(). In the
offload path, swapMemoryPoolBlockOffset() was called before claimBlock(), so
getCacheLevel() returned the post-swap level while mFreeBlockIterators[id] still
pointed into the pre-swap queue. Erasing via a std::list::iterator into a
different std::list instance is undefined behaviour (C++17 §26.3.10.4); in
practice it silently corrupts mNumFreeBlocksPerLevel and the free-list structure.

Fix: claim both blocks before the swap, matching the correct ordering already
used by WindowBlockManager::offloadBlock() (line 1121).

Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@thorjohnsen thorjohnsen marked this pull request as draft March 23, 2026 15:03
@thorjohnsen
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@thorjohnsen thorjohnsen marked this pull request as ready for review March 23, 2026 15:10
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39952 [ run ] triggered by Bot. Commit: 8d1d0fd Link to invocation

@thorjohnsen thorjohnsen marked this pull request as draft March 23, 2026 15:16
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
@thorjohnsen
Copy link
Copy Markdown
Collaborator Author

/bot kill

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39957 [ kill ] triggered by Bot. Commit: 98abb50 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39957 [ kill ] completed with state SUCCESS. Commit: 98abb50
Successfully killed previous jobs for commit 98abb50

Link to invocation

Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
…BlockOffset

Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
@thorjohnsen
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@thorjohnsen thorjohnsen marked this pull request as ready for review March 23, 2026 21:56
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39983 [ run ] triggered by Bot. Commit: dcf2408 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39983 [ run ] completed with state FAILURE. Commit: dcf2408
/LLM/main/L0_MergeRequest_PR pipeline #31143 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
@thorjohnsen
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40026 [ run ] triggered by Bot. Commit: a5f5326 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40026 [ run ] completed with state SUCCESS. Commit: a5f5326
/LLM/main/L0_MergeRequest_PR pipeline #31183 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@thorjohnsen
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40134 [ run ] triggered by Bot. Commit: c1425aa Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40134 [ run ] completed with state FAILURE. Commit: c1425aa

Link to invocation

@thorjohnsen
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40137 [ run ] triggered by Bot. Commit: 06813ab Link to invocation

@thorjohnsen thorjohnsen added the KV-Cache Management kv-cache management for efficient LLM inference label Mar 24, 2026
Copy link
Copy Markdown
Collaborator

@eopXD eopXD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but some test coverage can help illustrate the change and extra security.

…n LRUEvictionPolicy

Add ClaimAfterSwapDoesNotCorruptQueues to LRUPolicyTest to guard against
reintroduction of the iterator UB fixed in PR NVIDIA#12297. The test releases a
primary and a secondary block into their respective free queues, then calls
swapMemoryPoolBlockOffset() on both (flipping isPrimary()) before calling
claimBlock(). With the old bare-iterator approach this would erase from the
wrong std::list; with the fix (stored cacheLevel tuple) the counts and queue
integrity remain correct.

Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
@thorjohnsen thorjohnsen requested a review from eopXD March 24, 2026 19:09
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40137 [ run ] completed with state SUCCESS. Commit: 06813ab
/LLM/main/L0_MergeRequest_PR pipeline #31283 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@eopXD
Copy link
Copy Markdown
Collaborator

eopXD commented Mar 25, 2026

/bot run --disable-fail-fast

@thorjohnsen thorjohnsen enabled auto-merge (squash) March 25, 2026 20:47
@thorjohnsen
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40372 [ run ] triggered by Bot. Commit: e32f060 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40372 [ run ] completed with state FAILURE. Commit: e32f060
/LLM/main/L0_MergeRequest_PR pipeline #31471 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@thorjohnsen
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40551 [ run ] triggered by Bot. Commit: e32f060 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40551 [ run ] completed with state DISABLED
CI server is currently disabled for scheduled maintenance. Estimated completion time: 9 PM PST on 3/28.

Link to invocation

@thorjohnsen
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40732 [ run ] triggered by Bot. Commit: e32f060 Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

KV-Cache Management kv-cache management for efficient LLM inference

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants