Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions infscale/configs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,16 @@ def get_worlds_to_configure(
"""Compare two specs and return new and updated worlds."""
helper = ServeConfigHelper()

curr_worlds = helper._get_worlds(curr_spec)
new_worlds = helper._get_worlds(new_spec)
new_world_names = set(new_worlds.keys())

# if current spec is not available,
# return worlds from the new spec.
if curr_spec is None:
return new_world_names

curr_worlds = helper._get_worlds(curr_spec)
curr_world_names = set(curr_worlds.keys())
new_world_names = set(new_worlds.keys())

deploy_worlds = new_world_names - curr_world_names

Expand Down
174 changes: 106 additions & 68 deletions infscale/execution/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,88 +30,126 @@ class ConfigManager:
def __init__(self):
"""Initialize config manager instance."""
self._loop = asyncio.get_event_loop()
self._task: asyncio.Task | None = None
self._event = asyncio.Event()
self._spec: ServeConfig = None
self._event.set()
self._curr_worlds_to_configure: set[str] = set()
self._cancel_cur_cfg = False
self._world_infos: dict[str, WorldInfo] = {}

def handle_new_spec(self, spec: ServeConfig) -> None:
# semaphore event for back-to-back configs
self._config_event = asyncio.Event()
self._config_event.set()
self._world_tasks: dict[str, asyncio.Task] = {}
self._new_spec: ServeConfig = None
self._curr_spec: ServeConfig = None
self._curr_world_infos: dict[str, WorldInfo] = {}
self._new_world_infos: dict[str, WorldInfo] = {}
self.worlds_to_cancel = set()

async def handle_new_spec(self, spec: ServeConfig) -> None:
"""Handle new spec."""
self._cancel_cur_cfg = self._should_cancel_current(spec)
self._spec = spec
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(
self._curr_spec, spec
)

def _should_cancel_current(self, spec: ServeConfig) -> bool:
"""Decide if current configuration should be cancelled."""
if self._spec is None:
return False
# on the first run, both new and cur will be empty sets
new = self._new_world_infos.keys()
cur = self._curr_world_infos.keys()
curr_worlds_to_configure = new - cur

new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec)
self.worlds_to_cancel = new_worlds_to_configure & curr_worlds_to_configure

# cancel if the new config affects worlds currently being configured
# TODO: if there's a overlap between new worlds and curr worlds we cancel
# current configuration. This needs to be fixed, to cancel only the worlds that
# are affected (eg new_worlds & curr_worlds)
return not new_worlds_to_configure.isdisjoint(self._curr_worlds_to_configure)
self._curr_world_infos = (
self._build_world_infos(self._curr_spec) if self._curr_spec else {}
)
self._new_world_infos = self._build_world_infos(spec)
self._new_spec = spec

def set_worlds_to_configure(self, world_names: set[str]) -> None:
"""Set the world names currently being configured."""
self._curr_worlds_to_configure = world_names
if len(self.worlds_to_cancel):
await self._cancel_world_configuration(self.worlds_to_cancel)

def set_world_infos(self, worlds: list[WorldInfo]) -> None:
"""Set new world infos."""
for world_info in worlds:
self._world_infos[world_info.name] = world_info
# wait for current configuration to finish
await self._config_event.wait()

def get_world_infos(self) -> dict[str, WorldInfo]:
"Get world infos."
return self._world_infos
# do cleanup after each successful configuration
self._curr_spec = self._new_spec
self._new_spec = None
self.worlds_to_cancel = set()

# block handling new spec after doing cleanup for the current one
self._config_event.clear()

def unblock_next_config(self) -> None:
"""Set task event and unblock next config process."""
self._config_event.set()

def update_world_infos(self, worlds_names: set[str]) -> None:
"""Update world infos."""
for world_name in worlds_names:
world_info = self._new_world_infos[world_name]
self._curr_world_infos[world_info.name] = world_info

def get_curr_world_infos(self) -> dict[str, WorldInfo]:
"Get current world infos."
return self._curr_world_infos

def is_first_run(self) -> bool:
"Return boolean if is first run or not."
return not self._world_infos
return not self._curr_world_infos

def remove_world_info(self, world_name: str) -> None:
"""Remove world info by name."""
del self._world_infos[world_name]
del self._curr_world_infos[world_name]

def get_worlds_to_add_and_remove(self) -> tuple[list[WorldInfo], list[WorldInfo]]:
def get_worlds_to_add_and_remove(self) -> tuple[set[str], set[str]]:
"""Return a list of world infos to add and to remove."""
new_world_infos = self._build_world_infos()

new = new_world_infos.keys()
cur = self._world_infos.keys()
new = self._new_world_infos.keys()
cur = self._curr_world_infos.keys()

worlds_to_add = [new_world_infos[name] for name in new - cur]
worlds_to_remove = [new_world_infos[name] for name in cur - new]
worlds_to_add = new - cur
worlds_to_remove = cur - new

return worlds_to_add, worlds_to_remove

async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
"""Cancel any in-progress configure and schedule a new one."""
# wait for current to finish if we do not want to cancel
if not self._cancel_cur_cfg:
await self._event.wait()

# cancel current if running
if self._task and not self._task.done():
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass

# block again for new run
self._event.clear()
self._task = self._loop.create_task(self._run(coro_factory))

def _build_world_infos(self) -> dict[str, WorldInfo]:
def get_new_world_info(self, world_name) -> WorldInfo:
"""Return new world info based on world name."""
return self._new_world_infos[world_name]

def get_world_infos_by_name(
self, world_names: set[str], to_remove: bool = False
) -> list[WorldInfo]:
"""Return a list of world infos to add based on world names."""
world_infos = []

target = self._new_world_infos

if to_remove:
target = self._curr_world_infos

for world_name in world_names:
world_infos.append(target[world_name])

return world_infos

async def _cancel_world_configuration(self, world_names: set[str]):
"""Cancel only worlds that are impacted by new spec."""
coroutines = [self._cancel_world(w) for w in world_names]
await asyncio.gather(*coroutines, return_exceptions=True)

def schedule_world_cfg(
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
):
"""Schedule configuration for a single world."""
task = self._loop.create_task(self._run_world(world_info, coro_factory))
self._world_tasks[world_info.name] = task
return task

async def _cancel_world(self, world_name: str):
"""Cancel an in-progress world config task."""
task = self._world_tasks.pop(world_name, None)
if task and not task.done():
task.cancel()
raise asyncio.CancelledError

def _build_world_infos(self, spec: ServeConfig) -> dict[str, WorldInfo]:
world_infos: dict[str, WorldInfo] = {}

my_id = self._spec.stage.id
for k, v in self._spec.flow_graph.items():
my_id = spec.stage.id
for k, v in spec.flow_graph.items():
for cfg_world_info in v:
# NOTE: no. of peers is always 1 for now
assert len(cfg_world_info.peers) == 1
Expand Down Expand Up @@ -161,13 +199,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:

return world_infos

async def _run(self, coro_factory: Callable[[], Awaitable[None]]):
"""Run coroutine factory."""
async def _run_world(
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
):
"""Run and cleanup world configuration."""
try:
await coro_factory()
await coro_factory(world_info)
except asyncio.CancelledError:
pass
raise
finally:
# reset class attributes and events
self._event.set()
self._curr_worlds_to_configure = set()
self._world_tasks.pop(world_info.name, None)
25 changes: 23 additions & 2 deletions infscale/execution/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,31 @@ async def setup(self) -> None:
if self.rank == 0:
self._server_task = asyncio.create_task(self._setup_server(setup_done))
else:
_ = asyncio.create_task(self._setup_client(setup_done))
client_task = asyncio.create_task(self._setup_client(setup_done))

# wait until setting up either server or client is done
await setup_done.wait()
try:
await setup_done.wait()
except asyncio.CancelledError as e:
# since both _setup_server and _setup_client are spawned as separate tasks
# and the setup itself is a task, we need to handle parent task cancellation
# on the awaited line, since cancellation only propagates through awaited calls
# here, await setup_done.wait() is the propagation point from parent task to child tasks
# so we need to cancel child tasks whenever CancelledError is received
if self._server_task and not self._server_task.done():
self._server_task.cancel()
try:
await self._server_task
except asyncio.CancelledError:
pass

if client_task and not client_task.done():
client_task.cancel()
try:
await client_task
except asyncio.CancelledError:
pass
raise

def cleanup(self) -> None:
if self._server_task is not None:
Expand Down
Loading