Skip to content

Commit 68deaea

Browse files
authored
fix: unique world names on scale_out (#303)
If we have back-to-back scale_out - scale_in - scale_out, on the second scale_out, generator might return the same config, creating a world name conflict while configuring multiworld. For that, we added a duplicate count which we increment on each scale_out. After the config is generated, we find the new workers and update their worlds and their peers worlds duplicate count, so we can have unique worlds names between scale_out operations.
1 parent 1363410 commit 68deaea

File tree

4 files changed

+43
-8
lines changed

4 files changed

+43
-8
lines changed

infscale/configs/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class WorldInfo:
7070
addr: str = "127.0.0.1"
7171
backend: Optional[str] = ""
7272
recover: bool = False
73-
duplicate_count: int = 0
73+
conflict_count: int = 0
7474

7575

7676
@dataclass

infscale/controller/job_context.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ def __init__(self, ctrl: Controller, job_id: str):
663663
self._cur_cfg: JobConfig | None = None
664664
self._new_cfg: JobConfig | None = None
665665
self._flow_graph_patched = False
666+
self._worlds_conflict_count: dict[str, int] = {}
666667

667668
# event to update the config after all agents added ports and ip address
668669
self.agents_setup_event = asyncio.Event()
@@ -831,6 +832,8 @@ def process_cfg(self) -> None:
831832

832833
self._reconcile_wrk_status(self._cur_cfg, self._new_cfg)
833834

835+
self._update_worlds_conflict_count(self._cur_cfg, self._new_cfg)
836+
834837
self._new_cfg.reqgen_config = self.ctrl.reqgen_config
835838

836839
agent_resources = self.get_agent_resources_map()
@@ -853,6 +856,33 @@ def process_cfg(self) -> None:
853856

854857
self.job_checker.setup(self._new_cfg)
855858

859+
def _update_worlds_conflict_count(
860+
self, cur_cfg: JobConfig, new_cfg: JobConfig
861+
) -> None:
862+
"""Update world infos duplicate count."""
863+
if cur_cfg:
864+
new_workers = JobConfig.get_workers_diff(new_cfg, cur_cfg)
865+
else:
866+
new_workers = {worker.id for worker in new_cfg.workers}
867+
868+
for wid, world_list in new_cfg.flow_graph.items():
869+
for world_info in world_list:
870+
is_peer = any(wrk_id in world_info.peers for wrk_id in new_workers)
871+
872+
if wid in new_workers or is_peer:
873+
name = world_info.name
874+
self._set_world_conflict_count(name)
875+
world_info.conflict_count = self._worlds_conflict_count[name]
876+
877+
def _set_world_conflict_count(self, world_name: str) -> None:
878+
"""Set worlds conflict count."""
879+
if world_name in self._worlds_conflict_count:
880+
self._worlds_conflict_count[world_name] += 1
881+
882+
return
883+
884+
self._worlds_conflict_count[world_name] = 0
885+
856886
def reset_cfg_recover_flags(self) -> None:
857887
"""Reset recover flags on config."""
858888
self._cur_cfg.reset_recover_flags()
@@ -878,15 +908,19 @@ def _update_recovery_flow_graph(
878908
recover_flow_graph = cfg.flow_graph[recover_wid]
879909

880910
for world_info in recover_flow_graph:
911+
name = world_info.name
912+
self._set_world_conflict_count(name)
881913
world_info.addr = ip
882914
world_info.recover = True
883-
world_info.duplicate_count += 1
915+
world_info.conflict_count = self._worlds_conflict_count[name]
884916

885917
for world_list in cfg.flow_graph.values():
886918
for world_info in world_list:
887919
if recover_wid in world_info.peers:
920+
name = world_info.name
921+
self._set_world_conflict_count(name)
888922
world_info.recover = True
889-
world_info.duplicate_count += 1
923+
world_info.conflict_count = self._worlds_conflict_count[name]
890924

891925
def _update_recovery_worker_data(
892926
self, cfg: JobConfig, wrk_id: str, gpu_id: int
@@ -1233,6 +1267,7 @@ def cleanup(self) -> None:
12331267
self._cur_cfg = None
12341268
self._new_cfg = None
12351269
self._flow_graph_patched = False
1270+
self._worlds_conflict_count = {}
12361271

12371272
def _release_gpu_resources(self, agent_data: AgentMetaData) -> None:
12381273
resources = self.ctrl.agent_contexts[agent_data.id].resources

infscale/execution/pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,14 +458,14 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:
458458
else:
459459
continue
460460

461-
name, backend, addr, data_port, ctrl_port, recover, duplicate_count = (
461+
name, backend, addr, data_port, ctrl_port, recover, conflict_count = (
462462
cfg_world_info.name,
463463
cfg_world_info.backend,
464464
cfg_world_info.addr,
465465
cfg_world_info.data_port,
466466
cfg_world_info.ctrl_port,
467467
cfg_world_info.recover,
468-
cfg_world_info.duplicate_count,
468+
cfg_world_info.conflict_count,
469469
)
470470

471471
world_size = len(cfg_world_info.peers) + 1
@@ -483,8 +483,8 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:
483483
"other_id": other_id,
484484
"other": other_rank,
485485
"recover": recover,
486-
"duplicate_count": duplicate_count,
487-
"multiworld_name": f"{name}-{duplicate_count}",
486+
"conflict_count": conflict_count,
487+
"multiworld_name": f"{name}-{conflict_count}",
488488
}
489489
world_info = WorldInfo(**data)
490490
world_infos[name] = world_info

infscale/execution/world.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@ class WorldInfo:
4242
other: int # other peer's rank
4343

4444
recover: bool
45-
duplicate_count: int
45+
conflict_count: int
4646
multiworld_name: str

0 commit comments

Comments
 (0)