Skip to content

Commit 593f514

Browse files
committed
fully non-blocking pipeline parallelism executor loop.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent d944430 commit 593f514

File tree

6 files changed

+431
-159
lines changed

6 files changed

+431
-159
lines changed

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,9 +792,9 @@ def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
792792
self.nccl_comm.send(tensor, dest)
793793
return
794794

795-
self.tensor_ready_event.record()
795+
tensor = tensor.clone()
796+
self.send_stream.wait_stream(torch.cuda.current_stream())
796797
with torch.cuda.stream(self.send_stream):
797-
self.tensor_ready_event.wait()
798798
self.nccl_comm.send(tensor, dest)
799799

800800
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from tensorrt_llm.mapping import CpType
1515

1616
from ..distributed import Distributed
17+
from .hang_detector import HangDetector
1718
from .llm_request import (ExecutorRequest, LlmRequest,
1819
executor_request_to_llm_request)
1920

@@ -47,10 +48,17 @@ def is_control_request(self):
4748
class ExecutorRequestQueue:
4849
"""Handles fetching and processing of new requests from the request queue."""
4950

50-
def __init__(self, dist: Distributed, enable_attention_dp: bool,
51-
max_batch_size: int, max_beam_width: int,
52-
max_num_active_requests: int, enable_iter_perf_stats: bool,
53-
batch_wait_timeout_ms: float):
51+
def __init__(
52+
self,
53+
dist: Distributed,
54+
enable_attention_dp: bool,
55+
max_batch_size: int,
56+
max_beam_width: int,
57+
max_num_active_requests: int,
58+
enable_iter_perf_stats: bool,
59+
batch_wait_timeout_ms: float,
60+
hang_detector: Optional[HangDetector] = None,
61+
):
5462
self.dist = dist
5563
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
5664
self.waiting_queue: deque[RequestQueueItem] = deque()
@@ -66,6 +74,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
6674
self.active = True
6775
self.batch_wait_timeout_ms = batch_wait_timeout_ms
6876
self.send_requests_handler = None
77+
self.hang_detector = hang_detector or HangDetector()
6978

7079
# State tracking
7180
self.num_fetch_requests = 0
@@ -303,7 +312,9 @@ def _fetch_and_process_requests(
303312
self.request_accumulated.clear()
304313
# Reset timeout to 0 to avoid hanging when no new requests are available
305314
timeout = datetime.timedelta(0)
315+
self.hang_detector.pause()
306316
new_requests.extend(self._get_from_request_queue(timeout))
317+
self.hang_detector.checkpoint()
307318

308319
# Broadcast requests and handle Python objects
309320
new_requests, py_request_objects = self._handle_request_broadcasting(
@@ -477,8 +488,10 @@ def _handle_request_broadcasting(self,
477488
# Preserve original `new_requests` on rank 0
478489
_ = self._broadcast_new_requests(new_requests, py_request_objects)
479490
else:
491+
self.hang_detector.pause()
480492
new_requests, py_request_objects = self._broadcast_new_requests(
481493
new_requests, py_request_objects)
494+
self.hang_detector.checkpoint()
482495

483496
return new_requests, py_request_objects
484497

@@ -589,7 +602,8 @@ def _broadcast_new_requests(
589602

590603
# Broadcast within first tp group before send/recv chain to other tp groups
591604
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
592-
payloads = self.dist.tp_broadcast(payloads, root=0)
605+
with nvtx_range("tp_broadcast_requests"):
606+
payloads = self.dist.tp_broadcast(payloads, root=0)
593607

594608
# Tag for communication
595609
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import asyncio
2+
import sys
3+
import threading
4+
import traceback
5+
from typing import Callable, Optional
6+
7+
from tensorrt_llm.logger import logger
8+
9+
10+
class HangDetector:
11+
def __init__(
12+
self, timeout: Optional[int] = None, on_detected: Optional[Callable[[], None]] = None
13+
):
14+
self.timeout = timeout or 300
15+
self.on_detected = on_detected or (lambda: None)
16+
self.task = None
17+
self.loop = None
18+
self.loop_thread = None
19+
self.lock = threading.Lock()
20+
self.active = False
21+
self._detected = False
22+
23+
def start(self):
24+
"""Start monitoring for hangs"""
25+
26+
def run_loop():
27+
asyncio.set_event_loop(self.loop)
28+
self.loop.run_forever()
29+
30+
self.active = True
31+
self.loop = asyncio.new_event_loop()
32+
self.loop_thread = threading.Thread(target=run_loop, daemon=True, name="hang_detector_loop")
33+
self.loop_thread.start()
34+
35+
async def _detect_hang(self):
36+
await asyncio.sleep(self.timeout)
37+
with self.lock:
38+
self._detected = True
39+
logger.error(f"Hang detected after {self.timeout} seconds.")
40+
self.print_all_stacks()
41+
self.on_detected()
42+
43+
def print_all_stacks(self):
44+
"""Print stack traces for all threads"""
45+
for thread_id, frame in sys._current_frames().items():
46+
logger.error(
47+
f"Thread {thread_id} stack trace:\n" + "".join(traceback.format_stack(frame))
48+
)
49+
50+
def detected(self):
51+
"""Return True if hang is detected"""
52+
with self.lock:
53+
return self._detected
54+
55+
def checkpoint(self):
56+
"""Call this periodically in your code"""
57+
self.pause()
58+
if self.active:
59+
self.task = asyncio.run_coroutine_threadsafe(self._detect_hang(), self.loop)
60+
61+
def pause(self):
62+
"""Pause monitoring"""
63+
if self.task is not None and not self.task.done():
64+
self.task.cancel()
65+
self.task = None
66+
67+
def stop(self):
68+
"""Stop monitoring"""
69+
self.active = False
70+
self.pause()
71+
if self.loop is not None:
72+
# Cancel all pending tasks before stopping the loop
73+
def cancel_all_tasks():
74+
for task in asyncio.all_tasks(self.loop):
75+
if not task.done():
76+
task.cancel()
77+
self.loop.call_soon(self.loop.stop)
78+
79+
self.loop.call_soon_threadsafe(cancel_all_tasks)
80+
81+
if self.loop_thread is not None and self.loop_thread.is_alive():
82+
self.loop_thread.join()
83+
84+
self.loop = None
85+
self.loop_thread = None

0 commit comments

Comments
 (0)