Skip to content

Commit 5e76c46

Browse files
authored
refactor: pipeline world infos (#310)
Refactored pipeline code and moved world info related stuff into config_runner. With this, config_runner will be responsible of computing worlds to add and to remove and it will be used for when we refactor the code to handle per world configuration.
1 parent 21cc3d6 commit 5e76c46

File tree

2 files changed

+114
-83
lines changed

2 files changed

+114
-83
lines changed

infscale/execution/config_runner.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,29 @@
1414
#
1515
# SPDX-License-Identifier: Apache-2.0
1616

17-
"""config_runner.py."""
17+
"""config_manager.py."""
1818

1919
import asyncio
2020
from typing import Awaitable, Callable
2121

2222
from infscale.configs.job import ServeConfig
23+
from infscale.execution.control import Channel as CtrlCh
24+
from infscale.execution.world import WorldInfo
2325

2426

25-
class ConfigRunner:
26-
"""ConfigRunner class."""
27+
class ConfigManager:
28+
"""ConfigManager class."""
2729

2830
def __init__(self):
29-
"""Initialize config runner instance."""
31+
"""Initialize config manager instance."""
3032
self._loop = asyncio.get_event_loop()
3133
self._task: asyncio.Task | None = None
3234
self._event = asyncio.Event()
3335
self._spec: ServeConfig = None
34-
self._event.set() # initially no configure running
36+
self._event.set()
3537
self._curr_worlds_to_configure: set[str] = set()
3638
self._cancel_cur_cfg = False
39+
self._world_infos: dict[str, WorldInfo] = {}
3740

3841
def handle_new_spec(self, spec: ServeConfig) -> None:
3942
"""Handle new spec."""
@@ -57,6 +60,35 @@ def set_worlds_to_configure(self, world_names: set[str]) -> None:
5760
"""Set the world names currently being configured."""
5861
self._curr_worlds_to_configure = world_names
5962

63+
def set_world_infos(self, worlds: list[WorldInfo]) -> None:
64+
"""Set new world infos."""
65+
for world_info in worlds:
66+
self._world_infos[world_info.name] = world_info
67+
68+
def get_world_infos(self) -> dict[str, WorldInfo]:
69+
"Get world infos."
70+
return self._world_infos
71+
72+
def is_first_run(self) -> bool:
73+
"Return boolean if is first run or not."
74+
return not self._world_infos
75+
76+
def remove_world_info(self, world_name: str) -> None:
77+
"""Remove world info by name."""
78+
del self._world_infos[world_name]
79+
80+
def get_worlds_to_add_and_remove(self) -> tuple[list[WorldInfo], list[WorldInfo]]:
81+
"""Return a list of world infos to add and to remove."""
82+
new_world_infos = self._build_world_infos()
83+
84+
new = new_world_infos.keys()
85+
cur = self._world_infos.keys()
86+
87+
worlds_to_add = [new_world_infos[name] for name in new - cur]
88+
worlds_to_remove = [new_world_infos[name] for name in cur - new]
89+
90+
return worlds_to_add, worlds_to_remove
91+
6092
async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
6193
"""Cancel any in-progress configure and schedule a new one."""
6294
# wait for current to finish if we do not want to cancel
@@ -75,6 +107,60 @@ async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
75107
self._event.clear()
76108
self._task = self._loop.create_task(self._run(coro_factory))
77109

110+
def _build_world_infos(self) -> dict[str, WorldInfo]:
111+
world_infos: dict[str, WorldInfo] = {}
112+
113+
my_id = self._spec.stage.id
114+
for k, v in self._spec.flow_graph.items():
115+
for cfg_world_info in v:
116+
# NOTE: no. of peers is always 1 for now
117+
assert len(cfg_world_info.peers) == 1
118+
119+
if my_id == k:
120+
my_rank = 0
121+
other_rank = 1
122+
other_id = cfg_world_info.peers[0]
123+
elif my_id in cfg_world_info.peers:
124+
# NOTE: this is always 1 for now
125+
my_rank = cfg_world_info.peers.index(my_id) + 1
126+
other_rank = 0
127+
other_id = k
128+
else:
129+
continue
130+
131+
name, backend, addr, data_port, ctrl_port, recover, conflict_count = (
132+
cfg_world_info.name,
133+
cfg_world_info.backend,
134+
cfg_world_info.addr,
135+
cfg_world_info.data_port,
136+
cfg_world_info.ctrl_port,
137+
cfg_world_info.recover,
138+
cfg_world_info.conflict_count,
139+
)
140+
141+
world_size = len(cfg_world_info.peers) + 1
142+
ctrl_ch = CtrlCh(my_rank, world_size, addr, ctrl_port)
143+
144+
data = {
145+
"name": name,
146+
"size": world_size,
147+
"addr": addr,
148+
"port": data_port,
149+
"backend": backend,
150+
"channel": ctrl_ch,
151+
"my_id": my_id,
152+
"me": my_rank,
153+
"other_id": other_id,
154+
"other": other_rank,
155+
"recover": recover,
156+
"conflict_count": conflict_count,
157+
"multiworld_name": f"{name}-{conflict_count}",
158+
}
159+
world_info = WorldInfo(**data)
160+
world_infos[name] = world_info
161+
162+
return world_infos
163+
78164
async def _run(self, coro_factory: Callable[[], Awaitable[None]]):
79165
"""Run coroutine factory."""
80166
try:

infscale/execution/pipeline.py

Lines changed: 23 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from infscale import get_logger
2828
from infscale.common.job_msg import Message, MessageType, WorkerStatus
2929
from infscale.configs.job import ServeConfig
30-
from infscale.execution.config_runner import ConfigRunner
30+
from infscale.execution.config_runner import ConfigManager
3131
from infscale.execution.control import Channel as CtrlCh
3232
from infscale.execution.metrics_collector import MetricsCollector
3333
from infscale.execution.router import Router
@@ -68,13 +68,12 @@ def __init__(
6868
self.wcomm = wcomm
6969
self.spec: ServeConfig = None
7070
self.device = None
71-
self.world_infos: dict[str, WorldInfo] = {}
7271
self.cfg_event = asyncio.Event()
7372
self._micro_batch_size = 1
7473
self._initialized = False
7574
self._inspector = PipelineInspector()
7675
self._status: WorkerStatus = WorkerStatus.READY
77-
self.config_runner = ConfigRunner()
76+
self.config_manager = ConfigManager()
7877

7978
# TODO: these variables are only for a server (i.e., dispatcher)
8079
# need to consider refactoring pipeline such that server code
@@ -118,7 +117,9 @@ def _set_worker_status(self, status: WorkerStatus) -> None:
118117
"""Set worker status in pipeline and channel."""
119118
self._status = status
120119

121-
for world_info in self.world_infos.values():
120+
world_infos = self.config_manager.get_world_infos()
121+
122+
for world_info in world_infos.values():
122123
world_info.channel.set_worker_status(status)
123124

124125
def _set_n_send_worker_status(self, status: WorkerStatus) -> None:
@@ -143,49 +144,45 @@ def _reset_control_channel(self, world_info: WorldInfo) -> None:
143144

144145
async def _cleanup_recovered_worlds(self) -> None:
145146
"""Clean up world infos for recovered worlds."""
147+
world_infos = self.config_manager.get_world_infos()
148+
146149
# if I'm the recovered worker, return
147-
if len(self.world_infos) == 0:
150+
if len(world_infos) == 0:
148151
return
149152

150153
recover_worlds = [
151154
world_info
152155
for world_list in self.spec.flow_graph.values()
153156
for world_info in world_list
154-
if world_info.recover and world_info.name in self.world_infos
157+
if world_info.recover and world_info.name in world_infos
155158
]
156159

157160
# no worlds to recover
158161
if len(recover_worlds) == 0:
159162
return
160163

161164
for world_info in recover_worlds:
162-
wi = self.world_infos.get(world_info.name, None)
165+
wi = world_infos.get(world_info.name, None)
163166

164167
await self.router.cleanup_world(wi)
165168
self._reset_control_channel(wi)
166169
self._reset_multiworld(wi)
167170

168-
del self.world_infos[wi.name]
171+
self.config_manager.remove_world_info(wi.name)
169172

170173
async def _configure(self) -> None:
171174
"""(Re)configure multiworld, control channel and router."""
172175
await self._cleanup_recovered_worlds()
173176

174-
is_first_run = not self.world_infos
177+
is_first_run = self.config_manager.is_first_run()
175178

176179
if not is_first_run:
177180
self._set_worker_status(WorkerStatus.UPDATING)
178181

179-
new_world_infos = self._build_world_infos()
180-
new = new_world_infos.keys()
181-
cur = self.world_infos.keys()
182-
183-
worlds_to_add = [new_world_infos[name] for name in new - cur]
184-
worlds_to_remove = [self.world_infos[name] for name in cur - new]
185-
186-
self.config_runner.set_worlds_to_configure(new - cur)
182+
worlds_to_add, worlds_to_remove = (
183+
self.config_manager.get_worlds_to_add_and_remove()
184+
)
187185

188-
# handle new worlds
189186
tasks = []
190187
# 1. set up control channel
191188
for world_info in worlds_to_add:
@@ -207,12 +204,14 @@ async def _configure(self) -> None:
207204
await asyncio.gather(*tasks)
208205

209206
# update world_info for added worlds
210-
for world_info in worlds_to_add:
211-
self.world_infos[world_info.name] = world_info
207+
self.config_manager.set_world_infos(worlds_to_add)
212208

213209
# configure router with worlds to add and remove
214210
await self.router.configure(
215-
self.spec, self.device, worlds_to_add, worlds_to_remove
211+
self.spec,
212+
self.device,
213+
worlds_to_add,
214+
worlds_to_remove,
216215
)
217216

218217
# handle unnecessary world
@@ -223,7 +222,7 @@ async def _configure(self) -> None:
223222
# 2. remove unnecessary world from multiworld
224223
self._reset_multiworld(world_info)
225224

226-
del self.world_infos[world_info.name]
225+
self.config_manager.remove_world_info(world_info.name)
227226

228227
worker_status = WorkerStatus.RUNNING if is_first_run else WorkerStatus.UPDATED
229228

@@ -426,7 +425,7 @@ async def _handle_config(self, spec: ServeConfig) -> None:
426425
if spec is None:
427426
return
428427

429-
self.config_runner.handle_new_spec(spec)
428+
self.config_manager.handle_new_spec(spec)
430429

431430
self._configure_variables(spec)
432431

@@ -435,61 +434,7 @@ async def _handle_config(self, spec: ServeConfig) -> None:
435434
self._initialize_once()
436435

437436
# (re)configure the pipeline
438-
await self.config_runner.schedule(self._configure)
439-
440-
def _build_world_infos(self) -> dict[str, WorldInfo]:
441-
world_infos: dict[str, WorldInfo] = {}
442-
443-
my_id = self.spec.stage.id
444-
for k, v in self.spec.flow_graph.items():
445-
for cfg_world_info in v:
446-
# NOTE: no. of peers is always 1 for now
447-
assert len(cfg_world_info.peers) == 1
448-
449-
if my_id == k:
450-
my_rank = 0
451-
other_rank = 1
452-
other_id = cfg_world_info.peers[0]
453-
elif my_id in cfg_world_info.peers:
454-
# NOTE: this is always 1 for now
455-
my_rank = cfg_world_info.peers.index(my_id) + 1
456-
other_rank = 0
457-
other_id = k
458-
else:
459-
continue
460-
461-
name, backend, addr, data_port, ctrl_port, recover, conflict_count = (
462-
cfg_world_info.name,
463-
cfg_world_info.backend,
464-
cfg_world_info.addr,
465-
cfg_world_info.data_port,
466-
cfg_world_info.ctrl_port,
467-
cfg_world_info.recover,
468-
cfg_world_info.conflict_count,
469-
)
470-
471-
world_size = len(cfg_world_info.peers) + 1
472-
ctrl_ch = CtrlCh(my_rank, world_size, addr, ctrl_port)
473-
474-
data = {
475-
"name": name,
476-
"size": world_size,
477-
"addr": addr,
478-
"port": data_port,
479-
"backend": backend,
480-
"channel": ctrl_ch,
481-
"my_id": my_id,
482-
"me": my_rank,
483-
"other_id": other_id,
484-
"other": other_rank,
485-
"recover": recover,
486-
"conflict_count": conflict_count,
487-
"multiworld_name": f"{name}-{conflict_count}",
488-
}
489-
world_info = WorldInfo(**data)
490-
world_infos[name] = world_info
491-
492-
return world_infos
437+
await self.config_manager.schedule(self._configure)
493438

494439
def _configure_variables(self, spec: ServeConfig) -> None:
495440
"""Set variables that need to be updated."""

0 commit comments

Comments
 (0)