@@ -170,13 +170,16 @@ def get_worlds_to_configure(
170170 updated_worlds = {
171171 k
172172 for k in common_keys
173- if (curr_worlds [k ].addr != new_worlds [k ].addr
173+ if (
174+ curr_worlds [k ].addr != new_worlds [k ].addr
174175 or curr_worlds [k ].data_port != new_worlds [k ].data_port
175- or curr_worlds [k ].ctrl_port != new_worlds [k ].ctrl_port )
176+ or curr_worlds [k ].ctrl_port != new_worlds [k ].ctrl_port
177+ )
176178 }
177179
178180 return deploy_worlds | updated_worlds
179-
181+
182+
180183class ServeConfigHelper :
181184 """Class for defining helper methods for serve config."""
182185
@@ -399,7 +402,7 @@ def is_identical(x: JobConfig, y: JobConfig) -> bool:
399402 def world_name (world_id : int ) -> str :
400403 """Return world name given a world id."""
401404 return f"w{ world_id } "
402-
405+
403406 @staticmethod
404407 def get_pipeline_identifiers (new_cfg : JobConfig ) -> set [str ]:
405408 """Get pipeline identifiers based on server id."""
@@ -450,30 +453,38 @@ def categorize_workers(
450453 # select workers that will be affected by workers to be started
451454 for w , world_info_list in new_config .flow_graph .items ():
452455 for new_world_info in world_info_list :
453- curr_world_info = helper .find_matching_world_info (curr_config , w , new_world_info )
454- helper .pick_workers (update_wrkrs , start_wrkrs , w , new_world_info , curr_world_info )
456+ curr_world_info = helper .find_matching_world_info (
457+ curr_config , w , new_world_info
458+ )
459+ helper .pick_workers (
460+ update_wrkrs , start_wrkrs , w , new_world_info , curr_world_info
461+ )
455462
456463 if curr_config is None :
457464 return start_wrkrs , update_wrkrs , stop_wrkrs
458465
459466 # select workers that will be affected by workers to be stopped
460467 for w , world_info_list in curr_config .flow_graph .items ():
461468 for new_world_info in world_info_list :
462- curr_world_info = helper .find_matching_world_info (curr_config , w , new_world_info )
463- helper .pick_workers (update_wrkrs , stop_wrkrs , w , new_world_info , curr_world_info )
469+ curr_world_info = helper .find_matching_world_info (
470+ curr_config , w , new_world_info
471+ )
472+ helper .pick_workers (
473+ update_wrkrs , stop_wrkrs , w , new_world_info , curr_world_info
474+ )
464475
465476 # due to pervious state, recover workers are included in update workers
466477 # therefore, recover workers need to be removed from the updated ones.
467478 update_wrkrs -= recover_wrkrs
468479
469480 return start_wrkrs , update_wrkrs , stop_wrkrs
470-
481+
471482 @staticmethod
472483 def get_workers_diff (a : JobConfig , b : JobConfig ) -> set [str ]:
473484 """Return a set of worker ids diffs based on old and new cfg."""
474485 old_workers = {worker .id for worker in a .workers }
475486 new_workers = {worker .id for worker in b .workers }
476-
487+
477488 return old_workers - new_workers
478489
479490 @staticmethod
@@ -509,10 +520,10 @@ def remove_pipeline(config: JobConfig, workers_to_remove: set[str]) -> JobConfig
509520
510521class JobConfigHelper :
511522 """Class for defining helper methods for job config."""
523+
512524 def get_server_id (self , config : JobConfig ) -> str :
513525 return next ((w .id for w in config .workers if w .is_server ), "" )
514-
515-
526+
516527 def find_pipeline_nodes (
517528 self ,
518529 flow_graph : dict [str , list [WorldInfo ]],
@@ -577,9 +588,7 @@ def find_pipeline_nodes(
577588
578589 # everything else (except server) is removed
579590 to_remove = {
580- wid
581- for wid in flow_graph
582- if wid != server_id and wid not in survivors
591+ wid for wid in flow_graph if wid != server_id and wid not in survivors
583592 }
584593
585594 return to_remove
@@ -600,11 +609,13 @@ def pick_workers(
600609
601610 The needles are workers to start or stop and the haystack is
602611 name and peers.
603-
612+
604613 Also includes peers of `name` if its connection details
605614 (`addr`, `ctrl_port`, `data_port`) differ from the previous config.
606615 """
607- if curr_world_info and self .has_connection_changed (curr_world_info , new_world_info ):
616+ if curr_world_info and self .has_connection_changed (
617+ curr_world_info , new_world_info
618+ ):
608619 for peer in new_world_info .peers :
609620 res_set .add (peer )
610621
@@ -634,15 +645,15 @@ def pick_workers(
634645 # because name is already affected by one peer
635646 # so we come out of the for-loop
636647 break
637-
648+
638649 def has_connection_changed (self , old : WorldInfo , new : WorldInfo ) -> bool :
639650 """Check if worker connection details are changed."""
640651 return (
641- old .addr != new .addr or
642- old .ctrl_port != new .ctrl_port or
643- old .data_port != new .data_port
652+ old .addr != new .addr
653+ or old .ctrl_port != new .ctrl_port
654+ or old .data_port != new .data_port
644655 )
645-
656+
646657 def find_matching_world_info (
647658 self , curr_config : JobConfig | None , w : str , new_world_info : WorldInfo
648659 ) -> WorldInfo | None :
0 commit comments