5050logger = init_logger (__name__ )
5151
5252
53- def save_eplb_state (tensor : torch .Tensor , save_dir : Path , dir_iter : int ) -> None :
54- try :
55- file_path = f"{ save_dir } /global_expert_load_window_i{ dir_iter } .safetensors" # noqa: E501
56- torch .save (
57- {
58- "global_expert_load_window" : tensor ,
59- },
60- file_path ,
61- )
62- logger .info ("Successfully saved to %s." , file_path )
63- except Exception as e :
64- logger .error ("An error occurred while saving the tensor: %s." , e )
65-
66-
67- def load_eplb_state (eplb_load_path : Path ) -> torch .Tensor :
68- loaded_tensors = torch .load (eplb_load_path )
69- logger .info ("Successfully loaded %s." , eplb_load_path )
70- return loaded_tensors ["global_expert_load_window" ]
71-
72-
7353@dataclass
7454class EplbModelState :
7555 """EPLB metrics."""
@@ -151,21 +131,44 @@ class EplbModelState:
151131 See:
152132 https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
153133 """
154- load_path : Path | None = None
155- """
156- Path for loading eplb initial state.
157- """
158- save_dir : Path | None = None
159- """
160- Path where eplb states will be saved.
161- """
162- expert_load_window_step : int = 0
163- """
164- Current step in the sliding window.
165134 model_name : str
166135 model : MixtureOfExperts
167136
168137
138+ def save_eplb_state (
139+ global_expert_load_windows : list [torch .Tensor ],
140+ save_dir : Path ,
141+ dir_iter : int ,
142+ model_states : list [EplbModelState ],
143+ ) -> None :
144+ tensors = {}
145+ for eplb_model_state , global_expert_load_window in zip (
146+ model_states , global_expert_load_windows
147+ ):
148+ name = type (eplb_model_state .model ).__name__
149+ tensors [f"global_expert_load_window_{ name } " ] = global_expert_load_window
150+ try :
151+ file_path = f"{ save_dir } /global_expert_load_window_i{ dir_iter } .safetensors" # noqa: E501
152+ torch .save (tensors , file_path )
153+ logger .info ("Successfully saved to %s." , file_path )
154+ except Exception as e :
155+ logger .error ("An error occurred while saving the tensor: %s." , e )
156+
157+
158+ def load_eplb_state (
159+ eplb_load_path : Path , model_states : list [EplbModelState ]
160+ ) -> list [torch .Tensor ]:
161+ loaded_tensors = torch .load (eplb_load_path )
162+ global_load_windows = []
163+ for eplb_model_state in model_states :
164+ name = type (eplb_model_state .model ).__name__
165+ tensor = loaded_tensors [f"global_expert_load_window_{ name } " ]
166+ tensor = tensor .to (eplb_model_state .expert_load_window .device )
167+ global_load_windows .append (tensor )
168+ logger .info ("Successfully loaded %s." , eplb_load_path )
169+ return global_load_windows
170+
171+
169172class EplbState :
170173 """
171174 EplbState of each expert parallel model. Key is the model config hash.
@@ -177,21 +180,6 @@ def __init__(self, parallel_config: ParallelConfig, device: torch.device):
177180 self .model_states : dict [str , EplbModelState ] = {}
178181 """
179182 Current step in the sliding window.
180-
181- NOTE : Keep in mind that all EP ranks need to have the same
182- `expert_rearrangement_step` value to ensure synchronization .
183- Otherwise , the rearrangement will hang at collective
184- communication calls .
185- """
186- expert_rearrangement_step_interval: int = 0
187- """
188- Interval for expert rearrangement steps .
189- This is a constant and is taken from the config .
190- """
191- save_dir_iter = 0
192- """
193- Saving directory iteration , used to save the expert load window .
194- """
195183 Different from `expert_rearrangement_step`,
196184 each EP rank may have its own `expert_load_window_step`.
197185 """
@@ -216,6 +204,7 @@ def __init__(self, parallel_config: ParallelConfig, device: torch.device):
216204 This is a constant and is taken from the config.
217205 """
218206 self .expert_rearrangement_step_interval : int = 0
207+ self .save_dir_iter = 0
219208
220209 @staticmethod
221210 def build_initial_global_physical_to_logical_map (
@@ -363,22 +352,16 @@ def add_model(
363352 device = self .device ,
364353 )
365354
366- eplb_load_path = parallel_config .eplb_config .load_path
367- eplb_save_dir = parallel_config .eplb_config .save_dir
368-
369- eplb_step_interval = parallel_config .eplb_config .step_interval
355+ eplb_load_path = self .parallel_config .eplb_config .load_path
356+ eplb_save_dir = self .parallel_config .eplb_config .save_dir
357+ eplb_step_interval = self .parallel_config .eplb_config .step_interval
370358 if eplb_load_path is not None or eplb_save_dir is not None :
371- expert_rearrangement_step = 0
359+ self . expert_rearrangement_step = 0
372360 else :
373361 # Set the initial progress of rearrangement to 3/4
374- expert_rearrangement_step = max (
362+ self . expert_rearrangement_step = max (
375363 0 , eplb_step_interval - eplb_step_interval // 4
376364 )
377- # Set the initial progress of rearrangement to 3/4
378- eplb_step_interval = self .parallel_config .eplb_config .step_interval
379- self .expert_rearrangement_step = max (
380- 0 , eplb_step_interval - eplb_step_interval // 4
381- )
382365 self .expert_rearrangement_step_interval = eplb_step_interval
383366
384367 if global_expert_load is not None :
@@ -448,11 +431,6 @@ def add_model(
448431 logical_replica_count ,
449432 expert_load_pass ,
450433 expert_load_window ,
451- eplb_load_path ,
452- eplb_save_dir ,
453- expert_load_window_size = expert_load_window_size ,
454- expert_rearrangement_step = expert_rearrangement_step ,
455- expert_rearrangement_step_interval = eplb_step_interval ,
456434 model_config .model ,
457435 model ,
458436 )
@@ -524,7 +502,7 @@ def step(
524502
525503 if ep_group .rank () == 0 :
526504 logger .info (
527- "EPLB step: %d for model %s: avg_tokens=%.2f, "
505+ "EPLBS step: %d for model %s: avg_tokens=%.2f, "
528506 "max_tokens=%d, balancedness=%.4f" ,
529507 self .expert_rearrangement_step ,
530508 eplb_model_state .model_name ,
@@ -551,10 +529,8 @@ def step(
551529 # performing collective communication.
552530 self .expert_rearrangement_step += 1
553531 if self .expert_rearrangement_step >= self .expert_rearrangement_step_interval :
554- self .rearrange (model )
555- self .expert_rearrangement_step = 0
556- self .expert_rearrangement_step = 0
557532 self .rearrange ()
533+ self .expert_rearrangement_step = 0
558534
559535 def rearrange (
560536 self ,
@@ -590,14 +566,23 @@ def rearrange(
590566 time_start = time .perf_counter ()
591567 logger .info ("Rearranging experts %s..." , "(profile)" if is_profile else "" )
592568
593- if self .load_path is not None and self .expert_rearrangement_step == 0 :
594- global_expert_load_window = load_eplb_state (self .load_path ).to (
595- self .physical_to_logical_map .device
569+ if (
570+ self .parallel_config .eplb_config .load_path is not None
571+ and self .expert_rearrangement_step == 0
572+ ):
573+ global_expert_load_windows = load_eplb_state (
574+ self .parallel_config .eplb_config .load_path ,
575+ list (self .model_states .values ()),
596576 )
597- elif global_expert_load is None :
598- if global_expert_loads is None :
577+ assert global_expert_load_windows is not None
578+ elif global_expert_loads is None :
599579 # Map the physical expert load to global logical experts
600580 global_expert_load_windows = []
581+ should_save_eplb_state = (
582+ self .parallel_config .eplb_config .save_dir is not None
583+ and not is_profile
584+ and self .expert_rearrangement_step > 0
585+ )
601586 if not execute_shuffle :
602587 num_models = torch .tensor (
603588 [len (self .model_states )], dtype = torch .int32 , device = "cpu"
@@ -622,17 +607,6 @@ def rearrange(
622607 src = eplb_model_state .expert_load_window ,
623608 )
624609
625- if (
626- is_main_rank
627- and self .save_dir is not None
628- and not is_profile
629- and self .expert_rearrangement_step > 0
630- ):
631- save_eplb_state (
632- global_expert_load_window , self .save_dir , self .save_dir_iter
633- )
634- self .save_dir_iter += 1
635-
636610 if not execute_shuffle :
637611 metadata = torch .tensor (
638612 [
@@ -649,6 +623,16 @@ def rearrange(
649623
650624 global_expert_load_window = logical_expert_load_window .sum (dim = 0 )
651625 global_expert_load_windows .append (global_expert_load_window )
626+
627+ if is_main_rank and should_save_eplb_state :
628+ save_eplb_state (
629+ global_expert_load_windows ,
630+ self .parallel_config .eplb_config .save_dir ,
631+ self .save_dir_iter ,
632+ list (self .model_states .values ()),
633+ )
634+ self .save_dir_iter += 1
635+
652636 # Perform all-reduce to get the expert load across all ranks for each model
653637 global_expert_load_windows = self ._allreduce_list (
654638 global_expert_load_windows
0 commit comments