|
54 | 54 | from .model_engine import ModelEngine |
55 | 55 | from .resource_manager import ResourceManager |
56 | 56 | from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState, |
57 | | - SampleStateTensors) |
| 57 | + SampleStateTensors, TRTLLMSampler) |
58 | 58 | from .scheduler import (RequestScheduler, ScheduledRequests, |
59 | 59 | SerializableSchedulerOutput) |
60 | 60 |
|
@@ -371,9 +371,19 @@ def __init__(self, |
371 | 371 | self.send_schedule_handler = None |
372 | 372 | self.pp_scheduler_max_retry_count = int( |
373 | 373 | os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10)) |
| 374 | + self.pp_multi_stream_sample = os.environ.get( |
| 375 | + "TRTLLM_PP_MULTI_STREAM_SAMPLE", "1") == "1" |
374 | 376 | self.sample_stream = torch.cuda.Stream() |
375 | 377 | self.start_sample_event = torch.cuda.Event() |
376 | 378 | self.finish_sample_event = torch.cuda.Event() |
| 379 | + if (self.dist.pp_size > 1 and self.pp_multi_stream_sample |
| 380 | + and isinstance(self.sampler, TRTLLMSampler)): |
| 381 | + # TRTLLM sampler uses default stream for store and algorithms. |
| 382 | + # To enable multi-stream sampling, we need to re-initialize |
| 383 | + # the sampler store and algorithms on the sample stream. |
| 384 | + with torch.cuda.stream(self.sample_stream): |
| 385 | + self.sampler._initialize_store() |
| 386 | + self.sampler._instantiate_algorithms() |
377 | 387 |
|
378 | 388 | # Set of request IDs that are currently in flight across all micro batches. |
379 | 389 | # The scheduler will avoid scheduling requests that are already in flight. |
@@ -1216,8 +1226,7 @@ def _executor_loop_pp(self): |
1216 | 1226 | guided_decoder_failed_requests = self.guided_decoder.execute( |
1217 | 1227 | batch_outputs['logits']) |
1218 | 1228 |
|
1219 | | - if os.environ.get("TRTLLM_PP_MULTI_STREAM_SAMPLE", |
1220 | | - "1") == "1": |
| 1229 | + if self.pp_multi_stream_sample: |
1221 | 1230 | # Wait for the previous sample to finish. |
1222 | 1231 | self.finish_sample_event.wait() |
1223 | 1232 | # Copy the batch outputs as sampler inputs |
|
0 commit comments