Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/distributed/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,9 +792,9 @@ def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
self.nccl_comm.send(tensor, dest)
return

self.tensor_ready_event.record()
tensor = tensor.clone()
self.send_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.send_stream):
self.tensor_ready_event.wait()
self.nccl_comm.send(tensor, dest)

def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
Expand Down
24 changes: 19 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tensorrt_llm.mapping import CpType

from ..distributed import Distributed
from .hang_detector import HangDetector
from .llm_request import (ExecutorRequest, LlmRequest,
executor_request_to_llm_request)

Expand Down Expand Up @@ -47,10 +48,17 @@ def is_control_request(self):
class ExecutorRequestQueue:
"""Handles fetching and processing of new requests from the request queue."""

def __init__(self, dist: Distributed, enable_attention_dp: bool,
max_batch_size: int, max_beam_width: int,
max_num_active_requests: int, enable_iter_perf_stats: bool,
batch_wait_timeout_ms: float):
def __init__(
self,
dist: Distributed,
enable_attention_dp: bool,
max_batch_size: int,
max_beam_width: int,
max_num_active_requests: int,
enable_iter_perf_stats: bool,
batch_wait_timeout_ms: float,
hang_detector: Optional[HangDetector] = None,
):
self.dist = dist
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
self.waiting_queue: deque[RequestQueueItem] = deque()
Expand All @@ -66,6 +74,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
self.active = True
self.batch_wait_timeout_ms = batch_wait_timeout_ms
self.send_requests_handler = None
self.hang_detector = hang_detector or HangDetector()

# State tracking
self.num_fetch_requests = 0
Expand Down Expand Up @@ -303,7 +312,9 @@ def _fetch_and_process_requests(
self.request_accumulated.clear()
# Reset timeout to 0 to avoid hanging when no new requests are available
timeout = datetime.timedelta(0)
self.hang_detector.pause()
new_requests.extend(self._get_from_request_queue(timeout))
self.hang_detector.checkpoint()

# Broadcast requests and handle Python objects
new_requests, py_request_objects = self._handle_request_broadcasting(
Expand Down Expand Up @@ -477,8 +488,10 @@ def _handle_request_broadcasting(self,
# Preserve original `new_requests` on rank 0
_ = self._broadcast_new_requests(new_requests, py_request_objects)
else:
self.hang_detector.pause()
new_requests, py_request_objects = self._broadcast_new_requests(
new_requests, py_request_objects)
self.hang_detector.checkpoint()

return new_requests, py_request_objects

Expand Down Expand Up @@ -589,7 +602,8 @@ def _broadcast_new_requests(

# Broadcast within first tp group before send/recv chain to other tp groups
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
payloads = self.dist.tp_broadcast(payloads, root=0)
with nvtx_range("tp_broadcast_requests"):
payloads = self.dist.tp_broadcast(payloads, root=0)

# Tag for communication
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
Expand Down
85 changes: 85 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/hang_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import asyncio
import sys
import threading
import traceback
from typing import Callable, Optional

from tensorrt_llm.logger import logger


class HangDetector:
def __init__(
self, timeout: Optional[int] = None, on_detected: Optional[Callable[[], None]] = None
):
self.timeout = timeout or 300
self.on_detected = on_detected or (lambda: None)
self.task = None
self.loop = None
self.loop_thread = None
self.lock = threading.Lock()
self.active = False
self._detected = False

def start(self):
"""Start monitoring for hangs"""

def run_loop():
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

self.active = True
self.loop = asyncio.new_event_loop()
self.loop_thread = threading.Thread(target=run_loop, daemon=True, name="hang_detector_loop")
self.loop_thread.start()

async def _detect_hang(self):
await asyncio.sleep(self.timeout)
with self.lock:
self._detected = True
logger.error(f"Hang detected after {self.timeout} seconds.")
self.print_all_stacks()
self.on_detected()

def print_all_stacks(self):
"""Print stack traces for all threads"""
for thread_id, frame in sys._current_frames().items():
logger.error(
f"Thread {thread_id} stack trace:\n" + "".join(traceback.format_stack(frame))
)

def detected(self):
"""Return True if hang is detected"""
with self.lock:
return self._detected

def checkpoint(self):
"""Call this periodically in your code"""
self.pause()
if self.active:
self.task = asyncio.run_coroutine_threadsafe(self._detect_hang(), self.loop)

def pause(self):
"""Pause monitoring"""
if self.task is not None and not self.task.done():
self.task.cancel()
self.task = None

def stop(self):
"""Stop monitoring"""
self.active = False
self.pause()
if self.loop is not None:
# Cancel all pending tasks before stopping the loop
def cancel_all_tasks():
for task in asyncio.all_tasks(self.loop):
if not task.done():
task.cancel()
self.loop.call_soon(self.loop.stop)

self.loop.call_soon_threadsafe(cancel_all_tasks)

if self.loop_thread is not None and self.loop_thread.is_alive():
self.loop_thread.join()

self.loop = None
self.loop_thread = None
Loading