Skip to content

Commit 11bf1ed

Browse files
committed
feat: handle individual world configuration
Since having one config for the entire configuration couldn't handle partial worlds update, we implemented per world configuration tasks. Config manager takes care of scheduling and canceling the tasks based on some diffs between old and new config. Also, config manager will orchestrate processing configs based on what it needs to cancel. Updated the code with proper error handlers for cancelling tasks so it properly propagates to parent task. With this, world configuration can be tracked individually, making easier to implement retry on failure in the future.
1 parent 627409d commit 11bf1ed

File tree

4 files changed

+182
-91
lines changed

4 files changed

+182
-91
lines changed

infscale/configs/job.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,16 @@ def get_worlds_to_configure(
157157
"""Compare two specs and return new and updated worlds."""
158158
helper = ServeConfigHelper()
159159

160-
curr_worlds = helper._get_worlds(curr_spec)
161160
new_worlds = helper._get_worlds(new_spec)
161+
new_world_names = set(new_worlds.keys())
162+
163+
# if current spec is not available,
164+
# return worlds from the new spec.
165+
if curr_spec is None:
166+
return new_world_names
162167

168+
curr_worlds = helper._get_worlds(curr_spec)
163169
curr_world_names = set(curr_worlds.keys())
164-
new_world_names = set(new_worlds.keys())
165170

166171
deploy_worlds = new_world_names - curr_world_names
167172

infscale/execution/config_manager.py

Lines changed: 106 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -30,88 +30,126 @@ class ConfigManager:
3030
def __init__(self):
3131
"""Initialize config manager instance."""
3232
self._loop = asyncio.get_event_loop()
33-
self._task: asyncio.Task | None = None
34-
self._event = asyncio.Event()
35-
self._spec: ServeConfig = None
36-
self._event.set()
37-
self._curr_worlds_to_configure: set[str] = set()
38-
self._cancel_cur_cfg = False
39-
self._world_infos: dict[str, WorldInfo] = {}
40-
41-
def handle_new_spec(self, spec: ServeConfig) -> None:
33+
# semaphore event for back-to-back configs
34+
self._config_event = asyncio.Event()
35+
self._config_event.set()
36+
self._world_tasks: dict[str, asyncio.Task] = {}
37+
self._new_spec: ServeConfig = None
38+
self._curr_spec: ServeConfig = None
39+
self._curr_world_infos: dict[str, WorldInfo] = {}
40+
self._new_world_infos: dict[str, WorldInfo] = {}
41+
self.worlds_to_cancel = set()
42+
43+
async def handle_new_spec(self, spec: ServeConfig) -> None:
4244
"""Handle new spec."""
43-
self._cancel_cur_cfg = self._should_cancel_current(spec)
44-
self._spec = spec
45+
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(
46+
self._curr_spec, spec
47+
)
4548

46-
def _should_cancel_current(self, spec: ServeConfig) -> bool:
47-
"""Decide if current configuration should be cancelled."""
48-
if self._spec is None:
49-
return False
49+
# on the first run, both new and cur will be empty sets
50+
new = self._new_world_infos.keys()
51+
cur = self._curr_world_infos.keys()
52+
curr_worlds_to_configure = new - cur
5053

51-
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec)
54+
self.worlds_to_cancel = new_worlds_to_configure & curr_worlds_to_configure
5255

53-
# cancel if the new config affects worlds currently being configured
54-
# TODO: if there's a overlap between new worlds and curr worlds we cancel
55-
# current configuration. This needs to be fixed, to cancel only the worlds that
56-
# are affected (eg new_worlds & curr_worlds)
57-
return not new_worlds_to_configure.isdisjoint(self._curr_worlds_to_configure)
56+
self._curr_world_infos = (
57+
self._build_world_infos(self._curr_spec) if self._curr_spec else {}
58+
)
59+
self._new_world_infos = self._build_world_infos(spec)
60+
self._new_spec = spec
5861

59-
def set_worlds_to_configure(self, world_names: set[str]) -> None:
60-
"""Set the world names currently being configured."""
61-
self._curr_worlds_to_configure = world_names
62+
if len(self.worlds_to_cancel):
63+
await self._cancel_world_configuration(self.worlds_to_cancel)
6264

63-
def set_world_infos(self, worlds: list[WorldInfo]) -> None:
64-
"""Set new world infos."""
65-
for world_info in worlds:
66-
self._world_infos[world_info.name] = world_info
65+
# wait for current configuration to finish
66+
await self._config_event.wait()
6767

68-
def get_world_infos(self) -> dict[str, WorldInfo]:
69-
"Get world infos."
70-
return self._world_infos
68+
# do cleanup after each successful configuration
69+
self._curr_spec = self._new_spec
70+
self._new_spec = None
71+
self.worlds_to_cancel = set()
72+
73+
# block handling new spec after doing cleanup for the current one
74+
self._config_event.clear()
75+
76+
def unblock_next_config(self) -> None:
77+
"""Set task event and unblock next config process."""
78+
self._config_event.set()
79+
80+
def update_world_infos(self, worlds_names: set[str]) -> None:
81+
"""Update world infos."""
82+
for world_name in worlds_names:
83+
world_info = self._new_world_infos[world_name]
84+
self._curr_world_infos[world_info.name] = world_info
85+
86+
def get_curr_world_infos(self) -> dict[str, WorldInfo]:
87+
"Get current world infos."
88+
return self._curr_world_infos
7189

7290
def is_first_run(self) -> bool:
7391
"Return boolean if is first run or not."
74-
return not self._world_infos
92+
return not self._curr_world_infos
7593

7694
def remove_world_info(self, world_name: str) -> None:
7795
"""Remove world info by name."""
78-
del self._world_infos[world_name]
96+
del self._curr_world_infos[world_name]
7997

80-
def get_worlds_to_add_and_remove(self) -> tuple[list[WorldInfo], list[WorldInfo]]:
98+
def get_worlds_to_add_and_remove(self) -> tuple[set[str], set[str]]:
8199
"""Return a list of world infos to add and to remove."""
82-
new_world_infos = self._build_world_infos()
83-
84-
new = new_world_infos.keys()
85-
cur = self._world_infos.keys()
100+
new = self._new_world_infos.keys()
101+
cur = self._curr_world_infos.keys()
86102

87-
worlds_to_add = [new_world_infos[name] for name in new - cur]
88-
worlds_to_remove = [new_world_infos[name] for name in cur - new]
103+
worlds_to_add = new - cur
104+
worlds_to_remove = cur - new
89105

90106
return worlds_to_add, worlds_to_remove
91107

92-
async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
93-
"""Cancel any in-progress configure and schedule a new one."""
94-
# wait for current to finish if we do not want to cancel
95-
if not self._cancel_cur_cfg:
96-
await self._event.wait()
97-
98-
# cancel current if running
99-
if self._task and not self._task.done():
100-
self._task.cancel()
101-
try:
102-
await self._task
103-
except asyncio.CancelledError:
104-
pass
105-
106-
# block again for new run
107-
self._event.clear()
108-
self._task = self._loop.create_task(self._run(coro_factory))
109-
110-
def _build_world_infos(self) -> dict[str, WorldInfo]:
108+
def get_new_world_info(self, world_name) -> WorldInfo:
109+
"""Return new world info based on world name."""
110+
return self._new_world_infos[world_name]
111+
112+
def get_world_infos_by_name(
113+
self, world_names: set[str], to_remove: bool = False
114+
) -> list[WorldInfo]:
115+
"""Return a list of world infos to add based on world names."""
116+
world_infos = []
117+
118+
target = self._new_world_infos
119+
120+
if to_remove:
121+
target = self._curr_world_infos
122+
123+
for world_name in world_names:
124+
world_infos.append(target[world_name])
125+
126+
return world_infos
127+
128+
async def _cancel_world_configuration(self, world_names: set[str]):
129+
"""Cancel only worlds that are impacted by new spec."""
130+
coroutines = [self._cancel_world(w) for w in world_names]
131+
await asyncio.gather(*coroutines, return_exceptions=True)
132+
133+
def schedule_world_cfg(
134+
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
135+
):
136+
"""Schedule configuration for a single world."""
137+
task = self._loop.create_task(self._run_world(world_info, coro_factory))
138+
self._world_tasks[world_info.name] = task
139+
return task
140+
141+
async def _cancel_world(self, world_name: str):
142+
"""Cancel an in-progress world config task."""
143+
task = self._world_tasks.pop(world_name, None)
144+
if task and not task.done():
145+
task.cancel()
146+
raise asyncio.CancelledError
147+
148+
def _build_world_infos(self, spec: ServeConfig) -> dict[str, WorldInfo]:
111149
world_infos: dict[str, WorldInfo] = {}
112150

113-
my_id = self._spec.stage.id
114-
for k, v in self._spec.flow_graph.items():
151+
my_id = spec.stage.id
152+
for k, v in spec.flow_graph.items():
115153
for cfg_world_info in v:
116154
# NOTE: no. of peers is always 1 for now
117155
assert len(cfg_world_info.peers) == 1
@@ -161,13 +199,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:
161199

162200
return world_infos
163201

164-
async def _run(self, coro_factory: Callable[[], Awaitable[None]]):
165-
"""Run coroutine factory."""
202+
async def _run_world(
203+
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
204+
):
205+
"""Run and cleanup world configuration."""
166206
try:
167-
await coro_factory()
207+
await coro_factory(world_info)
168208
except asyncio.CancelledError:
169-
pass
209+
raise
170210
finally:
171-
# reset class attributes and events
172-
self._event.set()
173-
self._curr_worlds_to_configure = set()
211+
self._world_tasks.pop(world_info.name, None)

infscale/execution/control.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,31 @@ async def setup(self) -> None:
162162
if self.rank == 0:
163163
self._server_task = asyncio.create_task(self._setup_server(setup_done))
164164
else:
165-
_ = asyncio.create_task(self._setup_client(setup_done))
165+
client_task = asyncio.create_task(self._setup_client(setup_done))
166166

167167
# wait until setting up either server or client is done
168-
await setup_done.wait()
168+
try:
169+
await setup_done.wait()
170+
except asyncio.CancelledError as e:
171+
# since both _setup_server and _setup_client are spawned as separate tasks
172+
# and the setup itself is a task, we need to handle parent task cancellation
173+
# on the awaited line, since cancellation only propagates through awaited calls
174+
# here, await setup_done.wait() is the propagation point from parent task to child tasks
175+
# so we need to cancel child tasks whenever CancelledError is received
176+
if self._server_task and not self._server_task.done():
177+
self._server_task.cancel()
178+
try:
179+
await self._server_task
180+
except asyncio.CancelledError:
181+
pass
182+
183+
if client_task and not client_task.done():
184+
client_task.cancel()
185+
try:
186+
await client_task
187+
except asyncio.CancelledError:
188+
pass
189+
raise
169190

170191
def cleanup(self) -> None:
171192
if self._server_task is not None:

0 commit comments

Comments
 (0)