Skip to content

Commit 3fd5f0a

Browse files
committed
fully non-blocking pipeline parallelism executor loop.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 74832a1 commit 3fd5f0a

File tree

6 files changed

+421
-159
lines changed

6 files changed

+421
-159
lines changed

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,9 +803,9 @@ def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
803803
self.nccl_comm.send(tensor, dest)
804804
return
805805

806-
self.tensor_ready_event.record()
806+
tensor = tensor.clone()
807+
self.send_stream.wait_stream(torch.cuda.current_stream())
807808
with torch.cuda.stream(self.send_stream):
808-
self.tensor_ready_event.wait()
809809
self.nccl_comm.send(tensor, dest)
810810

811811
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from tensorrt_llm.mapping import CpType
1515

1616
from ..distributed import Distributed
17+
from .hang_detector import HangDetector
1718
from .llm_request import (ExecutorRequest, LlmRequest,
1819
executor_request_to_llm_request)
1920

@@ -50,7 +51,7 @@ class ExecutorRequestQueue:
5051
def __init__(self, dist: Distributed, enable_attention_dp: bool,
5152
max_batch_size: int, max_beam_width: int,
5253
max_num_active_requests: int, enable_iter_perf_stats: bool,
53-
batch_wait_timeout_ms: float):
54+
batch_wait_timeout_ms: float, hang_detector: HangDetector):
5455
self.dist = dist
5556
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
5657
self.waiting_queue: deque[RequestQueueItem] = deque()
@@ -66,6 +67,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
6667
self.active = True
6768
self.batch_wait_timeout_ms = batch_wait_timeout_ms
6869
self.send_requests_handler = None
70+
self.hang_detector = hang_detector
6971

7072
# State tracking
7173
self.num_fetch_requests = 0
@@ -303,7 +305,9 @@ def _fetch_and_process_requests(
303305
self.request_accumulated.clear()
304306
# Reset timeout to 0 to avoid hanging when no new requests are available
305307
timeout = datetime.timedelta(0)
308+
self.hang_detector.pause()
306309
new_requests.extend(self._get_from_request_queue(timeout))
310+
self.hang_detector.checkpoint()
307311

308312
# Broadcast requests and handle Python objects
309313
new_requests, py_request_objects = self._handle_request_broadcasting(
@@ -477,8 +481,10 @@ def _handle_request_broadcasting(self,
477481
# Preserve original `new_requests` on rank 0
478482
_ = self._broadcast_new_requests(new_requests, py_request_objects)
479483
else:
484+
self.hang_detector.pause()
480485
new_requests, py_request_objects = self._broadcast_new_requests(
481486
new_requests, py_request_objects)
487+
self.hang_detector.checkpoint()
482488

483489
return new_requests, py_request_objects
484490

@@ -589,7 +595,8 @@ def _broadcast_new_requests(
589595

590596
# Broadcast within first tp group before send/recv chain to other tp groups
591597
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
592-
payloads = self.dist.tp_broadcast(payloads, root=0)
598+
with nvtx_range("tp_broadcast_requests"):
599+
payloads = self.dist.tp_broadcast(payloads, root=0)
593600

594601
# Tag for communication
595602
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import asyncio
2+
import sys
3+
import threading
4+
import traceback
5+
from typing import Callable, Optional
6+
7+
from tensorrt_llm.logger import logger
8+
9+
10+
class HangDetector:
11+
def __init__(
12+
self, timeout: Optional[int] = None, on_detected: Optional[Callable[[], None]] = None
13+
):
14+
self.timeout = timeout or 60
15+
self.on_detected = on_detected or (lambda: None)
16+
self.task = None
17+
self.loop = None
18+
self.loop_thread = None
19+
self.lock = threading.Lock()
20+
self.active = False
21+
self._detected = False
22+
23+
def start(self):
24+
"""Start monitoring for hangs"""
25+
26+
def run_loop():
27+
asyncio.set_event_loop(self.loop)
28+
self.loop.run_forever()
29+
30+
self.active = True
31+
self.loop = asyncio.new_event_loop()
32+
self.loop_thread = threading.Thread(target=run_loop, daemon=True, name="hang_detector_loop")
33+
self.loop_thread.start()
34+
35+
async def _detect_hang(self):
36+
await asyncio.sleep(self.timeout)
37+
with self.lock:
38+
self._detected = True
39+
logger.error(f"Hang detected after {self.timeout} seconds.")
40+
self.print_all_stacks()
41+
self.on_detected()
42+
43+
def print_all_stacks(self):
44+
"""Print stack traces for all threads"""
45+
for thread_id, frame in sys._current_frames().items():
46+
logger.error(
47+
f"Thread {thread_id} stack trace:\n" + "".join(traceback.format_stack(frame))
48+
)
49+
50+
def detected(self):
51+
"""Return True if hang is detected"""
52+
with self.lock:
53+
return self._detected
54+
55+
def checkpoint(self):
56+
"""Call this periodically in your code"""
57+
self.pause()
58+
if self.active:
59+
self.task = asyncio.run_coroutine_threadsafe(self._detect_hang(), self.loop)
60+
61+
def pause(self):
62+
"""Pause monitoring"""
63+
if self.task is not None and not self.task.done():
64+
self.task.cancel()
65+
self.task = None
66+
67+
def stop(self):
68+
"""Stop monitoring"""
69+
self.active = False
70+
self.pause()
71+
if self.loop is not None:
72+
# Cancel all pending tasks before stopping the loop
73+
def cancel_all_tasks():
74+
for task in asyncio.all_tasks(self.loop):
75+
if not task.done():
76+
task.cancel()
77+
self.loop.call_soon(self.loop.stop)
78+
79+
self.loop.call_soon_threadsafe(cancel_all_tasks)
80+
81+
if self.loop_thread is not None and self.loop_thread.is_alive():
82+
self.loop_thread.join()
83+
84+
self.loop = None
85+
self.loop_thread = None

0 commit comments

Comments
 (0)