Skip to content

Commit 84d107b

Browse files
authored
[https://nvbugs/5717993][fix] Add execution_stream across PyExecutor, KVCacheManager, PeftCacheManager to ensure proper CUDA stream synchronization between KV cache transfer operations and model forward kernels. (#10060)
Signed-off-by: SimengLiu-nv <[email protected]>
1 parent 0d2e271 commit 84d107b

File tree

12 files changed

+321
-36
lines changed

12 files changed

+321
-36
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
speculative_config: SpeculativeConfig,
7878
sparse_attention_config: SparseAttentionConfig,
7979
profiling_stage_data: Optional[dict],
80+
execution_stream: Optional[torch.cuda.Stream] = None,
8081
):
8182
self._model_engine = model_engine
8283
self._draft_model_engine = draft_model_engine
@@ -97,6 +98,7 @@ def __init__(
9798
self._profiling_stage_data = profiling_stage_data
9899
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
99100
model_engine.model.model_config)
101+
self._execution_stream = execution_stream
100102

101103
def _get_kv_size_per_token(self):
102104
model_config = self._model_engine.model.model_config
@@ -474,6 +476,7 @@ def _create_kv_cache_manager(
474476
max_beam_width=self._max_beam_width,
475477
kv_connector_manager=self._kv_connector_manager,
476478
estimating_kv_cache=estimating_kv_cache,
479+
execution_stream=self._execution_stream,
477480
)
478481

479482
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self
@@ -527,14 +530,20 @@ def teardown_managers(self, resources: Dict) -> None:
527530

528531

529532
def _create_kv_cache_manager(
530-
model_engine: PyTorchModelEngine, kv_cache_manager_cls,
531-
mapping: Mapping, kv_cache_config: KvCacheConfig, tokens_per_block: int,
532-
max_seq_len: int, max_batch_size: int,
533+
model_engine: PyTorchModelEngine,
534+
kv_cache_manager_cls,
535+
mapping: Mapping,
536+
kv_cache_config: KvCacheConfig,
537+
tokens_per_block: int,
538+
max_seq_len: int,
539+
max_batch_size: int,
533540
spec_config: Optional[SpeculativeConfig],
534541
sparse_attn_config: Optional[SparseAttentionConfig],
535-
max_num_tokens: int, max_beam_width: int,
542+
max_num_tokens: int,
543+
max_beam_width: int,
536544
kv_connector_manager: Optional[KvCacheConnectorManager],
537-
estimating_kv_cache: bool) -> KVCacheManager:
545+
estimating_kv_cache: bool,
546+
execution_stream: Optional[torch.cuda.Stream] = None) -> KVCacheManager:
538547
"""
539548
Returns:
540549
A KVCacheManager instance for the given model_engine
@@ -580,6 +589,7 @@ def _create_kv_cache_manager(
580589
if not estimating_kv_cache else None,
581590
sparse_attn_config=sparse_attn_config,
582591
is_estimating_kv_cache=estimating_kv_cache,
592+
execution_stream=execution_stream,
583593
)
584594
elif is_nemotron_hybrid(config):
585595
if max_beam_width > 1:
@@ -623,6 +633,7 @@ def _create_kv_cache_manager(
623633
dtype=kv_cache_dtype,
624634
spec_config=spec_config,
625635
is_estimating_kv_cache=estimating_kv_cache,
636+
execution_stream=execution_stream,
626637
)
627638
elif is_qwen3_next(config):
628639
if max_beam_width > 1:
@@ -672,6 +683,7 @@ def _create_kv_cache_manager(
672683
dtype=kv_cache_dtype,
673684
spec_config=spec_config,
674685
is_estimating_kv_cache=estimating_kv_cache,
686+
execution_stream=execution_stream,
675687
)
676688
else:
677689
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
@@ -700,6 +712,7 @@ def _create_kv_cache_manager(
700712
if not estimating_kv_cache else None,
701713
sparse_attn_config=sparse_attn_config,
702714
is_estimating_kv_cache=estimating_kv_cache,
715+
execution_stream=execution_stream,
703716
)
704717
return kv_cache_manager
705718

@@ -727,6 +740,7 @@ def create_py_executor_instance(
727740
scheduler_config: Optional[SchedulerConfig] = None,
728741
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
729742
virtual_memory_pools: Optional[dict] = None,
743+
execution_stream: Optional[torch.cuda.Stream] = None,
730744
) -> PyExecutor:
731745
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
732746

@@ -813,6 +827,7 @@ def create_py_executor_instance(
813827
lora_config=lora_config,
814828
model_config=model_binding_config,
815829
world_config=world_config,
830+
execution_stream=execution_stream,
816831
)
817832
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
818833
model_engine.set_lora_model_config(
@@ -875,7 +890,8 @@ def create_py_executor_instance(
875890
kv_connector_manager=kv_connector_manager,
876891
max_seq_len=max_seq_len,
877892
peft_cache_config=peft_cache_config,
878-
virtual_memory_pools=virtual_memory_pools)
893+
virtual_memory_pools=virtual_memory_pools,
894+
execution_stream=execution_stream)
879895

880896

881897
def create_torch_sampler_args(

tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __init__(
197197
dtype: DataType = DataType.HALF,
198198
spec_config: Optional["DecodingBaseConfig"] = None,
199199
is_estimating_kv_cache: bool = False,
200+
execution_stream: Optional[torch.cuda.Stream] = None,
200201
) -> None:
201202

202203
# mamba hybrid cache requires block reuse to be disabled in KV cache config
@@ -234,6 +235,7 @@ def __init__(
234235
spec_config=spec_config,
235236
layer_mask=layer_mask,
236237
is_estimating_kv_cache=is_estimating_kv_cache,
238+
execution_stream=execution_stream,
237239
)
238240

239241
def prepare_resources(self, scheduled_batch: ScheduledRequests):

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,22 @@ def __init__(self,
136136
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
137137
max_seq_len: Optional[int] = None,
138138
peft_cache_config: Optional[PeftCacheConfig] = None,
139-
virtual_memory_pools: Optional[dict] = None):
139+
virtual_memory_pools: Optional[dict] = None,
140+
execution_stream: Optional[torch.cuda.Stream] = None):
140141
super(PyExecutor, self).__init__()
141142
self.device_id = torch.cuda.current_device()
142143
self.global_rank = dist.rank
143144

145+
# Store the execution stream for model forward operations.
146+
# This stream is used for proper synchronization with KVCacheTransferManager.
147+
# execution_stream can be provided by create_py_executor
148+
# Create a new stream if none provided
149+
self.execution_stream = execution_stream if execution_stream is not None else torch.cuda.Stream(
150+
)
151+
logger.info(
152+
f"[PyExecutor] execution_stream initialized: {self.execution_stream}. "
153+
)
154+
144155
self.peft_cache_config = peft_cache_config
145156

146157
self.iter_counter = 0
@@ -245,10 +256,19 @@ def __init__(self,
245256
self.inflight_req_ids = ReqIdsSet()
246257

247258
# During warmup, we don't enable the profiler
259+
# Run warmup on the execution_stream for proper synchronization with
260+
# KVCacheTransferManager's onboard/offload operations.
248261
self.is_warmup = True
249-
self.model_engine.warmup(self.resource_manager)
250-
if self.draft_model_engine is not None:
251-
self.draft_model_engine.warmup(self.resource_manager)
262+
263+
self.execution_stream.wait_stream(torch.cuda.current_stream())
264+
with torch.cuda.stream(self.execution_stream):
265+
self.model_engine.warmup(self.resource_manager)
266+
if self.draft_model_engine is not None:
267+
self.draft_model_engine.warmup(self.resource_manager)
268+
269+
# Ensure the default stream waits for execution_stream to complete
270+
# before subsequent operations.
271+
torch.cuda.current_stream().wait_stream(self.execution_stream)
252272
self.is_warmup = False
253273

254274
self.is_shutdown = False
@@ -2231,10 +2251,19 @@ def forward(scheduled_requests, resource_manager, new_tensors_device,
22312251
a.py_return_context_logits
22322252
for a in scheduled_requests.context_requests)
22332253
cache_indirection_buffer = self.sampler.get_cache_indirection()
2234-
outputs = forward(scheduled_requests, self.resource_manager,
2235-
new_tensors_device, gather_context_logits,
2236-
cache_indirection_buffer,
2237-
num_accepted_tokens_device)
2254+
2255+
# Run model forward on the execution stream for proper synchronization
2256+
# with KVCacheTransferManager's onboard/offload operations.
2257+
self.execution_stream.wait_stream(torch.cuda.current_stream())
2258+
with torch.cuda.stream(self.execution_stream):
2259+
outputs = forward(scheduled_requests, self.resource_manager,
2260+
new_tensors_device, gather_context_logits,
2261+
cache_indirection_buffer,
2262+
num_accepted_tokens_device)
2263+
2264+
# Ensure the default stream waits for execution_stream to complete
2265+
# before downstream operations use the outputs.
2266+
torch.cuda.current_stream().wait_stream(self.execution_stream)
22382267

22392268
self._kv_connector_wait_for_save()
22402269

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,13 @@ def drafting_loop_wrapper(model):
601601
resources = {}
602602
estimating_kv_cache = False
603603
kv_cache_creator = None
604+
605+
# Create the execution stream for model forward operations
606+
# for proper synchronization with KVCacheTransferManager's onboard/offload operations.
607+
execution_stream = torch.cuda.Stream()
608+
logger.info(
609+
f"[create_py_executor] Created execution_stream: {execution_stream}")
610+
604611
if model_engine.model.model_config.is_generation:
605612
#NOTE: non-generation models do not have kv cache
606613
kv_cache_creator = KvCacheCreator(
@@ -619,6 +626,7 @@ def drafting_loop_wrapper(model):
619626
speculative_config=spec_config,
620627
profiling_stage_data=profiling_stage_data,
621628
sparse_attention_config=sparse_attention_config,
629+
execution_stream=execution_stream,
622630
)
623631
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
624632
with allocation_scope(
@@ -676,6 +684,7 @@ def drafting_loop_wrapper(model):
676684
scheduler_config=scheduler_config,
677685
cache_transceiver_config=cache_transceiver_config,
678686
virtual_memory_pools=vm_pools if not estimating_kv_cache else None,
687+
execution_stream=execution_stream,
679688
)
680689
# Originally, peft_cache_config might be mutated inside
681690
# create_py_executor_instance. Restore it here.
@@ -736,6 +745,7 @@ def drafting_loop_wrapper(model):
736745
scheduler_config=scheduler_config,
737746
cache_transceiver_config=cache_transceiver_config,
738747
virtual_memory_pools=vm_pools,
748+
execution_stream=execution_stream,
739749
)
740750

741751
_adjust_torch_mem_fraction()

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def __init__(
176176
indexer_k_cache_quant_block_size: int = 128,
177177
indexer_k_cache_index_head_dim: int = 0,
178178
is_estimating_kv_cache: bool = False,
179+
execution_stream: Optional[torch.cuda.Stream] = None,
179180
**kwargs,
180181
) -> None:
181182
self.mapping = mapping
@@ -351,9 +352,13 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
351352
# Set up temp_attention_window_inputs
352353
temp_attention_window_inputs = self._set_temp_attention_window_inputs()
353354

354-
# Note that this stream is unused for now. Will be used for copying to host
355-
# when that feature is enabled.
356-
self._stream = torch.cuda.Stream()
355+
# Use the provided execution stream for proper synchronization with KVCacheTransferManager.
356+
# The execution stream is the stream where model forward kernels run, and KVCacheTransferManager
357+
# needs to synchronize with it for onboard/offload operations.
358+
# If no execution stream is provided, create a new one (for backward compatibility).
359+
self._stream = execution_stream if execution_stream is not None else torch.cuda.Stream(
360+
)
361+
logger.info(f"[KVCacheManager] execution_stream: {self._stream}")
357362
kwargs = {
358363
'num_kv_heads_per_layer': self.num_kv_heads_per_layer,
359364
'size_per_head': head_dim,
@@ -365,7 +370,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
365370
'temp_attention_window_inputs': temp_attention_window_inputs,
366371
'dtype': dtype,
367372
'sink_token_length': sink_token_length,
368-
'stream': self._stream.cuda_stream,
373+
'stream': self._stream.cuda_stream, # Pass to BufferManager
369374
'max_sequence_length': max_seq_len,
370375
'enable_block_reuse': kv_cache_config.enable_block_reuse,
371376
'onboard_blocks': kv_cache_config.onboard_blocks,
@@ -1442,7 +1447,8 @@ def __init__(self,
14421447
peft_cache_config: PeftCacheConfig,
14431448
lora_config: LoraConfig,
14441449
model_config: ModelConfigCpp,
1445-
world_config: WorldConfig | None = None):
1450+
world_config: WorldConfig | None = None,
1451+
execution_stream: Optional[torch.cuda.Stream] = None):
14461452
import tensorrt_llm.bindings as _tb
14471453

14481454
peft_cache_config = peft_cache_config._to_pybind()
@@ -1467,8 +1473,12 @@ def __init__(self,
14671473
world_config = _tb.WorldConfig()
14681474

14691475
BufferManager = tensorrt_llm.bindings.internal.runtime.BufferManager
1470-
buffer_manager = BufferManager(torch.cuda.current_stream().cuda_stream,
1471-
True)
1476+
buffer_manager_stream = execution_stream.cuda_stream if execution_stream is not None else torch.cuda.current_stream(
1477+
).cuda_stream
1478+
buffer_manager = BufferManager(buffer_manager_stream, True)
1479+
logger.info(
1480+
f"[PeftCacheManager] buffer_manager_stream: {buffer_manager_stream}"
1481+
)
14721482
self.impl = PeftCacheManagerCpp(config=peft_cache_manager_config,
14731483
model_config=model_config,
14741484
world_config=world_config,

tensorrt_llm/evaluate/lm_eval.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -380,23 +380,43 @@ def _adjust_config(task_dict, random_seed):
380380

381381
@contextmanager
382382
def _patch_lm_eval(self):
383-
if self.dataset_path is None:
384-
yield
385-
return
383+
from pathlib import Path
386384

387385
import lm_eval
388-
self._task_config_post_init = lm_eval.api.task.TaskConfig.__post_init__
386+
import lm_eval.tasks
387+
388+
# Patch Path.relative_to to handle custom task paths outside lm_eval/tasks
389+
# This is needed with lm_eval>=0.4.9.2 with new function pretty_print_task (a local function inside
390+
# get_task_dict) calls yaml_path.relative_to(lm_eval_tasks_path) which fails
391+
# when the yaml is from tensorrt_llm/evaluate/lm_eval_tasks
392+
original_relative_to = Path.relative_to
393+
394+
def _patched_relative_to(self, other, *args, **kwargs):
395+
try:
396+
return original_relative_to(self, other, *args, **kwargs)
397+
except ValueError:
398+
# Return absolute path if relative_to fails (path not under base)
399+
return self
400+
401+
Path.relative_to = _patched_relative_to
402+
403+
# Optionally patch dataset_path if provided
404+
original_post_init = None
405+
if self.dataset_path is not None:
406+
original_post_init = lm_eval.api.task.TaskConfig.__post_init__
389407

390-
def _patched(task_config, *args, **kwargs):
391-
task_config.dataset_path = self.dataset_path
392-
self._task_config_post_init(task_config, *args, **kwargs)
408+
def _patched_post_init(task_config, *args, **kwargs):
409+
task_config.dataset_path = self.dataset_path
410+
original_post_init(task_config, *args, **kwargs)
393411

394-
lm_eval.api.task.TaskConfig.__post_init__ = _patched
412+
lm_eval.api.task.TaskConfig.__post_init__ = _patched_post_init
395413

396414
try:
397415
yield
398416
finally:
399-
lm_eval.api.task.TaskConfig.__post_init__ = self._task_config_post_init
417+
Path.relative_to = original_relative_to
418+
if original_post_init is not None:
419+
lm_eval.api.task.TaskConfig.__post_init__ = original_post_init
400420

401421
def generate_samples(self) -> Iterable[tuple]:
402422
raise NotImplementedError()

tests/integration/test_lists/test-db/l0_a100.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ l0_a100:
1515
tests:
1616
- unittest/llmapi/test_llm_pytorch.py
1717
- unittest/llmapi/test_mpi_session.py ISOLATION
18-
- unittest/llmapi/test_memory_profiling.py # profile kvcache for vision encoder
18+
- unittest/llmapi/test_memory_profiling.py::test_profile_kvcache # profile kvcache for vision encoder
19+
- unittest/llmapi/test_memory_profiling.py::test_pyexecutor_and_kvcache_share_execution_stream # test that PyExecutor and KVCacheManager share the same execution_stream
1920
- unittest/trt/model_api/test_model_quantization.py
2021
# executor
2122
- unittest/executor/test_base_worker.py

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ l0_h100:
7676
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False]
7777
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=True]
7878
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_dummy_load_format
79-
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] TIMEOUT (90)
80-
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] TIMEOUT (90)
81-
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] TIMEOUT (90)
79+
# Waive known failures in https://nvbugs/5774869
80+
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] TIMEOUT (90)
81+
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] TIMEOUT (90)
82+
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] TIMEOUT (90)
8283
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=False]
8384
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=True-eagle3_one_model=True]
8485
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=True]

tests/integration/test_lists/waives.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ full:L40S/accuracy/test_cli_flow.py::TestGpt2::test_weight_streaming_plugin SKIP
298298
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2] SKIP (https://nvbugs/5596337)
299299
accuracy/test_llm_api.py::TestMixtral8x7BInstruct::test_awq_tp2 SKIP (https://nvbugs/5598847)
300300
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3.5-MoE-instruct] SKIP (https://nvbugs/5465143)
301-
unittest/llmapi/test_memory_profiling.py SKIP (https://nvbugs/5580781)
301+
unittest/llmapi/test_memory_profiling.py::test_profile_kvcache SKIP (https://nvbugs/5580781)
302302
triton_server/test_triton.py::test_llava[llava] SKIP (https://nvbugs/5547414)
303303
full:RTX/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5569696)
304304
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] SKIP (https://nvbugs/5596343)

0 commit comments

Comments
 (0)