Skip to content

Commit 9859c9b

Browse files
committed
feat: handle individual world configuration
Since having one config for the entire configuration couldn't handle partial worlds update, we implemented per world configuration tasks. Config manager takes care of scheduling and canceling the tasks based on some diffs between old and new config. Also, config manager will orchestrate processing configs based on what it needs to cancel. Updated the code with proper error handlers for cancelling tasks so it properly propagates to parent task. With this, world configuration can be tracked individually, making easier to implement retry on failure in the future.
1 parent 627409d commit 9859c9b

File tree

5 files changed

+124
-57
lines changed

5 files changed

+124
-57
lines changed

infscale/configs/job.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,16 @@ def get_worlds_to_configure(
157157
"""Compare two specs and return new and updated worlds."""
158158
helper = ServeConfigHelper()
159159

160-
curr_worlds = helper._get_worlds(curr_spec)
161160
new_worlds = helper._get_worlds(new_spec)
161+
new_world_names = set(new_worlds.keys())
162+
163+
# if current spec is not available,
164+
# return worlds from the new spec.
165+
if curr_spec is None:
166+
return new_world_names
162167

168+
curr_worlds = helper._get_worlds(curr_spec)
163169
curr_world_names = set(curr_worlds.keys())
164-
new_world_names = set(new_worlds.keys())
165170

166171
deploy_worlds = new_world_names - curr_world_names
167172

infscale/execution/config_manager.py

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,38 +30,43 @@ class ConfigManager:
3030
def __init__(self):
3131
"""Initialize config manager instance."""
3232
self._loop = asyncio.get_event_loop()
33-
self._task: asyncio.Task | None = None
34-
self._event = asyncio.Event()
33+
# semaphore event for back-to-back configs
34+
self._config_event = asyncio.Event()
35+
self._config_event.set()
36+
self._world_tasks: dict[str, asyncio.Task] = {}
3537
self._spec: ServeConfig = None
36-
self._event.set()
3738
self._curr_worlds_to_configure: set[str] = set()
38-
self._cancel_cur_cfg = False
3939
self._world_infos: dict[str, WorldInfo] = {}
40+
self.worlds_to_cancel = set()
4041

41-
def handle_new_spec(self, spec: ServeConfig) -> None:
42+
async def handle_new_spec(self, spec: ServeConfig) -> None:
4243
"""Handle new spec."""
43-
self._cancel_cur_cfg = self._should_cancel_current(spec)
44+
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec)
45+
self.worlds_to_cancel = new_worlds_to_configure & self._curr_worlds_to_configure
4446
self._spec = spec
4547

46-
def _should_cancel_current(self, spec: ServeConfig) -> bool:
47-
"""Decide if current configuration should be cancelled."""
48-
if self._spec is None:
49-
return False
48+
if len(self.worlds_to_cancel):
49+
await self._cancel_world_configuration(self.worlds_to_cancel)
5050

51-
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec)
51+
# wait for current configuration to finish
52+
await self._config_event.wait()
53+
54+
# block configuration until current config is processed
55+
# when a new spec is waiting, we want to block the execution until
56+
# the current one is under configuration.
57+
self._config_event.clear()
5258

53-
# cancel if the new config affects worlds currently being configured
54-
# TODO: if there's a overlap between new worlds and curr worlds we cancel
55-
# current configuration. This needs to be fixed, to cancel only the worlds that
56-
# are affected (eg new_worlds & curr_worlds)
57-
return not new_worlds_to_configure.isdisjoint(self._curr_worlds_to_configure)
59+
def reset_state(self) -> None:
60+
"""Reset any state that is kept between configs."""
61+
self._curr_worlds_to_configure = set()
62+
self.worlds_to_cancel = set()
5863

59-
def set_worlds_to_configure(self, world_names: set[str]) -> None:
60-
"""Set the world names currently being configured."""
61-
self._curr_worlds_to_configure = world_names
64+
def unblock_next_config(self) -> None:
65+
"""Set task event and unblock next config process."""
66+
self._config_event.set()
6267

63-
def set_world_infos(self, worlds: list[WorldInfo]) -> None:
64-
"""Set new world infos."""
68+
def update_world_infos(self, worlds: list[WorldInfo]) -> None:
69+
"""Update world infos."""
6570
for world_info in worlds:
6671
self._world_infos[world_info.name] = world_info
6772

@@ -87,25 +92,32 @@ def get_worlds_to_add_and_remove(self) -> tuple[list[WorldInfo], list[WorldInfo]
8792
worlds_to_add = [new_world_infos[name] for name in new - cur]
8893
worlds_to_remove = [new_world_infos[name] for name in cur - new]
8994

90-
return worlds_to_add, worlds_to_remove
95+
self._curr_worlds_to_configure = new - cur
9196

92-
async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
93-
"""Cancel any in-progress configure and schedule a new one."""
94-
# wait for current to finish if we do not want to cancel
95-
if not self._cancel_cur_cfg:
96-
await self._event.wait()
97+
return worlds_to_add, worlds_to_remove
9798

98-
# cancel current if running
99-
if self._task and not self._task.done():
100-
self._task.cancel()
99+
async def _cancel_world_configuration(self, world_names: set[str]):
100+
"""Cancel only worlds that are impacted by new spec."""
101+
coroutines = [self._cancel_world(w) for w in world_names]
102+
await asyncio.gather(*coroutines, return_exceptions=True)
103+
104+
def schedule_world_cfg(
105+
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
106+
):
107+
"""Schedule configuration for a single world."""
108+
task = self._loop.create_task(self._run_world(world_info, coro_factory))
109+
self._world_tasks[world_info.name] = task
110+
return task
111+
112+
async def _cancel_world(self, world_name: str):
113+
"""Cancel an in-progress world config task."""
114+
task = self._world_tasks.pop(world_name, None)
115+
if task and not task.done():
116+
task.cancel()
101117
try:
102-
await self._task
118+
await task
103119
except asyncio.CancelledError:
104-
pass
105-
106-
# block again for new run
107-
self._event.clear()
108-
self._task = self._loop.create_task(self._run(coro_factory))
120+
raise
109121

110122
def _build_world_infos(self) -> dict[str, WorldInfo]:
111123
world_infos: dict[str, WorldInfo] = {}
@@ -161,13 +173,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:
161173

162174
return world_infos
163175

164-
async def _run(self, coro_factory: Callable[[], Awaitable[None]]):
165-
"""Run coroutine factory."""
176+
async def _run_world(
177+
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
178+
):
179+
"""Run and cleanup world configuration."""
166180
try:
167-
await coro_factory()
181+
await coro_factory(world_info)
168182
except asyncio.CancelledError:
169-
pass
183+
raise
170184
finally:
171-
# reset class attributes and events
172-
self._event.set()
173-
self._curr_worlds_to_configure = set()
185+
self._world_tasks.pop(world_info.name, None)

infscale/execution/control.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,32 @@ async def setup(self) -> None:
162162
if self.rank == 0:
163163
self._server_task = asyncio.create_task(self._setup_server(setup_done))
164164
else:
165-
_ = asyncio.create_task(self._setup_client(setup_done))
165+
client_task = asyncio.create_task(self._setup_client(setup_done))
166166

167167
# wait until setting up either server or client is done
168-
await setup_done.wait()
168+
try:
169+
await setup_done.wait()
170+
except asyncio.CancelledError as e:
171+
# logger.warning(f"[{self.rank}] channel setup cancelled")
172+
# since both _setup_server and _setup_client are spawned as separate tasks
173+
# and the setup itself is a task, we need to handle parent task cancellation
174+
# on the awaited line, since cancellation only propagates through awaited calls
175+
# here, await setup_done.wait() is the propagation point from parent task to child tasks
176+
# so we need to cancel child tasks whenever CancelledError is received
177+
if self._server_task and not self._server_task.done():
178+
self._server_task.cancel()
179+
try:
180+
await self._server_task
181+
except asyncio.CancelledError:
182+
pass
183+
184+
if client_task and not client_task.done():
185+
client_task.cancel()
186+
try:
187+
await client_task
188+
except asyncio.CancelledError:
189+
pass
190+
raise
169191

170192
def cleanup(self) -> None:
171193
if self._server_task is not None:

infscale/execution/pipeline.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ async def _configure_multiworld(self, world_info: WorldInfo) -> None:
104104
port=port,
105105
device=self.device,
106106
)
107+
except asyncio.CancelledError:
108+
logger.warning(f"multiworld configuration cancelled for {world_info.name}")
107109
except Exception as e:
108110
logger.error(f"failed to initialize a multiworld {name}: {e}")
109111
condition = self._status != WorkerStatus.UPDATING
@@ -130,12 +132,20 @@ def _set_n_send_worker_status(self, status: WorkerStatus) -> None:
130132
self.wcomm.send(msg)
131133

132134
async def _configure_control_channel(self, world_info: WorldInfo) -> None:
133-
await world_info.channel.setup()
135+
try:
136+
await world_info.channel.setup()
134137

135-
await world_info.channel.wait_readiness()
138+
await world_info.channel.wait_readiness()
139+
except asyncio.CancelledError:
140+
logger.warning(f"channel configuration cancelled for {world_info}")
136141

137142
def _reset_multiworld(self, world_info: WorldInfo) -> None:
138-
self.world_manager.remove_world(world_info.multiworld_name)
143+
try:
144+
self.world_manager.remove_world(world_info.multiworld_name)
145+
except ValueError as e:
146+
logger.warning(f"failed to reset {world_info.multiworld_name}: {e}")
147+
return
148+
139149
logger.info(f"remove world {world_info.multiworld_name} from multiworld")
140150

141151
def _reset_control_channel(self, world_info: WorldInfo) -> None:
@@ -186,7 +196,12 @@ async def _configure(self) -> None:
186196
tasks = []
187197
# 1. set up control channel
188198
for world_info in worlds_to_add:
189-
task = self._configure_control_channel(world_info)
199+
if world_info.name in self.config_manager.worlds_to_cancel:
200+
continue
201+
202+
task = self.config_manager.schedule_world_cfg(
203+
world_info, self._configure_control_channel
204+
)
190205
tasks.append(task)
191206

192207
# TODO: this doesn't handle partial success
@@ -196,22 +211,28 @@ async def _configure(self) -> None:
196211
tasks = []
197212
# 2. set up multiworld
198213
for world_info in worlds_to_add:
199-
task = self._configure_multiworld(world_info)
214+
if world_info.name in self.config_manager.worlds_to_cancel:
215+
continue
216+
217+
task = self.config_manager.schedule_world_cfg(
218+
world_info, self._configure_multiworld
219+
)
200220
tasks.append(task)
201221

202222
# TODO: this doesn't handle partial success
203223
# a mechanism to handle a failure is left as a todo
204224
await asyncio.gather(*tasks)
205225

206226
# update world_info for added worlds
207-
self.config_manager.set_world_infos(worlds_to_add)
227+
self.config_manager.update_world_infos(worlds_to_add)
208228

209229
# configure router with worlds to add and remove
210230
await self.router.configure(
211231
self.spec,
212232
self.device,
213233
worlds_to_add,
214234
worlds_to_remove,
235+
self.config_manager.worlds_to_cancel,
215236
)
216237

217238
# handle unnecessary world
@@ -226,6 +247,9 @@ async def _configure(self) -> None:
226247

227248
worker_status = WorkerStatus.RUNNING if is_first_run else WorkerStatus.UPDATED
228249

250+
# config is done, do cleanup in config runner
251+
self.config_manager.reset_state()
252+
self.config_manager.unblock_next_config()
229253
self._set_n_send_worker_status(worker_status)
230254

231255
self.cfg_event.set()
@@ -425,16 +449,16 @@ async def _handle_config(self, spec: ServeConfig) -> None:
425449
if spec is None:
426450
return
427451

428-
self.config_manager.handle_new_spec(spec)
429-
430452
self._configure_variables(spec)
431453

432454
self._inspector.configure(self.spec)
433455

434456
self._initialize_once()
435-
436457
# (re)configure the pipeline
437-
await self.config_manager.schedule(self._configure)
458+
await self.config_manager.handle_new_spec(spec)
459+
# run configure as a separate task since we need to unblock receiving
460+
# a new config to be processed when current configuration is finished
461+
self._configure_task = asyncio.create_task(self._configure())
438462

439463
def _configure_variables(self, spec: ServeConfig) -> None:
440464
"""Set variables that need to be updated."""

infscale/execution/router.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ async def configure(
114114
device=torch.device("cpu"),
115115
worlds_to_add: list[WorldInfo] = [],
116116
worlds_to_remove: list[WorldInfo] = [],
117+
worlds_to_cancel: set[str] = set(),
117118
) -> None:
118119
"""(Re)configure router."""
119120
self._is_server = spec.is_server
@@ -131,6 +132,9 @@ async def configure(
131132
self._fwder.set_stickiness(sticky)
132133

133134
for world_info in worlds_to_add:
135+
if world_info.name in worlds_to_cancel:
136+
continue
137+
134138
cancellable = asyncio.Event()
135139
if world_info.me == 0: # I am a receiver from other
136140
task = asyncio.create_task(self._recv(world_info, cancellable))

0 commit comments

Comments
 (0)