Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def _general_warmup(self,
logger.info(
f"Run warmup with {num_tokens} tokens, include {num_gen_tokens} generation tokens"
)
self._sync_kv_cache_for_warmup(resource_manager)
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
Expand Down Expand Up @@ -800,6 +801,7 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager):
spec_resource_manager, Eagle3ResourceManager):
spec_resource_manager.is_first_draft = True

self._sync_kv_cache_for_warmup(resource_manager)
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
Expand Down Expand Up @@ -966,6 +968,7 @@ def _capture_generation_cuda_graphs(self,
self._update_draft_inference_state_for_warmup(
batch, draft_len > 0, resource_manager)
self.runtime_draft_len = draft_len
self._sync_kv_cache_for_warmup(resource_manager)
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
Expand Down Expand Up @@ -995,6 +998,7 @@ def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager):
logger.info(
f"Run piecewise CUDA graph warmup for num tokens={num_tokens}"
)
self._sync_kv_cache_for_warmup(resource_manager)
# Run a few times to ensure capture
for _ in range(3):
self.forward(batch,
Expand Down Expand Up @@ -1031,6 +1035,26 @@ def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager):

### Helper methods promoted from the original warmup method ###

def _sync_kv_cache_for_warmup(self, resource_manager: ResourceManager):
"""Synchronize KV cache transfer operations before a warmup forward pass.

When host cache offloading is enabled (host_cache_size > 0), the C++
KV cache manager schedules asynchronous H2D/D2H transfers and tracks
block locations via a transfer manager. In normal inference the
``prepare_resources()`` call performs this synchronization, but during
warmup ``prepare_resources()`` is not invoked. Without explicit
synchronization the GPU-side block table may contain stale or invalid
pointers, causing CUDA illegal-memory-access errors during the
attention kernels.
"""
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
if (isinstance(kv_cache_manager, KVCacheManager)
and getattr(kv_cache_manager, 'blocks_in_secondary_pool', 0)
> 0):
kv_cache_manager.impl.sync_transfer_manager_with_buffer_manager()
kv_cache_manager.impl.refresh_blocks()

@contextlib.contextmanager
def _release_batch_context(self, batch: Optional[ScheduledRequests],
resource_manager: ResourceManager):
Expand Down
Loading