@@ -30,88 +30,126 @@ class ConfigManager:
3030 def __init__ (self ):
3131 """Initialize config manager instance."""
3232 self ._loop = asyncio .get_event_loop ()
33- self ._task : asyncio .Task | None = None
34- self ._event = asyncio .Event ()
35- self ._spec : ServeConfig = None
36- self ._event .set ()
37- self ._curr_worlds_to_configure : set [str ] = set ()
38- self ._cancel_cur_cfg = False
39- self ._world_infos : dict [str , WorldInfo ] = {}
40-
41- def handle_new_spec (self , spec : ServeConfig ) -> None :
33+ # semaphore event for back-to-back configs
34+ self ._config_event = asyncio .Event ()
35+ self ._config_event .set ()
36+ self ._world_tasks : dict [str , asyncio .Task ] = {}
37+ self ._new_spec : ServeConfig = None
38+ self ._curr_spec : ServeConfig = None
39+ self ._curr_world_infos : dict [str , WorldInfo ] = {}
40+ self ._new_world_infos : dict [str , WorldInfo ] = {}
41+ self .worlds_to_cancel = set ()
42+
43+ async def handle_new_spec (self , spec : ServeConfig ) -> None :
4244 """Handle new spec."""
43- self ._cancel_cur_cfg = self ._should_cancel_current (spec )
44- self ._spec = spec
45+ new_worlds_to_configure = ServeConfig .get_worlds_to_configure (
46+ self ._curr_spec , spec
47+ )
4548
46- def _should_cancel_current ( self , spec : ServeConfig ) -> bool :
47- """Decide if current configuration should be cancelled."""
48- if self ._spec is None :
49- return False
49+ # on the first run, both new and cur will be empty sets
50+ new = self . _new_world_infos . keys ()
51+ cur = self ._curr_world_infos . keys ()
52+ curr_worlds_to_configure = new - cur
5053
51- new_worlds_to_configure = ServeConfig . get_worlds_to_configure ( self . _spec , spec )
54+ self . worlds_to_cancel = new_worlds_to_configure & curr_worlds_to_configure
5255
53- # cancel if the new config affects worlds currently being configured
54- # TODO: if there's a overlap between new worlds and curr worlds we cancel
55- # current configuration. This needs to be fixed, to cancel only the worlds that
56- # are affected (eg new_worlds & curr_worlds )
57- return not new_worlds_to_configure . isdisjoint ( self ._curr_worlds_to_configure )
56+ self . _curr_world_infos = (
57+ self . _build_world_infos ( self . _curr_spec ) if self . _curr_spec else {}
58+ )
59+ self . _new_world_infos = self . _build_world_infos ( spec )
60+ self ._new_spec = spec
5861
59- def set_worlds_to_configure (self , world_names : set [str ]) -> None :
60- """Set the world names currently being configured."""
61- self ._curr_worlds_to_configure = world_names
62+ if len (self .worlds_to_cancel ):
63+ await self ._cancel_world_configuration (self .worlds_to_cancel )
6264
63- def set_world_infos (self , worlds : list [WorldInfo ]) -> None :
64- """Set new world infos."""
65- for world_info in worlds :
66- self ._world_infos [world_info .name ] = world_info
65+ # wait for current configuration to finish
66+ await self ._config_event .wait ()
6767
68- def get_world_infos (self ) -> dict [str , WorldInfo ]:
69- "Get world infos."
70- return self ._world_infos
68+ # do cleanup after each successful configuration
69+ self ._curr_spec = self ._new_spec
70+ self ._new_spec = None
71+ self .worlds_to_cancel = set ()
72+
73+ # block handling new spec after doing cleanup for the current one
74+ self ._config_event .clear ()
75+
76+ def unblock_next_config (self ) -> None :
77+ """Set task event and unblock next config process."""
78+ self ._config_event .set ()
79+
80+ def update_world_infos (self , worlds_names : set [str ]) -> None :
81+ """Update world infos."""
82+ for world_name in worlds_names :
83+ world_info = self ._new_world_infos [world_name ]
84+ self ._curr_world_infos [world_info .name ] = world_info
85+
86+ def get_curr_world_infos (self ) -> dict [str , WorldInfo ]:
87+ "Get current world infos."
88+ return self ._curr_world_infos
7189
7290 def is_first_run (self ) -> bool :
7391 "Return boolean if is first run or not."
74- return not self ._world_infos
92+ return not self ._curr_world_infos
7593
7694 def remove_world_info (self , world_name : str ) -> None :
7795 """Remove world info by name."""
78- del self ._world_infos [world_name ]
96+ del self ._curr_world_infos [world_name ]
7997
80- def get_worlds_to_add_and_remove (self ) -> tuple [list [ WorldInfo ], list [ WorldInfo ]]:
98+ def get_worlds_to_add_and_remove (self ) -> tuple [set [ str ], set [ str ]]:
8199 """Return a list of world infos to add and to remove."""
82- new_world_infos = self ._build_world_infos ()
83-
84- new = new_world_infos .keys ()
85- cur = self ._world_infos .keys ()
100+ new = self ._new_world_infos .keys ()
101+ cur = self ._curr_world_infos .keys ()
86102
87- worlds_to_add = [ new_world_infos [ name ] for name in new - cur ]
88- worlds_to_remove = [ new_world_infos [ name ] for name in cur - new ]
103+ worlds_to_add = new - cur
104+ worlds_to_remove = cur - new
89105
90106 return worlds_to_add , worlds_to_remove
91107
92- async def schedule (self , coro_factory : Callable [[], Awaitable [None ]]):
93- """Cancel any in-progress configure and schedule a new one."""
94- # wait for current to finish if we do not want to cancel
95- if not self ._cancel_cur_cfg :
96- await self ._event .wait ()
97-
98- # cancel current if running
99- if self ._task and not self ._task .done ():
100- self ._task .cancel ()
101- try :
102- await self ._task
103- except asyncio .CancelledError :
104- pass
105-
106- # block again for new run
107- self ._event .clear ()
108- self ._task = self ._loop .create_task (self ._run (coro_factory ))
109-
110- def _build_world_infos (self ) -> dict [str , WorldInfo ]:
108+ def get_new_world_info (self , world_name ) -> WorldInfo :
109+ """Return new world info based on world name."""
110+ return self ._new_world_infos [world_name ]
111+
112+ def get_world_infos_by_name (
113+ self , world_names : set [str ], to_remove : bool = False
114+ ) -> list [WorldInfo ]:
115+ """Return a list of world infos to add based on world names."""
116+ world_infos = []
117+
118+ target = self ._new_world_infos
119+
120+ if to_remove :
121+ target = self ._curr_world_infos
122+
123+ for world_name in world_names :
124+ world_infos .append (target [world_name ])
125+
126+ return world_infos
127+
128+ async def _cancel_world_configuration (self , world_names : set [str ]):
129+ """Cancel only worlds that are impacted by new spec."""
130+ coroutines = [self ._cancel_world (w ) for w in world_names ]
131+ await asyncio .gather (* coroutines , return_exceptions = True )
132+
133+ def schedule_world_cfg (
134+ self , world_info : WorldInfo , coro_factory : Callable [[], Awaitable [None ]]
135+ ):
136+ """Schedule configuration for a single world."""
137+ task = self ._loop .create_task (self ._run_world (world_info , coro_factory ))
138+ self ._world_tasks [world_info .name ] = task
139+ return task
140+
141+ async def _cancel_world (self , world_name : str ):
142+ """Cancel an in-progress world config task."""
143+ task = self ._world_tasks .pop (world_name , None )
144+ if task and not task .done ():
145+ task .cancel ()
146+ raise asyncio .CancelledError
147+
148+ def _build_world_infos (self , spec : ServeConfig ) -> dict [str , WorldInfo ]:
111149 world_infos : dict [str , WorldInfo ] = {}
112150
113- my_id = self . _spec .stage .id
114- for k , v in self . _spec .flow_graph .items ():
151+ my_id = spec .stage .id
152+ for k , v in spec .flow_graph .items ():
115153 for cfg_world_info in v :
116154 # NOTE: no. of peers is always 1 for now
117155 assert len (cfg_world_info .peers ) == 1
@@ -161,13 +199,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:
161199
162200 return world_infos
163201
164- async def _run (self , coro_factory : Callable [[], Awaitable [None ]]):
165- """Run coroutine factory."""
202+ async def _run_world (
203+ self , world_info : WorldInfo , coro_factory : Callable [[], Awaitable [None ]]
204+ ):
205+ """Run and cleanup world configuration."""
166206 try :
167- await coro_factory ()
207+ await coro_factory (world_info )
168208 except asyncio .CancelledError :
169- pass
209+ raise
170210 finally :
171- # reset class attributes and events
172- self ._event .set ()
173- self ._curr_worlds_to_configure = set ()
211+ self ._world_tasks .pop (world_info .name , None )
0 commit comments