@@ -30,38 +30,43 @@ class ConfigManager:
3030 def __init__ (self ):
3131 """Initialize config manager instance."""
3232 self ._loop = asyncio .get_event_loop ()
33- self ._task : asyncio .Task | None = None
34- self ._event = asyncio .Event ()
33+ # semaphore event for back-to-back configs
34+ self ._config_event = asyncio .Event ()
35+ self ._config_event .set ()
36+ self ._world_tasks : dict [str , asyncio .Task ] = {}
3537 self ._spec : ServeConfig = None
36- self ._event .set ()
3738 self ._curr_worlds_to_configure : set [str ] = set ()
38- self ._cancel_cur_cfg = False
3939 self ._world_infos : dict [str , WorldInfo ] = {}
40+ self .worlds_to_cancel = set ()
4041
41- def handle_new_spec (self , spec : ServeConfig ) -> None :
42+ async def handle_new_spec (self , spec : ServeConfig ) -> None :
4243 """Handle new spec."""
43- self ._cancel_cur_cfg = self ._should_cancel_current (spec )
44+ new_worlds_to_configure = ServeConfig .get_worlds_to_configure (self ._spec , spec )
45+ self .worlds_to_cancel = new_worlds_to_configure & self ._curr_worlds_to_configure
4446 self ._spec = spec
4547
46- def _should_cancel_current (self , spec : ServeConfig ) -> bool :
47- """Decide if current configuration should be cancelled."""
48- if self ._spec is None :
49- return False
48+ if len (self .worlds_to_cancel ):
49+ await self ._cancel_world_configuration (self .worlds_to_cancel )
5050
51- new_worlds_to_configure = ServeConfig .get_worlds_to_configure (self ._spec , spec )
51+ # wait for current configuration to finish
52+ await self ._config_event .wait ()
53+
54+ # block configuration until current config is processed
55+ # when a new spec is waiting, we want to block the execution until
56+ # the current one is under configuration.
57+ self ._config_event .clear ()
5258
53- # cancel if the new config affects worlds currently being configured
54- # TODO: if there's a overlap between new worlds and curr worlds we cancel
55- # current configuration. This needs to be fixed, to cancel only the worlds that
56- # are affected (eg new_worlds & curr_worlds)
57- return not new_worlds_to_configure .isdisjoint (self ._curr_worlds_to_configure )
59+ def reset_state (self ) -> None :
60+ """Reset any state that is kept between configs."""
61+ self ._curr_worlds_to_configure = set ()
62+ self .worlds_to_cancel = set ()
5863
59- def set_worlds_to_configure (self , world_names : set [ str ] ) -> None :
60- """Set the world names currently being configured ."""
61- self ._curr_worlds_to_configure = world_names
64+ def unblock_next_config (self ) -> None :
65+ """Set task event and unblock next config process ."""
66+ self ._config_event . set ()
6267
63- def set_world_infos (self , worlds : list [WorldInfo ]) -> None :
64- """Set new world infos."""
68+ def update_world_infos (self , worlds : list [WorldInfo ]) -> None :
69+ """Update world infos."""
6570 for world_info in worlds :
6671 self ._world_infos [world_info .name ] = world_info
6772
@@ -87,25 +92,29 @@ def get_worlds_to_add_and_remove(self) -> tuple[list[WorldInfo], list[WorldInfo]
8792 worlds_to_add = [new_world_infos [name ] for name in new - cur ]
8893 worlds_to_remove = [new_world_infos [name ] for name in cur - new ]
8994
90- return worlds_to_add , worlds_to_remove
91-
92- async def schedule (self , coro_factory : Callable [[], Awaitable [None ]]):
93- """Cancel any in-progress configure and schedule a new one."""
94- # wait for current to finish if we do not want to cancel
95- if not self ._cancel_cur_cfg :
96- await self ._event .wait ()
95+ self ._curr_worlds_to_configure = new - cur
9796
98- # cancel current if running
99- if self ._task and not self ._task .done ():
100- self ._task .cancel ()
101- try :
102- await self ._task
103- except asyncio .CancelledError :
104- pass
97+ return worlds_to_add , worlds_to_remove
10598
106- # block again for new run
107- self ._event .clear ()
108- self ._task = self ._loop .create_task (self ._run (coro_factory ))
99+ async def _cancel_world_configuration (self , world_names : set [str ]):
100+ """Cancel only worlds that are impacted by new spec."""
101+ coroutines = [self ._cancel_world (w ) for w in world_names ]
102+ await asyncio .gather (* coroutines , return_exceptions = True )
103+
104+ def schedule_world_cfg (
105+ self , world_info : WorldInfo , coro_factory : Callable [[], Awaitable [None ]]
106+ ):
107+ """Schedule configuration for a single world."""
108+ task = self ._loop .create_task (self ._run_world (world_info , coro_factory ))
109+ self ._world_tasks [world_info .name ] = task
110+ return task
111+
112+ async def _cancel_world (self , world_name : str ):
113+ """Cancel an in-progress world config task."""
114+ task = self ._world_tasks .pop (world_name , None )
115+ if task and not task .done ():
116+ task .cancel ()
117+ raise asyncio .CancelledError
109118
110119 def _build_world_infos (self ) -> dict [str , WorldInfo ]:
111120 world_infos : dict [str , WorldInfo ] = {}
@@ -161,13 +170,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:
161170
162171 return world_infos
163172
164- async def _run (self , coro_factory : Callable [[], Awaitable [None ]]):
165- """Run coroutine factory."""
173+ async def _run_world (
174+ self , world_info : WorldInfo , coro_factory : Callable [[], Awaitable [None ]]
175+ ):
176+ """Run and cleanup world configuration."""
166177 try :
167- await coro_factory ()
178+ await coro_factory (world_info )
168179 except asyncio .CancelledError :
169- pass
180+ raise
170181 finally :
171- # reset class attributes and events
172- self ._event .set ()
173- self ._curr_worlds_to_configure = set ()
182+ self ._world_tasks .pop (world_info .name , None )
0 commit comments