Skip to content

Commit ed322fc

Browse files
committed
more fixes
Signed-off-by: Patryk Saffer <[email protected]>
1 parent ecec01d commit ed322fc

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

vllm/config/parallel.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,5 +620,17 @@ def _verify_args(self) -> Self:
620620
raise ValueError(
621621
"Unable to use nsight profiling unless workers run with Ray."
622622
)
623-
623+
624+
if self.eplb_config.load_initial_load_window and self.eplb_config.load_path is None:
625+
raise ValueError(
626+
"load_initial_load_window is set to True, but load_path is not provided."
627+
)
628+
if self.eplb_config.save_load_window and self.eplb_config.save_dir is None:
629+
raise ValueError(
630+
"save_load_window is set to True, but save_dir is provided."
631+
)
632+
if self.eplb_config.save_load_window and self.eplb_config.static:
633+
raise ValueError(
634+
"save_load_window is set to True, but static is set to True."
635+
)
624636
return self

vllm/distributed/eplb/eplb_state.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pathlib import Path
3333

3434
import torch
35+
from safetensors.torch import load_file, save_file
3536
from torch.distributed import ProcessGroup, all_reduce
3637

3738
from vllm.config import ModelConfig, ParallelConfig
@@ -149,7 +150,7 @@ def save_eplb_state(
149150
tensors[f"global_expert_load_window_{name}"] = global_expert_load_window
150151
try:
151152
file_path = f"{save_dir}/global_expert_load_window_i{dir_iter}.safetensors" # noqa: E501
152-
torch.save(tensors, file_path)
153+
save_file(tensors, file_path)
153154
logger.info("Successfully saved to %s.", file_path)
154155
except Exception as e:
155156
logger.error("An error occurred while saving the tensor: %s.", e)
@@ -158,7 +159,7 @@ def save_eplb_state(
158159
def load_eplb_state(
159160
eplb_load_path: Path, model_states: list[EplbModelState]
160161
) -> list[torch.Tensor]:
161-
loaded_tensors = torch.load(eplb_load_path)
162+
loaded_tensors = load_file(eplb_load_path)
162163
global_load_windows = []
163164
for eplb_model_state in model_states:
164165
name = type(eplb_model_state.model).__name__
@@ -527,13 +528,14 @@ def step(
527528
# performing collective communication.
528529
self.expert_rearrangement_step += 1
529530
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
530-
self.rearrange()
531531
self.expert_rearrangement_step = 0
532+
self.rearrange()
532533

533534
def rearrange(
534535
self,
535536
is_profile: bool = False,
536537
execute_shuffle: bool = True,
538+
load_initial_load_window: bool = False,
537539
global_expert_loads: list[torch.Tensor] | None = None,
538540
rank_mapping: dict[int, int] | None = None,
539541
) -> torch.Tensor | None:
@@ -564,10 +566,7 @@ def rearrange(
564566
time_start = time.perf_counter()
565567
logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
566568

567-
if (
568-
self.parallel_config.eplb_config.load_path is not None
569-
and self.expert_rearrangement_step == 0
570-
):
569+
if load_initial_load_window:
571570
global_expert_load_windows = load_eplb_state(
572571
self.parallel_config.eplb_config.load_path,
573572
list(self.model_states.values()),
@@ -578,8 +577,7 @@ def rearrange(
578577
global_expert_load_windows = []
579578
should_save_eplb_state = (
580579
self.parallel_config.eplb_config.save_load_window
581-
and not is_profile
582-
and self.expert_rearrangement_step > 0
580+
and not is_profile and not load_initial_load_window
583581
)
584582
if not execute_shuffle:
585583
num_models = torch.tensor(

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3160,8 +3160,8 @@ def load_model(self, eep_scale_up: bool = False) -> None:
31603160
rank_mapping,
31613161
)
31623162
if self.parallel_config.eplb_config.load_path is not None:
3163-
self.eplb_state.rearrange(self.model)
3164-
if self.parallel_config.eplb_config.save_dir is None:
3163+
self.eplb_state.rearrange(self.model, load_initial_load_window=True)
3164+
if self.parallel_config.eplb_config.static:
31653165
self.eplb_state = None
31663166

31673167
if (

0 commit comments

Comments
 (0)