From 593f514f882ed362f6d3e263c58489d3564896b8 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Wed, 31 Dec 2025 12:12:17 +0000 Subject: [PATCH 1/2] fully non-blocking pipeline parallelism executor loop. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- .../_torch/distributed/communicator.py | 4 +- .../pyexecutor/executor_request_queue.py | 24 +- .../_torch/pyexecutor/hang_detector.py | 85 ++++ tensorrt_llm/_torch/pyexecutor/py_executor.py | 440 ++++++++++++------ tensorrt_llm/_torch/utils.py | 26 +- tensorrt_llm/_utils.py | 11 + 6 files changed, 431 insertions(+), 159 deletions(-) create mode 100644 tensorrt_llm/_torch/pyexecutor/hang_detector.py diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 7b6af1188b0..7153f854290 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -792,9 +792,9 @@ def send(self, tensor: torch.Tensor, dest: Optional[int] = None): self.nccl_comm.send(tensor, dest) return - self.tensor_ready_event.record() + tensor = tensor.clone() + self.send_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.send_stream): - self.tensor_ready_event.wait() self.nccl_comm.send(tensor, dest) def recv(self, tensor: torch.Tensor, src: Optional[int] = None): diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 120c42dbd2c..ba9c05cecbf 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -14,6 +14,7 @@ from tensorrt_llm.mapping import CpType from ..distributed import Distributed +from .hang_detector import HangDetector from .llm_request import (ExecutorRequest, LlmRequest, executor_request_to_llm_request) @@ -47,10 +48,17 @@ def is_control_request(self): class ExecutorRequestQueue: """Handles fetching and processing of new requests from the request queue.""" - def __init__(self, dist: Distributed, enable_attention_dp: bool, - max_batch_size: int, max_beam_width: int, - max_num_active_requests: int, enable_iter_perf_stats: bool, - batch_wait_timeout_ms: float): + def __init__( + self, + dist: Distributed, + enable_attention_dp: bool, + max_batch_size: int, + max_beam_width: int, + max_num_active_requests: int, + enable_iter_perf_stats: bool, + batch_wait_timeout_ms: float, + hang_detector: Optional[HangDetector] = None, + ): self.dist = dist self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue() self.waiting_queue: deque[RequestQueueItem] = deque() @@ -66,6 +74,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool, self.active = True self.batch_wait_timeout_ms = batch_wait_timeout_ms self.send_requests_handler = None + self.hang_detector = hang_detector or HangDetector() # State tracking self.num_fetch_requests = 0 @@ -303,7 +312,9 @@ def _fetch_and_process_requests( self.request_accumulated.clear() # Reset timeout to 0 to avoid hanging when no new requests are available timeout = datetime.timedelta(0) + self.hang_detector.pause() new_requests.extend(self._get_from_request_queue(timeout)) + self.hang_detector.checkpoint() # Broadcast requests and handle Python objects new_requests, py_request_objects = self._handle_request_broadcasting( @@ -477,8 +488,10 @@ def _handle_request_broadcasting(self, # Preserve original `new_requests` on rank 0 _ = self._broadcast_new_requests(new_requests, py_request_objects) else: + self.hang_detector.pause() new_requests, py_request_objects = self._broadcast_new_requests( new_requests, py_request_objects) + self.hang_detector.checkpoint() return new_requests, py_request_objects @@ -589,7 +602,8 @@ def _broadcast_new_requests( # Broadcast within first tp group before send/recv chain to other tp groups if self.dist.tp_size > 1 and self.dist.is_first_pp_rank: - payloads = self.dist.tp_broadcast(payloads, root=0) + with nvtx_range("tp_broadcast_requests"): + payloads = self.dist.tp_broadcast(payloads, root=0) # Tag for communication tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts diff --git a/tensorrt_llm/_torch/pyexecutor/hang_detector.py b/tensorrt_llm/_torch/pyexecutor/hang_detector.py new file mode 100644 index 00000000000..42ed727de37 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/hang_detector.py @@ -0,0 +1,85 @@ +import asyncio +import sys +import threading +import traceback +from typing import Callable, Optional + +from tensorrt_llm.logger import logger + + +class HangDetector: + def __init__( + self, timeout: Optional[int] = None, on_detected: Optional[Callable[[], None]] = None + ): + self.timeout = timeout or 300 + self.on_detected = on_detected or (lambda: None) + self.task = None + self.loop = None + self.loop_thread = None + self.lock = threading.Lock() + self.active = False + self._detected = False + + def start(self): + """Start monitoring for hangs""" + + def run_loop(): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + self.active = True + self.loop = asyncio.new_event_loop() + self.loop_thread = threading.Thread(target=run_loop, daemon=True, name="hang_detector_loop") + self.loop_thread.start() + + async def _detect_hang(self): + await asyncio.sleep(self.timeout) + with self.lock: + self._detected = True + logger.error(f"Hang detected after {self.timeout} seconds.") + self.print_all_stacks() + self.on_detected() + + def print_all_stacks(self): + """Print stack traces for all threads""" + for thread_id, frame in sys._current_frames().items(): + logger.error( + f"Thread {thread_id} stack trace:\n" + "".join(traceback.format_stack(frame)) + ) + + def detected(self): + """Return True if hang is detected""" + with self.lock: + return self._detected + + def checkpoint(self): + """Call this periodically in your code""" + self.pause() + if self.active: + self.task = asyncio.run_coroutine_threadsafe(self._detect_hang(), self.loop) + + def pause(self): + """Pause monitoring""" + if self.task is not None and not self.task.done(): + self.task.cancel() + self.task = None + + def stop(self): + """Stop monitoring""" + self.active = False + self.pause() + if self.loop is not None: + # Cancel all pending tasks before stopping the loop + def cancel_all_tasks(): + for task in asyncio.all_tasks(self.loop): + if not task.done(): + task.cancel() + self.loop.call_soon(self.loop.stop) + + self.loop.call_soon_threadsafe(cancel_all_tasks) + + if self.loop_thread is not None and self.loop_thread.is_alive(): + self.loop_thread.join() + + self.loop = None + self.loop_thread = None diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index f9a6d6c3133..5b229b7a778 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2,16 +2,17 @@ import datetime import functools import os -import pickle # nosec B403 import threading import time import traceback from contextlib import contextmanager +from queue import Queue from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch +from mpi4py import MPI +from mpi4py.util import pkl5 -from tensorrt_llm._torch.expert_statistic import ExpertStatistic from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds try: @@ -19,10 +20,9 @@ except ImportError: from cuda import cudart -from tensorrt_llm._torch.pyexecutor.resource_manager import ( - ResourceManagerType, request_context) from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled, - nvtx_range, trace_func) + nvtx_range, set_thread_local_mpi_comm, + trace_func) from tensorrt_llm.bindings.executor import (DisServingRequestStats, FinishReason, InflightBatchingStats, IterationStats, KvCacheStats, @@ -37,6 +37,7 @@ from tensorrt_llm.runtime.generation import CUASSERT from ..distributed import Distributed +from ..expert_statistic import ExpertStatistic from ..models.modeling_utils import DecoderModelForCausalLM from ..modules.decoder_layer import DecoderLayer from ..speculative.drafter import Drafter @@ -46,12 +47,14 @@ from .guided_decoder import GuidedDecoder from .handle_additional_outputs import HandleAdditionalOutputs from .handle_logits import HandleLogits +from .hang_detector import HangDetector from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, LlmResponse, get_draft_token_length) from .model_engine import ModelEngine -from .resource_manager import ResourceManager +from .resource_manager import (ResourceManager, ResourceManagerType, + request_context) from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState, SampleStateTensors) from .scheduler import (RequestScheduler, ScheduledRequests, @@ -66,9 +69,10 @@ PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE" # Unique tag base to avoid collisions with token/logits comms -TERMINATION_COMM_TAG_BASE = 20000 -PP_COMM_TAG_SCHEDULE_RESULT = 21000 -PP_COMM_TAG_SAMPLE_STATE_BASE = 21001 +TERMINATION_COMM_TAG = 20000 +PP_COMM_TAG_SCHEDULE_RESULT = 20001 +PP_COMM_TAG_EXECUTED_BATCH_NUM = 20002 +PP_COMM_TAG_SAMPLE_STATE = 20003 @functools.cache @@ -114,6 +118,7 @@ class BatchStatePP(BatchState): class PyExecutor: + MIN_ASYNC_MICRO_BATCH_NUM = 128 def __init__(self, resource_manager, @@ -136,7 +141,8 @@ def __init__(self, kv_connector_manager: Optional[KvCacheConnectorManager] = None, max_seq_len: Optional[int] = None, peft_cache_config: Optional[PeftCacheConfig] = None, - virtual_memory_pools: Optional[dict] = None): + virtual_memory_pools: Optional[dict] = None, + hang_detection_timeout: Optional[int] = None): super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = dist.rank @@ -231,14 +237,19 @@ def __init__(self, os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0)) # list of requests in each PP micro batch - self.num_micro_batches = self.dist.pp_size + self.num_micro_batches = max(self.dist.pp_size, + self.MIN_ASYNC_MICRO_BATCH_NUM) self.micro_batches: List[BatchStatePP | None] = [None] * self.num_micro_batches self.send_handles = [None] * self.num_micro_batches # schedule handle for PP to propagate the first PP rank's schedule result - self.send_schedule_handler = None + self.send_schedule_handles = [None] * self.num_micro_batches + self.send_expected_batch_num_handles = [None] * self.num_micro_batches + self.unhandled_batch_counter = 0 self.pp_scheduler_max_retry_count = int( os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10)) + self.sample_stream = torch.cuda.Stream() + self.finish_sample_event = torch.cuda.Event() # Set of request IDs that are currently in flight across all micro batches. # The scheduler will avoid scheduling requests that are already in flight. @@ -257,6 +268,15 @@ def __init__(self, self.adp_ctx_batching_wait_iters_count = 0 self.batch_wait_iters_count = 0 + def on_detected(): + self._handle_errors( + f"Hang detected on rank {self.global_rank} in PyExecutor.") + self.shutdown_event.set() + self.is_shutdown = True + + self.hang_detector = HangDetector(timeout=hang_detection_timeout, + on_detected=on_detected) + # request fetcher initialization self._set_global_steady_clock_offset() self.executor_request_queue = ExecutorRequestQueue( @@ -267,6 +287,7 @@ def __init__(self, max_num_active_requests=self.max_num_active_requests, enable_iter_perf_stats=self.enable_iter_perf_stats, batch_wait_timeout_ms=self.batch_wait_timeout_ms, + hang_detector=self.hang_detector, ) self.executor_request_queue.set_exclude_last_generation_logits( self.disable_overlap_scheduler, self.dist.pp_size) @@ -364,7 +385,25 @@ def is_warmup(self, value: bool): def start_worker(self): with self.worker_lock: - if self.worker_started == False: + if not self.worker_started: + if self.dist.pp_size > 1: + self.broadcast_sample_state_comm = pkl5.Intracomm( + MPI.COMM_WORLD.Dup()) + self.executed_batch_queue: Queue[BatchStatePP] = Queue( + maxsize=self.num_micro_batches) + self.executed_batch_response_queue: Queue[ + BatchStatePP] = Queue(maxsize=-1) + broadcast_sample_state_loop = self._broadcast_sample_state_loop + if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): + broadcast_sample_state_loop = trace_func( + broadcast_sample_state_loop) + self.broadcast_sample_state_handler = threading.Thread( + target=broadcast_sample_state_loop, + daemon=True, + name="broadcast_sample_state_handler", + ) + self.broadcast_sample_state_handler.start() + self.hang_detector.start() self.worker_thread = threading.Thread( target=self._event_loop_wrapper, daemon=True) self.worker_thread.start() @@ -453,7 +492,12 @@ def shutdown(self): """ self.executor_request_queue.enqueue_shutdown_request() self.shutdown_event.wait() + if self.hang_detector.detected(): + return self.worker_thread.join() + if self.dist.pp_size > 1: + self.executed_batch_queue.put(None) + self.broadcast_sample_state_handler.join() self.worker_started = False for manager in self.resource_manager.resource_managers.values(): if manager: @@ -851,41 +895,58 @@ def _process_iter_stats( def _executor_loop_cleanup(self): - for h in self.send_handles: - if h is not None: - h.wait() + for i in range(self.num_micro_batches): + self.wait_on_pp_send_handles(self.send_handles, i) + self.wait_on_pp_send_handles(self.send_schedule_handles, i) + self.wait_on_pp_send_handles(self.send_expected_batch_num_handles, + i) with self.response_cv: self.is_shutdown = True self.response_cv.notify_all() self.shutdown_event.set() - def _pp_schedule_and_propagate(self): + def _pp_schedule_and_propagate(self, microbatch_id: int): """The first PP rank schedules the requests and propagates the result to all other PP ranks.""" - # The first PP rank schedules the requests, other ranks receive the schedule result from the previous PP rank. - if self.dist.is_first_pp_rank: + # For TP cases, the first rank schedules the requests. + # For DP cases, the first PP rank schedules the requests. + scheduled_batch = None + serializable_schedule = None + is_tp_broadcast = self.dist.tp_size > 1 and not self.enable_attention_dp + if self.dist.rank == 0 or (self.dist.is_first_pp_rank + and not is_tp_broadcast): scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( ) serializable_schedule = SerializableSchedulerOutput.from_scheduler_result( scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs) - else: + + # Broadcast within first tp group before send/recv chain to other tp groups + if self.dist.is_first_pp_rank and is_tp_broadcast: + with nvtx_range("tp_broadcast_schedule"): + serializable_schedule = self.dist.tp_broadcast( + serializable_schedule, root=0) + + # Other ranks receive the schedule result from the previous PP rank. + if not self.dist.is_first_pp_rank: with nvtx_range("recv_schedule_from_prev_pp"): serializable_schedule = self.dist.recv_object( self.dist.prev_pp_rank, PP_COMM_TAG_SCHEDULE_RESULT) - scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result( - self.active_requests) # Propagate the schedule result to the next PP rank except the last PP rank. if not self.dist.is_last_pp_rank: - if self.send_schedule_handler is not None: - with nvtx_range("wait_send_schedule_handler"): - self.send_schedule_handler.wait() + self.wait_on_pp_send_handles(self.send_schedule_handles, + microbatch_id) with nvtx_range("send_schedule_to_next_pp"): - self.send_schedule_handler = self.dist.isend_object( - serializable_schedule, self.dist.next_pp_rank, - PP_COMM_TAG_SCHEDULE_RESULT) + self.send_schedule_handles[ + microbatch_id] = self.dist.isend_object( + serializable_schedule, self.dist.next_pp_rank, + PP_COMM_TAG_SCHEDULE_RESULT) + + if scheduled_batch is None: + scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result( + self.active_requests) return scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs def _pp_retry_until_can_schedule(self, scheduled_batch): @@ -941,6 +1002,7 @@ def _executor_loop_pp(self): iter_start_time = time.time() iter_stats = None while True: + self.hang_detector.checkpoint() profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() @@ -948,6 +1010,14 @@ def _executor_loop_pp(self): # Fetch new requests from request queue new_requests = self._fetch_and_activate_new_requests() if self.should_stop_processing: + while self.unhandled_batch_counter > 0: + with nvtx_range("get_executed_batch"): + executed_batch = self.executed_batch_response_queue.get( + ) + self._handle_executed_batch(executed_batch) + self.unhandled_batch_counter -= 1 + + self.hang_detector.stop() break self._handle_control_request() @@ -965,8 +1035,8 @@ def _executor_loop_pp(self): # Stage 0: first PP rank schedules requests and propagates the result to all other PP ranks. scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._pp_schedule_and_propagate( - ) - if not self.dist.is_first_pp_rank: + microbatch_id) + if self.dist.rank != 0: # Retry until current rank can run first PP's schedule result. self._pp_retry_until_can_schedule(scheduled_batch) # Run scheduler locally because scheduler may change llm requests' state. @@ -1048,8 +1118,25 @@ def _executor_loop_pp(self): guided_decoder_failed_requests = self.guided_decoder.execute( batch_outputs['logits']) - sample_state = self._sample_async( - scheduled_batch, batch_outputs) + if os.environ.get("TRTLLM_PP_MULTI_STREAM_SAMPLE", + "1") == "1": + # Wait for the previous sample to finish. + self.finish_sample_event.wait() + # Copy the batch outputs as sampler inputs + # to avoid next forward step overwriting them. + batch_outputs_copy = { + name: tensor.clone() + for name, tensor in batch_outputs.items() + } + self.sample_stream.wait_stream( + torch.cuda.current_stream()) + with torch.cuda.stream(self.sample_stream): + sample_state = self._sample_async( + scheduled_batch, batch_outputs_copy) + self.finish_sample_event.record() + else: + sample_state = self._sample_async( + scheduled_batch, batch_outputs) assert sample_state is not None, "Sampling failed" # Handle guided decoder errors after _sample_async to avoid state conflicts. @@ -1075,9 +1162,14 @@ def _executor_loop_pp(self): self.micro_batches[microbatch_id] = batch_state # sync sampler for previous microbatch to start new sample state comm chain. - prev_microbatch_id = (microbatch_id - - 1) % self.num_micro_batches - previous_batch = self.micro_batches[prev_microbatch_id] + if self.dist.is_last_pp_rank: + prev_microbatch_id = (microbatch_id - + 1) % self.num_micro_batches + previous_batch = self.micro_batches[prev_microbatch_id] + elif can_queue: + previous_batch = self.previous_batch + else: + previous_batch = None if previous_batch is not None: with nvtx_range("sync_previous_sampler_event"): previous_batch.sample_state.sampler_event.synchronize() @@ -1085,104 +1177,181 @@ def _executor_loop_pp(self): # Stage 2: Communicate sample state for previous batch between ranks # send/recv chain: (pp_size - 1) -> 0 -> 1 -> ... -> (pp_size - 2) # intermediate ranks: send/recv sample state for next microbatch to allow overlap - offset = -1 if self.dist.is_last_pp_rank else 1 + offset = -1 if self.dist.is_last_pp_rank else ( + 1 - self.dist.pp_size) prev_microbatch_id = (microbatch_id + offset) % self.num_micro_batches previous_batch = self.micro_batches[prev_microbatch_id] - tag = PP_COMM_TAG_SAMPLE_STATE_BASE + prev_microbatch_id if previous_batch is not None: - sample_state = previous_batch.sample_state - if not self.dist.is_last_pp_rank: - # Receive tokens from previous pp rank (w.r.t model forward direction) - with nvtx_range("recv_sample_state"): - sample_state.host = self.dist.recv_object( - src=self.dist.prev_pp_rank, - tag=tag, + self.executed_batch_queue.put(previous_batch) + self.unhandled_batch_counter += 1 + + dequeue_counter = 0 + executed_batch_num = 0 + + # The first rank determines the number of executed batches. + if self.dist.rank == 0: + executed_batches = [] + # Wait for at least one batch to finish if no new request is available. + must_get = not can_queue + while not self.executed_batch_response_queue.empty() or ( + must_get and self.unhandled_batch_counter > 0): + with nvtx_range("get_executed_batch"): + executed_batches.append( + self.executed_batch_response_queue.get()) + dequeue_counter += 1 + must_get = False + executed_batch_num = dequeue_counter + + # Broadcast the number of executed batches to other ranks. + if self.dist.is_first_pp_rank and self.dist.tp_size > 1: + with nvtx_range("tp_broadcast_executed_batch_num"): + executed_batch_num = self.dist.tp_broadcast( + executed_batch_num, + root=0, + ) + if not self.dist.is_first_pp_rank: + with nvtx_range("recv_expected_batch_num"): + executed_batch_num = self.dist.recv_object( + src=self.dist.prev_pp_rank, + tag=PP_COMM_TAG_EXECUTED_BATCH_NUM, + ) + if not self.dist.is_last_pp_rank: + self.wait_on_pp_send_handles( + self.send_expected_batch_num_handles, microbatch_id) + with nvtx_range("send_expected_batch_num"): + self.send_expected_batch_num_handles[ + microbatch_id] = self.dist.isend_object( + executed_batch_num, + dest=self.dist.next_pp_rank, + tag=PP_COMM_TAG_EXECUTED_BATCH_NUM, ) - # Send tokens to next pp rank (w.r.t model forward direction) - # Second last rank does not need to since last rank has original decoded tokens - if not self.dist.is_second_last_pp_rank: - self.wait_on_pp_send_handles(prev_microbatch_id) - with nvtx_range("send_sample_state"): - self.send_handles[ - prev_microbatch_id] = self.dist.isend_object( - sample_state.host, - dest=self.dist.next_pp_rank, - tag=tag) - - # Stage 3: Finalize previous batch that finished sample state communication - # In last pp rank, stage 2 and 3 process different previous batches - prev_microbatch_id = (microbatch_id + - 1) % self.num_micro_batches - previous_batch = self.micro_batches[prev_microbatch_id] - finished_requests = [] - if previous_batch is not None: - with torch.cuda.nvtx.range("_handle_previous_batch_pp"): - sample_state = previous_batch.sample_state - sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs - self._update_requests(previous_batch.sample_state) - - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in previous_batch.scheduled_ctx_reqs: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length): - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) + # Handle executed batches. + if self.dist.rank != 0: + while dequeue_counter < executed_batch_num: + with nvtx_range("get_executed_batch"): + executed_batch = self.executed_batch_response_queue.get( + ) + self._handle_executed_batch(executed_batch) + dequeue_counter += 1 + else: + for executed_batch in executed_batches: + self._handle_executed_batch(executed_batch) + self.unhandled_batch_counter -= executed_batch_num - if self.kv_cache_transceiver: - self._send_disagg_ctx_cache( - previous_batch.scheduled_ctx_reqs) - self._handle_canceled_requests() + # march forward in microbatch slots + if can_queue: + self.previous_batch = batch_state + self.micro_batches[prev_microbatch_id] = None + microbatch_id = (microbatch_id + 1) % self.num_micro_batches + self.iter_counter += 1 - self._handle_logits_communication( - previous_batch, prev_microbatch_id) + def _broadcast_sample_state_loop(self): + logger.debug( + f"Starting broadcast sample state loop for pp_rank {self.dist.pp_rank}" + ) + torch.cuda.set_device(self.device_id) + # ensure the context is created, otherwise, some MPI calls will fail. + CUASSERT(cudart.cudaSetDevice(self.device_id)) + # Acquiring pkl5's send/recv locks from both executor loop thread + # and this thread will cause perf drop and even deadlock. + # We create new MPI comm to avoid these issues. + set_thread_local_mpi_comm(self.broadcast_sample_state_comm) + while True: + previous_batch = self.executed_batch_queue.get() + if previous_batch is None: + break + self._broadcast_sample_state(previous_batch) - finished_requests = self._handle_responses() - previous_scheduled_batch = previous_batch.sample_state.scheduled_requests - attn_metadata = getattr(self.model_engine, - 'attn_metadata', None) - kv_cache_dtype_byte_size = getattr( - self.model_engine, 'kv_cache_dtype_byte_size', None) - self.resource_manager.update_resources( - previous_scheduled_batch, attn_metadata, - kv_cache_dtype_byte_size) + def _broadcast_sample_state(self, previous_batch: Optional[BatchStatePP]): + if previous_batch is None: + return - self._remove_inflight_ids(previous_batch) + tag = PP_COMM_TAG_SAMPLE_STATE + microbatch_id = previous_batch.microbatch_id + sample_state = previous_batch.sample_state - self.wait_on_pp_send_handles(prev_microbatch_id) - self.micro_batches[prev_microbatch_id] = None + if not self.dist.is_last_pp_rank: + # Receive tokens from previous pp rank (w.r.t model forward direction) + with nvtx_range("recv_sample_state"): + sample_state.host = self.dist.recv_object( + src=self.dist.prev_pp_rank, + tag=tag, + ) - if self.kv_cache_transceiver and self.ctx_in_transmission_requests: - self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() + self.executed_batch_response_queue.put(previous_batch) + + # Send tokens to next pp rank (w.r.t model forward direction) + # Second last rank does not need to since last rank has original decoded tokens + if not self.dist.is_second_last_pp_rank: + self.wait_on_pp_send_handles(self.send_handles, microbatch_id) + with nvtx_range("send_sample_state"): + self.send_handles[microbatch_id] = self.dist.isend_object( + sample_state.host, + dest=self.dist.next_pp_rank, + tag=tag, + ) - if self._disagg_pp_termination_handler is not None: - self._disagg_pp_termination_handler.terminate_pending_requests( - ) + def _handle_executed_batch(self, previous_batch: Optional[BatchStatePP]): + # Stage 3: Finalize previous batch that finished sample state communication + # In last pp rank, stage 2 and 3 process different previous batches + finished_requests = [] + if previous_batch is not None: + with torch.cuda.nvtx.range("_handle_previous_batch_pp"): + sample_state = previous_batch.sample_state + sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs + self._update_requests(previous_batch.sample_state) + + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: + for req in previous_batch.scheduled_ctx_reqs: + if req.is_context_only_request and ( + req.is_context_finished + or req.is_finished_due_to_length): + block_id = self.kv_cache_manager.store_blocks_for_reuse( + req, True) + self.ctx_in_transmission_requests[ + req.py_request_id] = ( + (req, block_id, + self.ctx_in_transmission_counter)) - # march forward in microbatch slots - microbatch_id = (microbatch_id + 1) % self.num_micro_batches + if self.kv_cache_transceiver: + self._send_disagg_ctx_cache( + previous_batch.scheduled_ctx_reqs) + self._handle_canceled_requests() + + finished_requests = self._handle_responses() + previous_scheduled_batch = previous_batch.sample_state.scheduled_requests + attn_metadata = getattr(self.model_engine, 'attn_metadata', + None) + kv_cache_dtype_byte_size = getattr(self.model_engine, + 'kv_cache_dtype_byte_size', + None) + self.resource_manager.update_resources( + previous_scheduled_batch, attn_metadata, + kv_cache_dtype_byte_size) + + self._remove_inflight_ids(previous_batch) + + if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + self._check_kv_transfer_timeout() + self._terminate_disagg_ctx_finished_requests() - if self.enable_iter_perf_stats and previous_batch is not None: - sample_state = previous_batch.sample_state - sample_state.scheduled_requests.context_requests = previous_batch.scheduled_ctx_reqs - self._process_iter_stats(finished_requests, - self.active_requests, - previous_batch, microbatch_id) + if self._disagg_pp_termination_handler is not None: + self._disagg_pp_termination_handler.terminate_pending_requests() - self.iter_counter += 1 + if self.enable_iter_perf_stats and previous_batch is not None: + sample_state = previous_batch.sample_state + sample_state.scheduled_requests.context_requests = previous_batch.scheduled_ctx_reqs + self._process_iter_stats(finished_requests, self.active_requests, + previous_batch, + previous_batch.microbatch_id) @nvtx_range("wait_on_pp_send_handles") - def wait_on_pp_send_handles(self, microbatch_id): - if self.send_handles[microbatch_id] is not None: - self.send_handles[microbatch_id].wait() - self.send_handles[microbatch_id] = None + def wait_on_pp_send_handles(self, send_handles, microbatch_id): + if send_handles[microbatch_id] is not None: + send_handles[microbatch_id].wait() + send_handles[microbatch_id] = None def _can_queue(self, scheduled_batch): @@ -1313,6 +1482,7 @@ def _executor_loop(self): iter_start_time = time.time() iter_stats = None while True: + self.hang_detector.checkpoint() profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() @@ -1321,6 +1491,7 @@ def _executor_loop(self): self._handle_control_request() if scheduled_batch is None: + self.hang_detector.stop() break self._pause_requests(scheduled_batch.paused_requests) @@ -1516,6 +1687,7 @@ def _executor_loop_overlap(self): previous_tensors_device = None can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True while True: + self.hang_detector.checkpoint() profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() @@ -1524,6 +1696,7 @@ def _executor_loop_overlap(self): self._handle_control_request() if scheduled_batch is None: + self.hang_detector.stop() break # In gen-only benchmarking mode, wait until the number of scheduled generation # requests reaches the required threshold before starting forward pass, @@ -2645,41 +2818,6 @@ def _terminate_disagg_ctx_finished_requests(self): counter - 1)) - def _handle_logits_communication(self, previous_batch, prev_microbatch_id): - """Handle logits communication between pipeline parallel ranks. - - If logits were requested, the last PP rank sends to the first PP rank (who sends responses) - the logits of the requests that have finished. - - Args: - previous_batch: The previous batch state - prev_microbatch_id: The microbatch ID for the previous batch - """ - # NOTE: If the rank processing the logits ever becomes the same as - # the rank sending the responses, this code can be removed. - finished_reqs = [ - r for r in - previous_batch.sample_state.scheduled_requests.all_requests() - if r.state == LlmRequestState.GENERATION_COMPLETE and ( - r.py_return_context_logits or r.py_return_generation_logits - or r.py_additional_outputs is not None) - ] - if self.dist.is_first_pp_rank and len(finished_reqs): - finished_reqs_py_results = [r.py_result for r in finished_reqs] - finished_reqs_py_results = self.dist.recv_object( - src=self.dist.prev_pp_rank, - tag=prev_microbatch_id, - ) - for req, py_result in zip(finished_reqs, finished_reqs_py_results): - req.py_result = py_result - - elif self.dist.is_last_pp_rank and len(finished_reqs): - self.wait_on_pp_send_handles(prev_microbatch_id) - self.send_handles[prev_microbatch_id] = self.dist.isend_object( - [r.py_result for r in finished_reqs], - dest=self.dist.next_pp_rank, - tag=prev_microbatch_id) - def _await_any_response(self, timeout: Optional[float] = None ) -> List[LlmResponse]: @@ -2819,7 +2957,7 @@ def __init__(self, dist, terminator_func: Callable[[LlmRequest], None]): self._pending_termination = {} self._terminating_iteration = 0 self._send_handle = None - self._comm_tag = TERMINATION_COMM_TAG_BASE + self._comm_tag = TERMINATION_COMM_TAG def terminate(self, request: LlmRequest): self._pending_termination[request.py_request_id] = request diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 1c3c02ca346..17b77810e1e 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -8,7 +8,8 @@ import torch from torch.nn import functional as F -from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor +from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor, + torch_dtype_to_str) from tensorrt_llm.mapping import Mapping from tensorrt_llm.math_utils import ceil_div, pad_up from tensorrt_llm.quantization.utils import fp4_utils @@ -404,3 +405,26 @@ def split(x: torch.Tensor, def relu2(x: torch.Tensor) -> torch.Tensor: return torch.square(F.relu(x)) + + +def tensor_to_str(x: torch.Tensor, num_elements: int = 10) -> str: + # Pass num_elements=-1 will print the whole tensor + if num_elements < 0: + num_elements = torch.numel(x) + if x.dtype in (torch.int32, torch.int64): + float_x = x.to(dtype=float) + else: + float_x = x + return ("Tensor(" + f"shape={tuple(x.shape)}, " + f"dtype={torch_dtype_to_str(x.dtype)}, " + f"device={x.device}, " + f"stats=(" + f"abs_mean={float_x.abs().mean().item():.3f}, " + f"mean={float_x.mean().item():.3f}, " + f"std={float_x.std().item():.3f}, " + f"max={x.max().item():.3f}, " + f"min={x.min().item():.3f}" + "), " + f"values={x.flatten()[:num_elements].tolist()}" + ")") diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 86ebaef371d..58942043976 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -22,6 +22,7 @@ import socket import struct import tempfile +import threading import trace import weakref from contextlib import contextmanager @@ -501,7 +502,17 @@ def set_mpi_comm(new_comm): comm = new_comm +thread_local_comm = threading.local() + + +def set_thread_local_mpi_comm(new_comm): + thread_local_comm.value = new_comm + + def mpi_comm(): + if hasattr(thread_local_comm, + "value") and thread_local_comm.value is not None: + return thread_local_comm.value return comm From d7b0ea80dd5679ba1117a795244cb91b2bb19a8b Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Thu, 1 Jan 2026 16:16:36 +0000 Subject: [PATCH 2/2] Fix perf issue. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 5b229b7a778..b0e209b1b75 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -118,7 +118,7 @@ class BatchStatePP(BatchState): class PyExecutor: - MIN_ASYNC_MICRO_BATCH_NUM = 128 + MIN_ASYNC_MICRO_BATCH_NUM = 1024 def __init__(self, resource_manager,