Skip to content

Commit 783ca91

Browse files
committed
offline eplb with eagle
Signed-off-by: Patryk Saffer <[email protected]>
1 parent 45c9ce4 commit 783ca91

File tree

2 files changed

+68
-82
lines changed

2 files changed

+68
-82
lines changed

vllm/distributed/eplb/eplb_state.py

Lines changed: 66 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,6 @@
5050
logger = init_logger(__name__)
5151

5252

53-
def save_eplb_state(tensor: torch.Tensor, save_dir: Path, dir_iter: int) -> None:
54-
try:
55-
file_path = f"{save_dir}/global_expert_load_window_i{dir_iter}.safetensors" # noqa: E501
56-
torch.save(
57-
{
58-
"global_expert_load_window": tensor,
59-
},
60-
file_path,
61-
)
62-
logger.info("Successfully saved to %s.", file_path)
63-
except Exception as e:
64-
logger.error("An error occurred while saving the tensor: %s.", e)
65-
66-
67-
def load_eplb_state(eplb_load_path: Path) -> torch.Tensor:
68-
loaded_tensors = torch.load(eplb_load_path)
69-
logger.info("Successfully loaded %s.", eplb_load_path)
70-
return loaded_tensors["global_expert_load_window"]
71-
72-
7353
@dataclass
7454
class EplbModelState:
7555
"""EPLB metrics."""
@@ -151,21 +131,44 @@ class EplbModelState:
151131
See:
152132
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
153133
"""
154-
load_path: Path | None = None
155-
"""
156-
Path for loading eplb initial state.
157-
"""
158-
save_dir: Path | None = None
159-
"""
160-
Path where eplb states will be saved.
161-
"""
162-
expert_load_window_step: int = 0
163-
"""
164-
Current step in the sliding window.
165134
model_name: str
166135
model: MixtureOfExperts
167136

168137

138+
def save_eplb_state(
139+
global_expert_load_windows: list[torch.Tensor],
140+
save_dir: Path,
141+
dir_iter: int,
142+
model_states: list[EplbModelState],
143+
) -> None:
144+
tensors = {}
145+
for eplb_model_state, global_expert_load_window in zip(
146+
model_states, global_expert_load_windows
147+
):
148+
name = type(eplb_model_state.model).__name__
149+
tensors[f"global_expert_load_window_{name}"] = global_expert_load_window
150+
try:
151+
file_path = f"{save_dir}/global_expert_load_window_i{dir_iter}.safetensors" # noqa: E501
152+
torch.save(tensors, file_path)
153+
logger.info("Successfully saved to %s.", file_path)
154+
except Exception as e:
155+
logger.error("An error occurred while saving the tensor: %s.", e)
156+
157+
158+
def load_eplb_state(
159+
eplb_load_path: Path, model_states: list[EplbModelState]
160+
) -> list[torch.Tensor]:
161+
loaded_tensors = torch.load(eplb_load_path)
162+
global_load_windows = []
163+
for eplb_model_state in model_states:
164+
name = type(eplb_model_state.model).__name__
165+
tensor = loaded_tensors[f"global_expert_load_window_{name}"]
166+
tensor = tensor.to(eplb_model_state.expert_load_window.device)
167+
global_load_windows.append(tensor)
168+
logger.info("Successfully loaded %s.", eplb_load_path)
169+
return global_load_windows
170+
171+
169172
class EplbState:
170173
"""
171174
EplbState of each expert parallel model. Key is the model config hash.
@@ -177,21 +180,6 @@ def __init__(self, parallel_config: ParallelConfig, device: torch.device):
177180
self.model_states: dict[str, EplbModelState] = {}
178181
"""
179182
Current step in the sliding window.
180-
181-
NOTE: Keep in mind that all EP ranks need to have the same
182-
`expert_rearrangement_step` value to ensure synchronization.
183-
Otherwise, the rearrangement will hang at collective
184-
communication calls.
185-
"""
186-
expert_rearrangement_step_interval: int = 0
187-
"""
188-
Interval for expert rearrangement steps.
189-
This is a constant and is taken from the config.
190-
"""
191-
save_dir_iter = 0
192-
"""
193-
Saving directory iteration, used to save the expert load window.
194-
"""
195183
Different from `expert_rearrangement_step`,
196184
each EP rank may have its own `expert_load_window_step`.
197185
"""
@@ -216,6 +204,7 @@ def __init__(self, parallel_config: ParallelConfig, device: torch.device):
216204
This is a constant and is taken from the config.
217205
"""
218206
self.expert_rearrangement_step_interval: int = 0
207+
self.save_dir_iter = 0
219208

220209
@staticmethod
221210
def build_initial_global_physical_to_logical_map(
@@ -363,22 +352,16 @@ def add_model(
363352
device=self.device,
364353
)
365354

366-
eplb_load_path = parallel_config.eplb_config.load_path
367-
eplb_save_dir = parallel_config.eplb_config.save_dir
368-
369-
eplb_step_interval = parallel_config.eplb_config.step_interval
355+
eplb_load_path = self.parallel_config.eplb_config.load_path
356+
eplb_save_dir = self.parallel_config.eplb_config.save_dir
357+
eplb_step_interval = self.parallel_config.eplb_config.step_interval
370358
if eplb_load_path is not None or eplb_save_dir is not None:
371-
expert_rearrangement_step = 0
359+
self.expert_rearrangement_step = 0
372360
else:
373361
# Set the initial progress of rearrangement to 3/4
374-
expert_rearrangement_step = max(
362+
self.expert_rearrangement_step = max(
375363
0, eplb_step_interval - eplb_step_interval // 4
376364
)
377-
# Set the initial progress of rearrangement to 3/4
378-
eplb_step_interval = self.parallel_config.eplb_config.step_interval
379-
self.expert_rearrangement_step = max(
380-
0, eplb_step_interval - eplb_step_interval // 4
381-
)
382365
self.expert_rearrangement_step_interval = eplb_step_interval
383366

384367
if global_expert_load is not None:
@@ -448,11 +431,6 @@ def add_model(
448431
logical_replica_count,
449432
expert_load_pass,
450433
expert_load_window,
451-
eplb_load_path,
452-
eplb_save_dir,
453-
expert_load_window_size=expert_load_window_size,
454-
expert_rearrangement_step=expert_rearrangement_step,
455-
expert_rearrangement_step_interval=eplb_step_interval,
456434
model_config.model,
457435
model,
458436
)
@@ -524,7 +502,7 @@ def step(
524502

525503
if ep_group.rank() == 0:
526504
logger.info(
527-
"EPLB step: %d for model %s: avg_tokens=%.2f, "
505+
"EPLBS step: %d for model %s: avg_tokens=%.2f, "
528506
"max_tokens=%d, balancedness=%.4f",
529507
self.expert_rearrangement_step,
530508
eplb_model_state.model_name,
@@ -551,10 +529,8 @@ def step(
551529
# performing collective communication.
552530
self.expert_rearrangement_step += 1
553531
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
554-
self.rearrange(model)
555-
self.expert_rearrangement_step = 0
556-
self.expert_rearrangement_step = 0
557532
self.rearrange()
533+
self.expert_rearrangement_step = 0
558534

559535
def rearrange(
560536
self,
@@ -590,14 +566,23 @@ def rearrange(
590566
time_start = time.perf_counter()
591567
logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
592568

593-
if self.load_path is not None and self.expert_rearrangement_step == 0:
594-
global_expert_load_window = load_eplb_state(self.load_path).to(
595-
self.physical_to_logical_map.device
569+
if (
570+
self.parallel_config.eplb_config.load_path is not None
571+
and self.expert_rearrangement_step == 0
572+
):
573+
global_expert_load_windows = load_eplb_state(
574+
self.parallel_config.eplb_config.load_path,
575+
list(self.model_states.values()),
596576
)
597-
elif global_expert_load is None:
598-
if global_expert_loads is None:
577+
assert global_expert_load_windows is not None
578+
elif global_expert_loads is None:
599579
# Map the physical expert load to global logical experts
600580
global_expert_load_windows = []
581+
should_save_eplb_state = (
582+
self.parallel_config.eplb_config.save_dir is not None
583+
and not is_profile
584+
and self.expert_rearrangement_step > 0
585+
)
601586
if not execute_shuffle:
602587
num_models = torch.tensor(
603588
[len(self.model_states)], dtype=torch.int32, device="cpu"
@@ -622,17 +607,6 @@ def rearrange(
622607
src=eplb_model_state.expert_load_window,
623608
)
624609

625-
if (
626-
is_main_rank
627-
and self.save_dir is not None
628-
and not is_profile
629-
and self.expert_rearrangement_step > 0
630-
):
631-
save_eplb_state(
632-
global_expert_load_window, self.save_dir, self.save_dir_iter
633-
)
634-
self.save_dir_iter += 1
635-
636610
if not execute_shuffle:
637611
metadata = torch.tensor(
638612
[
@@ -649,6 +623,16 @@ def rearrange(
649623

650624
global_expert_load_window = logical_expert_load_window.sum(dim=0)
651625
global_expert_load_windows.append(global_expert_load_window)
626+
627+
if is_main_rank and should_save_eplb_state:
628+
save_eplb_state(
629+
global_expert_load_windows,
630+
self.parallel_config.eplb_config.save_dir,
631+
self.save_dir_iter,
632+
list(self.model_states.values()),
633+
)
634+
self.save_dir_iter += 1
635+
652636
# Perform all-reduce to get the expert load across all ranks for each model
653637
global_expert_load_windows = self._allreduce_list(
654638
global_expert_load_windows

vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def apply(
100100
apply_router_weight_on_input: bool = False,
101101
activation: str = "silu",
102102
enable_eplb: bool = False,
103+
eplb_record_metrics: bool = False,
103104
expert_load_view: torch.Tensor | None = None,
104105
logical_to_physical_map: torch.Tensor | None = None,
105106
logical_replica_count: torch.Tensor | None = None,
@@ -133,6 +134,7 @@ def apply(
133134
e_score_correction_bias=e_score_correction_bias,
134135
indices_type=self.topk_indices_dtype,
135136
enable_eplb=enable_eplb,
137+
eplb_record_metrics=eplb_record_metrics,
136138
expert_map=expert_map,
137139
expert_load_view=expert_load_view,
138140
logical_to_physical_map=logical_to_physical_map,

0 commit comments

Comments
 (0)