2727from infscale import get_logger
2828from infscale .common .job_msg import Message , MessageType , WorkerStatus
2929from infscale .configs .job import ServeConfig
30- from infscale .execution .config_runner import ConfigRunner
30+ from infscale .execution .config_runner import ConfigManager
3131from infscale .execution .control import Channel as CtrlCh
3232from infscale .execution .metrics_collector import MetricsCollector
3333from infscale .execution .router import Router
@@ -68,13 +68,12 @@ def __init__(
6868 self .wcomm = wcomm
6969 self .spec : ServeConfig = None
7070 self .device = None
71- self .world_infos : dict [str , WorldInfo ] = {}
7271 self .cfg_event = asyncio .Event ()
7372 self ._micro_batch_size = 1
7473 self ._initialized = False
7574 self ._inspector = PipelineInspector ()
7675 self ._status : WorkerStatus = WorkerStatus .READY
77- self .config_runner = ConfigRunner ()
76+ self .config_manager = ConfigManager ()
7877
7978 # TODO: these variables are only for a server (i.e., dispatcher)
8079 # need to consider refactoring pipeline such that server code
@@ -118,7 +117,9 @@ def _set_worker_status(self, status: WorkerStatus) -> None:
118117 """Set worker status in pipeline and channel."""
119118 self ._status = status
120119
121- for world_info in self .world_infos .values ():
120+ world_infos = self .config_manager .get_world_infos ()
121+
122+ for world_info in world_infos .values ():
122123 world_info .channel .set_worker_status (status )
123124
124125 def _set_n_send_worker_status (self , status : WorkerStatus ) -> None :
@@ -143,49 +144,45 @@ def _reset_control_channel(self, world_info: WorldInfo) -> None:
143144
144145 async def _cleanup_recovered_worlds (self ) -> None :
145146 """Clean up world infos for recovered worlds."""
147+ world_infos = self .config_manager .get_world_infos ()
148+
146149 # if I'm the recovered worker, return
147- if len (self . world_infos ) == 0 :
150+ if len (world_infos ) == 0 :
148151 return
149152
150153 recover_worlds = [
151154 world_info
152155 for world_list in self .spec .flow_graph .values ()
153156 for world_info in world_list
154- if world_info .recover and world_info .name in self . world_infos
157+ if world_info .recover and world_info .name in world_infos
155158 ]
156159
157160 # no worlds to recover
158161 if len (recover_worlds ) == 0 :
159162 return
160163
161164 for world_info in recover_worlds :
162- wi = self . world_infos .get (world_info .name , None )
165+ wi = world_infos .get (world_info .name , None )
163166
164167 await self .router .cleanup_world (wi )
165168 self ._reset_control_channel (wi )
166169 self ._reset_multiworld (wi )
167170
168- del self .world_infos [ wi .name ]
171+ self .config_manager . remove_world_info ( wi .name )
169172
170173 async def _configure (self ) -> None :
171174 """(Re)configure multiworld, control channel and router."""
172175 await self ._cleanup_recovered_worlds ()
173176
174- is_first_run = not self .world_infos
177+ is_first_run = self .config_manager . is_first_run ()
175178
176179 if not is_first_run :
177180 self ._set_worker_status (WorkerStatus .UPDATING )
178181
179- new_world_infos = self ._build_world_infos ()
180- new = new_world_infos .keys ()
181- cur = self .world_infos .keys ()
182-
183- worlds_to_add = [new_world_infos [name ] for name in new - cur ]
184- worlds_to_remove = [self .world_infos [name ] for name in cur - new ]
185-
186- self .config_runner .set_worlds_to_configure (new - cur )
182+ worlds_to_add , worlds_to_remove = (
183+ self .config_manager .get_worlds_to_add_and_remove ()
184+ )
187185
188- # handle new worlds
189186 tasks = []
190187 # 1. set up control channel
191188 for world_info in worlds_to_add :
@@ -207,12 +204,14 @@ async def _configure(self) -> None:
207204 await asyncio .gather (* tasks )
208205
209206 # update world_info for added worlds
210- for world_info in worlds_to_add :
211- self .world_infos [world_info .name ] = world_info
207+ self .config_manager .set_world_infos (worlds_to_add )
212208
213209 # configure router with worlds to add and remove
214210 await self .router .configure (
215- self .spec , self .device , worlds_to_add , worlds_to_remove
211+ self .spec ,
212+ self .device ,
213+ worlds_to_add ,
214+ worlds_to_remove ,
216215 )
217216
218217 # handle unnecessary world
@@ -223,7 +222,7 @@ async def _configure(self) -> None:
223222 # 2. remove unnecessary world from multiworld
224223 self ._reset_multiworld (world_info )
225224
226- del self .world_infos [ world_info .name ]
225+ self .config_manager . remove_world_info ( world_info .name )
227226
228227 worker_status = WorkerStatus .RUNNING if is_first_run else WorkerStatus .UPDATED
229228
@@ -426,7 +425,7 @@ async def _handle_config(self, spec: ServeConfig) -> None:
426425 if spec is None :
427426 return
428427
429- self .config_runner .handle_new_spec (spec )
428+ self .config_manager .handle_new_spec (spec )
430429
431430 self ._configure_variables (spec )
432431
@@ -435,61 +434,7 @@ async def _handle_config(self, spec: ServeConfig) -> None:
435434 self ._initialize_once ()
436435
437436 # (re)configure the pipeline
438- await self .config_runner .schedule (self ._configure )
439-
440- def _build_world_infos (self ) -> dict [str , WorldInfo ]:
441- world_infos : dict [str , WorldInfo ] = {}
442-
443- my_id = self .spec .stage .id
444- for k , v in self .spec .flow_graph .items ():
445- for cfg_world_info in v :
446- # NOTE: no. of peers is always 1 for now
447- assert len (cfg_world_info .peers ) == 1
448-
449- if my_id == k :
450- my_rank = 0
451- other_rank = 1
452- other_id = cfg_world_info .peers [0 ]
453- elif my_id in cfg_world_info .peers :
454- # NOTE: this is always 1 for now
455- my_rank = cfg_world_info .peers .index (my_id ) + 1
456- other_rank = 0
457- other_id = k
458- else :
459- continue
460-
461- name , backend , addr , data_port , ctrl_port , recover , conflict_count = (
462- cfg_world_info .name ,
463- cfg_world_info .backend ,
464- cfg_world_info .addr ,
465- cfg_world_info .data_port ,
466- cfg_world_info .ctrl_port ,
467- cfg_world_info .recover ,
468- cfg_world_info .conflict_count ,
469- )
470-
471- world_size = len (cfg_world_info .peers ) + 1
472- ctrl_ch = CtrlCh (my_rank , world_size , addr , ctrl_port )
473-
474- data = {
475- "name" : name ,
476- "size" : world_size ,
477- "addr" : addr ,
478- "port" : data_port ,
479- "backend" : backend ,
480- "channel" : ctrl_ch ,
481- "my_id" : my_id ,
482- "me" : my_rank ,
483- "other_id" : other_id ,
484- "other" : other_rank ,
485- "recover" : recover ,
486- "conflict_count" : conflict_count ,
487- "multiworld_name" : f"{ name } -{ conflict_count } " ,
488- }
489- world_info = WorldInfo (** data )
490- world_infos [name ] = world_info
491-
492- return world_infos
437+ await self .config_manager .schedule (self ._configure )
493438
494439 def _configure_variables (self , spec : ServeConfig ) -> None :
495440 """Set variables that need to be updated."""
0 commit comments