Skip to content

Commit 9fcc93e

Browse files
authored
[https://nvbugs/5829097][fix] Re-init TRTLLM sampler to use sample stream in multi-stream cases. (#10918)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent 9d65b8b commit 9fcc93e

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from .model_engine import ModelEngine
5555
from .resource_manager import ResourceManager
5656
from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState,
57-
SampleStateTensors)
57+
SampleStateTensors, TRTLLMSampler)
5858
from .scheduler import (RequestScheduler, ScheduledRequests,
5959
SerializableSchedulerOutput)
6060

@@ -371,9 +371,19 @@ def __init__(self,
371371
self.send_schedule_handler = None
372372
self.pp_scheduler_max_retry_count = int(
373373
os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10))
374+
self.pp_multi_stream_sample = os.environ.get(
375+
"TRTLLM_PP_MULTI_STREAM_SAMPLE", "1") == "1"
374376
self.sample_stream = torch.cuda.Stream()
375377
self.start_sample_event = torch.cuda.Event()
376378
self.finish_sample_event = torch.cuda.Event()
379+
if (self.dist.pp_size > 1 and self.pp_multi_stream_sample
380+
and isinstance(self.sampler, TRTLLMSampler)):
381+
# TRTLLM sampler uses default stream for store and algorithms.
382+
# To enable multi-stream sampling, we need to re-initialize
383+
# the sampler store and algorithms on the sample stream.
384+
with torch.cuda.stream(self.sample_stream):
385+
self.sampler._initialize_store()
386+
self.sampler._instantiate_algorithms()
377387

378388
# Set of request IDs that are currently in flight across all micro batches.
379389
# The scheduler will avoid scheduling requests that are already in flight.
@@ -1216,8 +1226,7 @@ def _executor_loop_pp(self):
12161226
guided_decoder_failed_requests = self.guided_decoder.execute(
12171227
batch_outputs['logits'])
12181228

1219-
if os.environ.get("TRTLLM_PP_MULTI_STREAM_SAMPLE",
1220-
"1") == "1":
1229+
if self.pp_multi_stream_sample:
12211230
# Wait for the previous sample to finish.
12221231
self.finish_sample_event.wait()
12231232
# Copy the batch outputs as sampler inputs

0 commit comments

Comments
 (0)