Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit e2b2aa5

Browse files
authored
[TPU] Align worker index with node boundary (vllm-project#7932)
1 parent e6a26ed commit e2b2aa5

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

vllm/executor/ray_tpu_executor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,40 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
111111
# Else, added to the list of workers.
112112
self.workers.append(worker)
113113

114+
logger.debug("workers: %s", self.workers)
115+
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
114116
if self.driver_dummy_worker is None:
115117
raise ValueError(
116118
"Ray does not allocate any TPUs on the driver node. Consider "
117119
"adjusting the Ray placement group or running the driver on a "
118120
"TPU node.")
119121

122+
worker_ips = [
123+
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
124+
for worker in self.workers
125+
]
126+
ip_counts: Dict[str, int] = {}
127+
for ip in worker_ips:
128+
ip_counts[ip] = ip_counts.get(ip, 0) + 1
129+
130+
def sort_by_driver_then_worker_ip(worker):
131+
"""
132+
Sort the workers based on 3 properties:
133+
1. If the worker is on the same node as the driver (vllm engine),
134+
it should be placed first.
135+
2. Then, if the worker is on a node with fewer workers, it should
136+
be placed first.
137+
3. Finally, if the work is on a node with smaller IP address, it
138+
should be placed first.
139+
"""
140+
ip = ray.get(worker.get_node_ip.remote())
141+
return (ip != driver_ip, ip_counts[ip], ip)
142+
143+
# After sorting, the workers on the same node will be
144+
# close to each other, and the workers on the driver
145+
# node will be placed first.
146+
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
147+
120148
# Get the set of TPU IDs used on each node.
121149
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
122150
use_dummy_driver=True)

0 commit comments

Comments
 (0)