diff --git a/infscale/configs/job.py b/infscale/configs/job.py index d7d033e..bb0cc8d 100644 --- a/infscale/configs/job.py +++ b/infscale/configs/job.py @@ -157,11 +157,16 @@ def get_worlds_to_configure( """Compare two specs and return new and updated worlds.""" helper = ServeConfigHelper() - curr_worlds = helper._get_worlds(curr_spec) new_worlds = helper._get_worlds(new_spec) + new_world_names = set(new_worlds.keys()) + + # if current spec is not available, + # return worlds from the new spec. + if curr_spec is None: + return new_world_names + curr_worlds = helper._get_worlds(curr_spec) curr_world_names = set(curr_worlds.keys()) - new_world_names = set(new_worlds.keys()) deploy_worlds = new_world_names - curr_world_names diff --git a/infscale/execution/config_manager.py b/infscale/execution/config_manager.py index db6e805..4fd9c06 100644 --- a/infscale/execution/config_manager.py +++ b/infscale/execution/config_manager.py @@ -30,88 +30,126 @@ class ConfigManager: def __init__(self): """Initialize config manager instance.""" self._loop = asyncio.get_event_loop() - self._task: asyncio.Task | None = None - self._event = asyncio.Event() - self._spec: ServeConfig = None - self._event.set() - self._curr_worlds_to_configure: set[str] = set() - self._cancel_cur_cfg = False - self._world_infos: dict[str, WorldInfo] = {} - - def handle_new_spec(self, spec: ServeConfig) -> None: + # semaphore event for back-to-back configs + self._config_event = asyncio.Event() + self._config_event.set() + self._world_tasks: dict[str, asyncio.Task] = {} + self._new_spec: ServeConfig = None + self._curr_spec: ServeConfig = None + self._curr_world_infos: dict[str, WorldInfo] = {} + self._new_world_infos: dict[str, WorldInfo] = {} + self.worlds_to_cancel = set() + + async def handle_new_spec(self, spec: ServeConfig) -> None: """Handle new spec.""" - self._cancel_cur_cfg = self._should_cancel_current(spec) - self._spec = spec + new_worlds_to_configure = ServeConfig.get_worlds_to_configure( + self._curr_spec, spec + ) - def _should_cancel_current(self, spec: ServeConfig) -> bool: - """Decide if current configuration should be cancelled.""" - if self._spec is None: - return False + # on the first run, both new and cur will be empty sets + new = self._new_world_infos.keys() + cur = self._curr_world_infos.keys() + curr_worlds_to_configure = new - cur - new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec) + self.worlds_to_cancel = new_worlds_to_configure & curr_worlds_to_configure - # cancel if the new config affects worlds currently being configured - # TODO: if there's a overlap between new worlds and curr worlds we cancel - # current configuration. This needs to be fixed, to cancel only the worlds that - # are affected (eg new_worlds & curr_worlds) - return not new_worlds_to_configure.isdisjoint(self._curr_worlds_to_configure) + self._curr_world_infos = ( + self._build_world_infos(self._curr_spec) if self._curr_spec else {} + ) + self._new_world_infos = self._build_world_infos(spec) + self._new_spec = spec - def set_worlds_to_configure(self, world_names: set[str]) -> None: - """Set the world names currently being configured.""" - self._curr_worlds_to_configure = world_names + if len(self.worlds_to_cancel): + await self._cancel_world_configuration(self.worlds_to_cancel) - def set_world_infos(self, worlds: list[WorldInfo]) -> None: - """Set new world infos.""" - for world_info in worlds: - self._world_infos[world_info.name] = world_info + # wait for current configuration to finish + await self._config_event.wait() - def get_world_infos(self) -> dict[str, WorldInfo]: - "Get world infos." - return self._world_infos + # do cleanup after each successful configuration + self._curr_spec = self._new_spec + self._new_spec = None + self.worlds_to_cancel = set() + + # block handling new spec after doing cleanup for the current one + self._config_event.clear() + + def unblock_next_config(self) -> None: + """Set task event and unblock next config process.""" + self._config_event.set() + + def update_world_infos(self, worlds_names: set[str]) -> None: + """Update world infos.""" + for world_name in worlds_names: + world_info = self._new_world_infos[world_name] + self._curr_world_infos[world_info.name] = world_info + + def get_curr_world_infos(self) -> dict[str, WorldInfo]: + "Get current world infos." + return self._curr_world_infos def is_first_run(self) -> bool: "Return boolean if is first run or not." - return not self._world_infos + return not self._curr_world_infos def remove_world_info(self, world_name: str) -> None: """Remove world info by name.""" - del self._world_infos[world_name] + del self._curr_world_infos[world_name] - def get_worlds_to_add_and_remove(self) -> tuple[list[WorldInfo], list[WorldInfo]]: + def get_worlds_to_add_and_remove(self) -> tuple[set[str], set[str]]: """Return a list of world infos to add and to remove.""" - new_world_infos = self._build_world_infos() - - new = new_world_infos.keys() - cur = self._world_infos.keys() + new = self._new_world_infos.keys() + cur = self._curr_world_infos.keys() - worlds_to_add = [new_world_infos[name] for name in new - cur] - worlds_to_remove = [new_world_infos[name] for name in cur - new] + worlds_to_add = new - cur + worlds_to_remove = cur - new return worlds_to_add, worlds_to_remove - async def schedule(self, coro_factory: Callable[[], Awaitable[None]]): - """Cancel any in-progress configure and schedule a new one.""" - # wait for current to finish if we do not want to cancel - if not self._cancel_cur_cfg: - await self._event.wait() - - # cancel current if running - if self._task and not self._task.done(): - self._task.cancel() - try: - await self._task - except asyncio.CancelledError: - pass - - # block again for new run - self._event.clear() - self._task = self._loop.create_task(self._run(coro_factory)) - - def _build_world_infos(self) -> dict[str, WorldInfo]: + def get_new_world_info(self, world_name) -> WorldInfo: + """Return new world info based on world name.""" + return self._new_world_infos[world_name] + + def get_world_infos_by_name( + self, world_names: set[str], to_remove: bool = False + ) -> list[WorldInfo]: + """Return a list of world infos to add based on world names.""" + world_infos = [] + + target = self._new_world_infos + + if to_remove: + target = self._curr_world_infos + + for world_name in world_names: + world_infos.append(target[world_name]) + + return world_infos + + async def _cancel_world_configuration(self, world_names: set[str]): + """Cancel only worlds that are impacted by new spec.""" + coroutines = [self._cancel_world(w) for w in world_names] + await asyncio.gather(*coroutines, return_exceptions=True) + + def schedule_world_cfg( + self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]] + ): + """Schedule configuration for a single world.""" + task = self._loop.create_task(self._run_world(world_info, coro_factory)) + self._world_tasks[world_info.name] = task + return task + + async def _cancel_world(self, world_name: str): + """Cancel an in-progress world config task.""" + task = self._world_tasks.pop(world_name, None) + if task and not task.done(): + task.cancel() + raise asyncio.CancelledError + + def _build_world_infos(self, spec: ServeConfig) -> dict[str, WorldInfo]: world_infos: dict[str, WorldInfo] = {} - my_id = self._spec.stage.id - for k, v in self._spec.flow_graph.items(): + my_id = spec.stage.id + for k, v in spec.flow_graph.items(): for cfg_world_info in v: # NOTE: no. of peers is always 1 for now assert len(cfg_world_info.peers) == 1 @@ -161,13 +199,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]: return world_infos - async def _run(self, coro_factory: Callable[[], Awaitable[None]]): - """Run coroutine factory.""" + async def _run_world( + self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]] + ): + """Run and cleanup world configuration.""" try: - await coro_factory() + await coro_factory(world_info) except asyncio.CancelledError: - pass + raise finally: - # reset class attributes and events - self._event.set() - self._curr_worlds_to_configure = set() + self._world_tasks.pop(world_info.name, None) diff --git a/infscale/execution/control.py b/infscale/execution/control.py index 1c769f1..938520b 100644 --- a/infscale/execution/control.py +++ b/infscale/execution/control.py @@ -162,10 +162,31 @@ async def setup(self) -> None: if self.rank == 0: self._server_task = asyncio.create_task(self._setup_server(setup_done)) else: - _ = asyncio.create_task(self._setup_client(setup_done)) + client_task = asyncio.create_task(self._setup_client(setup_done)) # wait until setting up either server or client is done - await setup_done.wait() + try: + await setup_done.wait() + except asyncio.CancelledError as e: + # since both _setup_server and _setup_client are spawned as separate tasks + # and the setup itself is a task, we need to handle parent task cancellation + # on the awaited line, since cancellation only propagates through awaited calls + # here, await setup_done.wait() is the propagation point from parent task to child tasks + # so we need to cancel child tasks whenever CancelledError is received + if self._server_task and not self._server_task.done(): + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass + + if client_task and not client_task.done(): + client_task.cancel() + try: + await client_task + except asyncio.CancelledError: + pass + raise def cleanup(self) -> None: if self._server_task is not None: diff --git a/infscale/execution/pipeline.py b/infscale/execution/pipeline.py index 66663eb..2a30662 100644 --- a/infscale/execution/pipeline.py +++ b/infscale/execution/pipeline.py @@ -104,6 +104,8 @@ async def _configure_multiworld(self, world_info: WorldInfo) -> None: port=port, device=self.device, ) + except asyncio.CancelledError: + logger.warning(f"multiworld configuration cancelled for {world_info.name}") except Exception as e: logger.error(f"failed to initialize a multiworld {name}: {e}") condition = self._status != WorkerStatus.UPDATING @@ -117,7 +119,7 @@ def _set_worker_status(self, status: WorkerStatus) -> None: """Set worker status in pipeline and channel.""" self._status = status - world_infos = self.config_manager.get_world_infos() + world_infos = self.config_manager.get_curr_world_infos() for world_info in world_infos.values(): world_info.channel.set_worker_status(status) @@ -130,12 +132,20 @@ def _set_n_send_worker_status(self, status: WorkerStatus) -> None: self.wcomm.send(msg) async def _configure_control_channel(self, world_info: WorldInfo) -> None: - await world_info.channel.setup() + try: + await world_info.channel.setup() - await world_info.channel.wait_readiness() + await world_info.channel.wait_readiness() + except asyncio.CancelledError: + logger.warning(f"channel configuration cancelled for {world_info}") def _reset_multiworld(self, world_info: WorldInfo) -> None: - self.world_manager.remove_world(world_info.multiworld_name) + try: + self.world_manager.remove_world(world_info.multiworld_name) + except ValueError as e: + logger.warning(f"failed to reset {world_info.multiworld_name}: {e}") + return + logger.info(f"remove world {world_info.multiworld_name} from multiworld") def _reset_control_channel(self, world_info: WorldInfo) -> None: @@ -144,7 +154,7 @@ def _reset_control_channel(self, world_info: WorldInfo) -> None: async def _cleanup_recovered_worlds(self) -> None: """Clean up world infos for recovered worlds.""" - world_infos = self.config_manager.get_world_infos() + world_infos = self.config_manager.get_curr_world_infos() # if I'm the recovered worker, return if len(world_infos) == 0: @@ -179,14 +189,18 @@ async def _configure(self) -> None: if not is_first_run: self._set_worker_status(WorkerStatus.UPDATING) - worlds_to_add, worlds_to_remove = ( + world_names_to_add, world_names_to_remove = ( self.config_manager.get_worlds_to_add_and_remove() ) tasks = [] # 1. set up control channel - for world_info in worlds_to_add: - task = self._configure_control_channel(world_info) + for world_name in world_names_to_add - self.config_manager.worlds_to_cancel: + world_info = self.config_manager.get_new_world_info(world_name) + + task = self.config_manager.schedule_world_cfg( + world_info, self._configure_control_channel + ) tasks.append(task) # TODO: this doesn't handle partial success @@ -195,8 +209,11 @@ async def _configure(self) -> None: tasks = [] # 2. set up multiworld - for world_info in worlds_to_add: - task = self._configure_multiworld(world_info) + for world_name in world_names_to_add - self.config_manager.worlds_to_cancel: + world_info = self.config_manager.get_new_world_info(world_name) + task = self.config_manager.schedule_world_cfg( + world_info, self._configure_multiworld + ) tasks.append(task) # TODO: this doesn't handle partial success @@ -204,7 +221,16 @@ async def _configure(self) -> None: await asyncio.gather(*tasks) # update world_info for added worlds - self.config_manager.set_world_infos(worlds_to_add) + self.config_manager.update_world_infos( + world_names_to_add - self.config_manager.worlds_to_cancel + ) + + worlds_to_add = self.config_manager.get_world_infos_by_name( + world_names_to_add - self.config_manager.worlds_to_cancel + ) + worlds_to_remove = self.config_manager.get_world_infos_by_name( + world_names_to_remove, True + ) # configure router with worlds to add and remove await self.router.configure( @@ -216,16 +242,17 @@ async def _configure(self) -> None: # handle unnecessary world # remove is executed in the reverse order of add - for world_info in worlds_to_remove: + for world_name in worlds_to_remove: # 1. remove unnecessary world from control channel - self._reset_control_channel(world_info) + self._reset_control_channel(world_name) # 2. remove unnecessary world from multiworld - self._reset_multiworld(world_info) + self._reset_multiworld(world_name) - self.config_manager.remove_world_info(world_info.name) + self.config_manager.remove_world_info(world_name.name) worker_status = WorkerStatus.RUNNING if is_first_run else WorkerStatus.UPDATED + self.config_manager.unblock_next_config() self._set_n_send_worker_status(worker_status) self.cfg_event.set() @@ -425,16 +452,16 @@ async def _handle_config(self, spec: ServeConfig) -> None: if spec is None: return - self.config_manager.handle_new_spec(spec) - self._configure_variables(spec) self._inspector.configure(self.spec) self._initialize_once() - # (re)configure the pipeline - await self.config_manager.schedule(self._configure) + await self.config_manager.handle_new_spec(spec) + # run configure as a separate task since we need to unblock receiving + # a new config to be processed when current configuration is finished + _ = asyncio.create_task(self._configure()) def _configure_variables(self, spec: ServeConfig) -> None: """Set variables that need to be updated."""