3232from pathlib import Path
3333
3434import torch
35+ from safetensors .torch import load_file , save_file
3536from torch .distributed import ProcessGroup , all_reduce
3637
3738from vllm .config import ModelConfig , ParallelConfig
@@ -149,7 +150,7 @@ def save_eplb_state(
149150 tensors [f"global_expert_load_window_{ name } " ] = global_expert_load_window
150151 try :
151152 file_path = f"{ save_dir } /global_expert_load_window_i{ dir_iter } .safetensors" # noqa: E501
152- torch . save (tensors , file_path )
153+ save_file (tensors , file_path )
153154 logger .info ("Successfully saved to %s." , file_path )
154155 except Exception as e :
155156 logger .error ("An error occurred while saving the tensor: %s." , e )
@@ -158,7 +159,7 @@ def save_eplb_state(
158159def load_eplb_state (
159160 eplb_load_path : Path , model_states : list [EplbModelState ]
160161) -> list [torch .Tensor ]:
161- loaded_tensors = torch . load (eplb_load_path )
162+ loaded_tensors = load_file (eplb_load_path )
162163 global_load_windows = []
163164 for eplb_model_state in model_states :
164165 name = type (eplb_model_state .model ).__name__
@@ -527,13 +528,14 @@ def step(
527528 # performing collective communication.
528529 self .expert_rearrangement_step += 1
529530 if self .expert_rearrangement_step >= self .expert_rearrangement_step_interval :
530- self .rearrange ()
531531 self .expert_rearrangement_step = 0
532+ self .rearrange ()
532533
533534 def rearrange (
534535 self ,
535536 is_profile : bool = False ,
536537 execute_shuffle : bool = True ,
538+ load_initial_load_window : bool = False ,
537539 global_expert_loads : list [torch .Tensor ] | None = None ,
538540 rank_mapping : dict [int , int ] | None = None ,
539541 ) -> torch .Tensor | None :
@@ -564,10 +566,7 @@ def rearrange(
564566 time_start = time .perf_counter ()
565567 logger .info ("Rearranging experts %s..." , "(profile)" if is_profile else "" )
566568
567- if (
568- self .parallel_config .eplb_config .load_path is not None
569- and self .expert_rearrangement_step == 0
570- ):
569+ if load_initial_load_window :
571570 global_expert_load_windows = load_eplb_state (
572571 self .parallel_config .eplb_config .load_path ,
573572 list (self .model_states .values ()),
@@ -578,8 +577,7 @@ def rearrange(
578577 global_expert_load_windows = []
579578 should_save_eplb_state = (
580579 self .parallel_config .eplb_config .save_load_window
581- and not is_profile
582- and self .expert_rearrangement_step > 0
580+ and not is_profile and not load_initial_load_window
583581 )
584582 if not execute_shuffle :
585583 num_models = torch .tensor (
0 commit comments