Skip to content

Commit c3a00eb

Browse files
committed
fix: worker info update
Signed-off-by: Hao Lin <linhaomails@gmail.com>
1 parent 469de00 commit c3a00eb

File tree

4 files changed

+53
-58
lines changed

4 files changed

+53
-58
lines changed

rlinf/scheduler/manager/worker_manager.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -335,15 +335,8 @@ def get_worker_info(self, worker_address: WorkerAddress) -> WorkerInfo:
335335
return node._worker_info
336336
return None
337337

338-
def update_worker_info(
339-
self, worker_address: WorkerAddress, worker_info: WorkerInfo
340-
):
341-
"""Update worker information by its address.
342-
343-
Args:
344-
worker_address (WorkerAddress): The address of the worker to update.
345-
worker_info (WorkerInfo): The new worker information.
346-
"""
338+
def update_worker_info(self, worker_address: WorkerAddress, worker_info: WorkerInfo):
339+
"""Update worker information by its address."""
347340
for root in self._root_workers:
348341
node = WorkerNode.find_node(root, worker_address)
349342
if node is not None:
@@ -386,18 +379,21 @@ def get_group_ranks(self, root_group_name: str) -> list[int]:
386379
return []
387380

388381
def set_group_workers(self, root_group_name: str, workers: list[WorkerInfo]):
389-
"""Replace all workers of a root group with provided worker infos."""
390-
root_node = None
391-
for root in self._root_workers:
392-
if (
393-
root._worker_address is not None
394-
and root._worker_address.get_name() == root_group_name
395-
):
396-
root_node = root
397-
break
398-
if root_node is None:
399-
root_node = WorkerNode(WorkerAddress(root_group_name, []))
400-
self._root_workers.append(root_node)
401-
root_node._nodes = []
382+
"""Replace all workers of a root group with provided worker infos.
383+
384+
Rebuilds membership via register/unregister APIs to preserve the WorkerNode
385+
tree invariants instead of mutating internal tree state directly.
386+
"""
387+
for rank in list(self.get_group_ranks(root_group_name)):
388+
self.unregister_worker(WorkerAddress(root_group_name, rank))
389+
402390
for worker_info in sorted(workers, key=lambda info: info.rank):
403-
root_node.add_child(worker_info.rank, worker_info)
391+
if worker_info.address.root_group_name != root_group_name:
392+
raise ValueError(
393+
f"Worker {worker_info.address.get_name()} does not belong to group {root_group_name}."
394+
)
395+
existing = self.get_worker_info(worker_info.address)
396+
if existing is None:
397+
self.register_worker(worker_info.address, worker_info)
398+
else:
399+
self.update_worker_info(worker_info.address, worker_info)

rlinf/scheduler/worker/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1274,6 +1274,7 @@ def update_state(
12741274
master_addr: Optional[str] = None,
12751275
root_group_names: Optional[list[str]] = None,
12761276
rebuild_collective: bool = False,
1277+
sync_manager: bool = True,
12771278
) -> "WorkerInfo":
12781279
"""Refresh worker state after membership/rank changes."""
12791280
if world_size is not None:
@@ -1314,5 +1315,4 @@ def update_state(
13141315

13151316
self._collective = Collective(self)
13161317

1317-
self._manager_proxy.update_worker_info(self._worker_address, self._worker_info)
13181318
return self._worker_info

rlinf/scheduler/worker/worker_group.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def scale_up(
242242

243243
self._placement_strategy = placement_strategy
244244
self._create_workers_from_placements(placements)
245-
self.update_state(reindex_ranks=True)
245+
self.update_state()
246246
self._is_ready()
247247
return self
248248

@@ -271,7 +271,7 @@ def scale_down(
271271
remaining_workers.append(worker_info)
272272

273273
self._workers = remaining_workers
274-
self.update_state(reindex_ranks=True)
274+
self.update_state()
275275
return self
276276

277277
def _unregister_worker(self, rank: int):
@@ -283,40 +283,39 @@ def _unregister_worker(self, rank: int):
283283
)
284284
worker_manager.unregister_worker(worker_address)
285285

286-
def update_state(self, reindex_ranks: bool = False):
286+
def update_state(self):
287287
"""Refresh worker-group state after membership changes."""
288288
from ..manager import CollectiveManager, WorkerManager
289289

290290
worker_manager = WorkerManager.get_proxy()
291-
if reindex_ranks:
292-
self._sort_workers()
293-
world_size = len(self._workers)
294-
if world_size > 0:
295-
first_worker_info = worker_manager.get_worker_info(
296-
WorkerAddress(self._worker_group_name, self._workers[0].rank)
297-
)
298-
assert first_worker_info is not None, (
299-
f"Cannot get worker info for rank {self._workers[0].rank} in group {self._worker_group_name}."
300-
)
301-
master_addr = first_worker_info.node_ip
302-
refs = []
303-
for new_rank, worker_info in enumerate(self._workers):
304-
refs.append(
305-
worker_info.worker.update_state.remote(
306-
new_rank=new_rank,
307-
world_size=world_size,
308-
master_addr=master_addr,
309-
rebuild_collective=True,
310-
)
291+
self._sort_workers()
292+
world_size = len(self._workers)
293+
if world_size > 0:
294+
first_worker_info = worker_manager.get_worker_info(
295+
WorkerAddress(self._worker_group_name, self._workers[0].rank)
296+
)
297+
assert first_worker_info is not None, (
298+
f"Cannot get worker info for rank {self._workers[0].rank} in group {self._worker_group_name}."
299+
)
300+
master_addr = first_worker_info.node_ip
301+
refs = []
302+
for new_rank, worker_info in enumerate(self._workers):
303+
refs.append(
304+
worker_info.worker.update_state.remote(
305+
new_rank=new_rank,
306+
world_size=world_size,
307+
master_addr=master_addr,
308+
rebuild_collective=True,
311309
)
312-
updated_infos = ray.get(refs)
313-
worker_manager.set_group_workers(self._worker_group_name, updated_infos)
314-
self._workers = [
315-
WorkerGroup.WorkerRank(rank=rank, worker=worker_info.worker)
316-
for rank, worker_info in enumerate(self._workers)
317-
]
318-
self._world_size = world_size
319-
self._group_size = world_size
310+
)
311+
updated_infos = ray.get(refs)
312+
worker_manager.set_group_workers(self._worker_group_name, updated_infos)
313+
self._workers = [
314+
WorkerGroup.WorkerRank(rank=rank, worker=worker_info.worker)
315+
for rank, worker_info in enumerate(self._workers)
316+
]
317+
self._world_size = world_size
318+
self._group_size = world_size
320319

321320
coll_manager = CollectiveManager.get_proxy()
322321
related_groups = coll_manager.get_related_worker_groups(self._worker_group_name)

tests/unit_tests/test_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,11 @@ def test_worker_group_scale_up_and_down(self, cluster: Cluster):
198198

199199
worker_group.scale_down([1])
200200
infos = worker_group.get_env_info().wait()
201-
assert sorted(info["rank"] for info in infos) == [0, 2]
201+
assert sorted(info["rank"] for info in infos) == [0, 1]
202202
assert all(info["world_size"] == 2 for info in infos)
203203

204-
subset_infos = worker_group.execute_on(0, 2).get_env_info().wait()
205-
assert sorted(info["rank"] for info in subset_infos) == [0, 2]
204+
subset_infos = worker_group.execute_on(0, 1).get_env_info().wait()
205+
assert sorted(info["rank"] for info in subset_infos) == [0, 1]
206206

207207

208208
class TestLoadUserExtensions:

0 commit comments

Comments
 (0)