Skip to content

Commit c4364fb

Browse files
njhillnpanpaliya
authored andcommitted
[BugFix] Harden distributed DP startup (vllm-project#21538)
Signed-off-by: Nick Hill <[email protected]>
1 parent c014e0b commit c4364fb

File tree

3 files changed

+56
-20
lines changed

3 files changed

+56
-20
lines changed

vllm/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2794,6 +2794,9 @@ def make_zmq_socket(
27942794
if linger is not None:
27952795
socket.setsockopt(zmq.LINGER, linger)
27962796

2797+
if socket_type == zmq.XPUB:
2798+
socket.setsockopt(zmq.XPUB_VERBOSE, True)
2799+
27972800
# Determine if the path is a TCP socket with an IPv6 address.
27982801
# Enable IPv6 on the zmq socket if so.
27992802
scheme, host, _ = split_zmq_path(path)

vllm/v1/engine/coordinator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ def process_input_socket(self, front_publish_address: str,
172172
bind=True,
173173
) as publish_back:
174174

175+
# Wait until all engines subscribe.
176+
for _ in self.engines:
177+
if publish_back.recv() != b'\x01':
178+
logger.error(
179+
"DP Coordinator received unexpected message while "
180+
"waiting for engines to subscribe")
181+
return
182+
# Send ready message to engines.
183+
publish_back.send(b"READY")
184+
185+
logger.info("All engine subscriptions received by DP coordinator")
186+
175187
poller = zmq.Poller()
176188
poller.register(publish_front, zmq.POLLIN)
177189
poller.register(output_back, zmq.POLLIN)

vllm/v1/engine/core.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,11 @@ def __init__(
461461
self.has_coordinator = addresses.coordinator_output is not None
462462
self.frontend_stats_publish_address = (
463463
addresses.frontend_stats_publish_address)
464+
logger.debug("Has DP Coordinator: %s, stats publish address: %s",
465+
self.has_coordinator,
466+
self.frontend_stats_publish_address)
464467
# Only publish request queue stats to coordinator for "internal"
465-
# LB mode.
468+
# and "hybrid" LB modes .
466469
self.publish_dp_lb_stats = (
467470
self.has_coordinator
468471
and not vllm_config.parallel_config.data_parallel_external_lb)
@@ -472,25 +475,38 @@ def __init__(
472475
super().__init__(vllm_config, executor_class, log_stats,
473476
executor_fail_callback)
474477

478+
# Background Threads and Queues for IO. These enable us to
479+
# overlap ZMQ socket IO with GPU since they release the GIL,
480+
# and to overlap some serialization/deserialization with the
481+
# model forward pass.
482+
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
483+
ready_event = threading.Event()
484+
input_thread = threading.Thread(target=self.process_input_sockets,
485+
args=(addresses.inputs,
486+
addresses.coordinator_input,
487+
identity, ready_event),
488+
daemon=True)
489+
input_thread.start()
490+
491+
self.output_thread = threading.Thread(
492+
target=self.process_output_sockets,
493+
args=(addresses.outputs, addresses.coordinator_output,
494+
self.engine_index),
495+
daemon=True)
496+
self.output_thread.start()
497+
498+
# Don't complete handshake until DP coordinator ready message is
499+
# received.
500+
while not ready_event.wait(timeout=10):
501+
if not input_thread.is_alive():
502+
raise RuntimeError(
503+
"Input socket thread died during startup")
504+
assert addresses.coordinator_input is not None
505+
logger.info("Waiting for READY message from DP Coordinator...")
506+
475507
self.step_fn = (self.step if self.batch_queue is None else
476508
self.step_with_batch_queue)
477509

478-
# Background Threads and Queues for IO. These enable us to
479-
# overlap ZMQ socket IO with GPU since they release the GIL,
480-
# and to overlap some serialization/deserialization with the
481-
# model forward pass.
482-
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
483-
threading.Thread(target=self.process_input_sockets,
484-
args=(addresses.inputs, addresses.coordinator_input,
485-
identity),
486-
daemon=True).start()
487-
self.output_thread = threading.Thread(
488-
target=self.process_output_sockets,
489-
args=(addresses.outputs, addresses.coordinator_output,
490-
self.engine_index),
491-
daemon=True)
492-
self.output_thread.start()
493-
494510
@contextmanager
495511
def _perform_handshakes(
496512
self,
@@ -505,10 +521,10 @@ def _perform_handshakes(
505521
506522
For DP=1 or offline mode, this is with the colocated front-end process.
507523
508-
For DP>1 with internal loadbalancing this is with the shared front-end
524+
For DP>1 with internal load-balancing this is with the shared front-end
509525
process which may reside on a different node.
510526
511-
For DP>1 with external or hybrid loadbalancing, two handshakes are
527+
For DP>1 with external or hybrid load-balancing, two handshakes are
512528
performed:
513529
- With the rank 0 front-end process which retrieves the
514530
DP Coordinator ZMQ addresses and DP process group address.
@@ -772,7 +788,7 @@ def _send_engine_dead(self):
772788

773789
def process_input_sockets(self, input_addresses: list[str],
774790
coord_input_address: Optional[str],
775-
identity: bytes):
791+
identity: bytes, ready_event: threading.Event):
776792
"""Input socket IO thread."""
777793

778794
# Msgpack serialization decoding.
@@ -809,9 +825,14 @@ def process_input_sockets(self, input_addresses: list[str],
809825
# back to us.
810826
input_socket.send(b'')
811827
poller.register(input_socket, zmq.POLLIN)
828+
812829
if coord_socket is not None:
830+
# Wait for ready message from coordinator.
831+
assert coord_socket.recv() == b"READY"
813832
poller.register(coord_socket, zmq.POLLIN)
814833

834+
ready_event.set()
835+
del ready_event
815836
while True:
816837
for input_socket, _ in poller.poll():
817838
# (RequestType, RequestData)

0 commit comments

Comments
 (0)