Skip to content

Commit 8681bda

Browse files
authored
Merge pull request #3 from Clorist33/Bugfix_eplb_4490
Bugfix eplb 4490
2 parents ea54388 + 03490f6 commit 8681bda

File tree

11 files changed

+569
-110
lines changed

11 files changed

+569
-110
lines changed

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class AscendConfig:
3434

3535
def __init__(self, vllm_config):
3636
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
37+
self.mix_placement = additional_config.get("mix_placement",False)
3738
torchair_graph_config = additional_config.get("torchair_graph_config",
3839
{})
3940

@@ -349,4 +350,4 @@ def check_ascend_config(vllm_config, enforce_eager):
349350
logger.warning(
350351
"ACL Graph is currently experimental. Please "
351352
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
352-
" if you encourage any Error")
353+
" if you encourage any Error")

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,21 @@
2828

2929
class VllmEplbAdaptor(EplbAdaptor):
3030

31-
def __init__(self, model, **args):
31+
def __init__(self, model, mtp_instance=None, num_mtp_layers=0, **args):
3232
super().__init__(**args)
3333
self.model = model
3434
self.rank_id = dist.get_rank()
3535
self.world_size = dist.get_world_size()
3636
self.param_dict = dict(self.model.named_parameters())
37+
self.mtp_instance = mtp_instance
38+
self.num_mtp_layers = num_mtp_layers
3739
if self.model.config.model_type == "qwen3_moe":
3840
self.num_dense_layers = 0
3941
self.global_expert_num = self.model.config.num_experts
4042
else:
4143
self.num_dense_layers = self.model.config.first_k_dense_replace
4244
self.global_expert_num = self.model.config.n_routed_experts
43-
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
45+
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers # MTP not included
4446
self.init_redundancy_expert = get_ascend_config(
4547
).init_redundancy_expert
4648

@@ -64,6 +66,18 @@ def __init__(self, model, **args):
6466
else:
6567
self.expert_weight_names = ["w13_weight", "w2_weight"]
6668

69+
if self.mtp_instance is not None:
70+
if any("w13_weight_offset" in name
71+
for name, _ in self.mtp_instance.named_parameters()):
72+
self.mtp_expert_weight_names = [
73+
"w13_weight", "w2_weight", "w13_weight_scale",
74+
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
75+
]
76+
else:
77+
self.mtp_expert_weight_names = ["w13_weight", "w2_weight"]
78+
else:
79+
self.mtp_expert_weight_names = []
80+
6781
self.expert_map_per_layer = dict(
6882
) # reference to expert map on device for expert map update
6983
self.expert_map_per_layer_cpu = dict(
@@ -72,6 +86,12 @@ def __init__(self, model, **args):
7286
self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \
7387
self.model.get_expert_map(self.num_dense_layers + layer_idx)
7488

89+
# Currently, MTP only support one layer.
90+
if self.mtp_instance is not None:
91+
for mtp_layer_idx in range(self.num_mtp_layers):
92+
self.expert_map_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = \
93+
self.mtp_instance.model.get_expert_map(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx)
94+
7595
# TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
7696
num_buffer_tensor = torch.where(
7797
self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel()
@@ -88,6 +108,11 @@ def __init__(self, model, **args):
88108
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \
89109
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
90110

111+
if self.mtp_instance is not None:
112+
for mtp_layer_idx in range(self.num_mtp_layers):
113+
self.log2phy_map_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = \
114+
self.mtp_instance.model.get_log2phy_map(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx)
115+
91116
self.all_topk_ids = []
92117

93118
def init_buffer_tensor(self, num_buffer_tensor):
@@ -131,12 +156,46 @@ def init_expert_param_per_layer(self):
131156
name][0].data[local_expert_id])
132157
self.expert_param_per_layer[layer_idx].append(per_expert_param)
133158

159+
if self.mtp_instance is not None:
160+
mtp_param_dict = dict(self.mtp_instance.named_parameters())
161+
for mtp_layer_idx in range(self.num_mtp_layers):
162+
self.expert_param_per_layer[self.num_dense_layers +
163+
self.num_moe_layers +
164+
mtp_layer_idx] = list()
165+
for local_expert_id in range(num_local_expert):
166+
for mtp_layer_idx in range(self.num_mtp_layers):
167+
self.expert_param_per_layer[
168+
self.num_dense_layers + self.num_moe_layers +
169+
mtp_layer_idx].append([
170+
mtp_param_dict["model.layers." +
171+
str(self.num_dense_layers +
172+
self.num_moe_layers +
173+
mtp_layer_idx) +
174+
".mtp_block.mlp.experts." +
175+
name].data[local_expert_id]
176+
for name in self.mtp_expert_weight_names
177+
])
178+
134179
def get_rank_expert_workload(self) -> torch.Tensor:
135180
self.moe_load = self.model.get_all_moe_loads()
181+
if self.mtp_instance is not None:
182+
self.moe_load = torch.cat([
183+
self.moe_load,
184+
self.mtp_instance.model.get_all_moe_loads().to(
185+
device=self.moe_load.device)
186+
],
187+
dim=0)
136188
return self.moe_load
137189

138190
def get_init_expert_map(self, num_moe_layers):
139191
expert_map = self.model.get_all_expert_map(num_moe_layers)
192+
if self.mtp_instance is not None:
193+
expert_map = torch.cat([
194+
expert_map,
195+
self.mtp_instance.model.get_all_expert_map().to(
196+
device=expert_map.device)
197+
],
198+
dim=0)
140199
if dist.is_initialized():
141200
world_size = dist.get_world_size()
142201

@@ -288,7 +347,9 @@ def determine_expert_map_all(self):
288347
local_num_experts = self.global_expert_num // self.world_size
289348

290349
expert_map_all = torch.full(
291-
(self.num_moe_layers, self.world_size, self.global_expert_num),
350+
(self.num_moe_layers if self.mtp_instance is None else
351+
(self.num_moe_layers + self.num_mtp_layers), self.world_size,
352+
self.global_expert_num),
292353
-1,
293354
dtype=torch.int32)
294355

@@ -311,6 +372,7 @@ def determine_expert_map_all(self):
311372

312373
local_ids = torch.arange(local_count, dtype=torch.int32)
313374
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(
314-
self.num_moe_layers, -1)
375+
self.num_moe_layers if self.mtp_instance is None else
376+
(self.num_moe_layers + self.num_mtp_layers), -1)
315377

316-
return expert_map_all
378+
return expert_map_all

vllm_ascend/eplb/core/eplb_device_transfer_loader.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ def generate_expert_d2d_transfer_task(self, expert_send_info,
5050
)
5151
return
5252

53-
# If neither send nor receive task is needed for this layer on this rank, return
54-
if not (expert_send_info or expert_recv_info):
55-
return
56-
5753
self.updated_expert_map = updated_expert_map
5854

5955
self.layer_id = layer_id
@@ -135,4 +131,4 @@ def update_expert_map_and_weight(self, reqs):
135131
self.state = ExpertWeightUpdateState.WAITING
136132

137133
def load_impl(self, old_expert_table, new_expert_table):
138-
raise NotImplementedError
134+
raise NotImplementedError

vllm_ascend/eplb/eplb_updator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ def __init__(self, ascend_config, loader, eplb_process: EplbProcess,
3535
self.eplb_process = eplb_process
3636
self.shared_dict = self.eplb_process.shared_dict
3737

38-
def set_adaptor(self, adaptor):
38+
def set_adaptor(self, adaptor, num_mtp_layers):
3939
self.adaptor = adaptor
40-
self.num_moe_layers = self.adaptor.num_moe_layers
40+
self.num_moe_layers = (self.adaptor.num_moe_layers
41+
if self.adaptor.mtp_instance is None else
42+
self.adaptor.num_moe_layers + num_mtp_layers)
4143
self.global_expert_num = self.adaptor.global_expert_num
4244

4345
def init_eplb(self, expert_map_path, process):
@@ -84,6 +86,8 @@ def update_iteration(self):
8486
self.expert_map_record_path)
8587

8688
self.adaptor.model.clear_all_moe_loads()
89+
if self.adaptor.mtp_instance is not None:
90+
self.adaptor.mtp_instance.model.clear_all_moe_loads()
8791
if not self.gate_eplb:
8892
self.cur_iterations = 0
8993

@@ -207,4 +211,4 @@ def shutdown(self):
207211
if self.process.is_alive():
208212
self.process.terminate()
209213
self.process.join()
210-
logger.info("[ModelRunner] EPLB process terminated")
214+
logger.info("[ModelRunner] EPLB process terminated")

vllm_ascend/eplb/utils.py

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,45 +18,73 @@
1818
import types
1919

2020
import torch
21+
from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictor
2122

2223

2324
def get_expert_map(self, layer_id):
24-
return self.model.layers[layer_id].mlp.experts.get_map()
25+
if not isinstance(self, DeepSeekMultiTokenPredictor):
26+
return self.model.layers[layer_id].mlp.experts.get_map()
27+
else:
28+
return self.layers[str(layer_id)].mtp_block.mlp.experts.get_map()
2529

2630

2731
def get_log2phy_map(self, layer_id):
28-
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
29-
30-
31-
def get_all_expert_map(self, num_moe_layers):
32-
all_loads = []
33-
num_dense_layers = self.num_dense_layers if hasattr(
34-
self, "num_dense_layers") else 0
35-
for layer_id in range(num_moe_layers):
36-
load_tensor = self.get_expert_map(
37-
layer_id + num_dense_layers) # (num_experts_per_layer,)
38-
all_loads.append(load_tensor)
32+
if not isinstance(self, DeepSeekMultiTokenPredictor):
33+
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
34+
else:
35+
return self.layers[str(
36+
layer_id)].mtp_block.mlp.experts.get_log2phy_map()
37+
38+
39+
def get_all_expert_map(self, num_moe_layers=None):
40+
if not isinstance(self, DeepSeekMultiTokenPredictor):
41+
all_loads = []
42+
num_dense_layers = self.num_dense_layers if hasattr(
43+
self, "num_dense_layers") else 0
44+
for layer_id in range(num_moe_layers):
45+
load_tensor = self.get_expert_map(
46+
layer_id + num_dense_layers) # (num_experts_per_layer,)
47+
all_loads.append(load_tensor)
48+
else:
49+
all_loads = []
50+
for layer_id in range(self.mtp_start_layer_idx,
51+
self.mtp_start_layer_idx + self.num_mtp_layers):
52+
load_tensor = self.get_expert_map(layer_id)
53+
all_loads.append(load_tensor)
3954

4055
return torch.stack(all_loads, dim=0)
4156

4257

4358
def get_all_moe_loads(self):
44-
num_dense_layers = self.num_dense_layers if hasattr(
45-
self, "num_dense_layers") else 0
46-
all_moe_loads = torch.stack(
47-
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
48-
for layer_id in range(self.num_moe_layers)],
49-
dim=0
50-
)
59+
if not isinstance(self, DeepSeekMultiTokenPredictor):
60+
num_dense_layers = self.num_dense_layers if hasattr(
61+
self, "num_dense_layers") else 0
62+
all_moe_loads = torch.stack(
63+
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
64+
for layer_id in range(self.num_moe_layers)],
65+
dim=0
66+
)
67+
else:
68+
all_moe_loads = torch.stack(
69+
[self.layers[str(idx)].mtp_block.mlp.experts.moe_load \
70+
for idx in range(self.mtp_start_layer_idx,
71+
self.mtp_start_layer_idx + self.num_mtp_layers)],
72+
dim=0
73+
)
5174
return all_moe_loads
5275

5376

5477
def clear_all_moe_loads(self):
55-
num_dense_layers = self.num_dense_layers if hasattr(
56-
self, "num_dense_layers") else 0
57-
for layer_id in range(self.num_moe_layers):
58-
self.model.layers[layer_id +
59-
num_dense_layers].mlp.experts.clear_moe_load()
78+
if not isinstance(self, DeepSeekMultiTokenPredictor):
79+
num_dense_layers = self.num_dense_layers if hasattr(
80+
self, "num_dense_layers") else 0
81+
for layer_id in range(self.num_moe_layers):
82+
self.model.layers[layer_id +
83+
num_dense_layers].mlp.experts.clear_moe_load()
84+
else:
85+
for layer_id in range(self.mtp_start_layer_idx,
86+
self.mtp_start_layer_idx + self.num_mtp_layers):
87+
self.layers[str(layer_id)].mtp_block.mlp.experts.clear_moe_load()
6088

6189

6290
def model_register(model, model_config):
@@ -66,12 +94,13 @@ def model_register(model, model_config):
6694
model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model)
6795
model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model)
6896

69-
config = model_config.hf_config
97+
if not isinstance(model, DeepSeekMultiTokenPredictor):
98+
config = model_config.hf_config
7099

71-
if config.model_type == "qwen3_moe":
72-
model.num_moe_layers = config.num_hidden_layers
73-
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
74-
model.num_dense_layers = config.first_k_dense_replace
75-
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
76-
else:
77-
raise NotImplementedError("EPLB is not supported.")
100+
if config.model_type == "qwen3_moe":
101+
model.num_moe_layers = config.num_hidden_layers
102+
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
103+
model.num_dense_layers = config.first_k_dense_replace
104+
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
105+
else:
106+
raise NotImplementedError("EPLB is not supported.")

vllm_ascend/ops/fused_moe/experts_selector.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def select_experts(hidden_states: torch.Tensor,
3333
routed_scaling_factor=1.0,
3434
e_score_correction_bias: Optional[torch.Tensor] = None,
3535
indices_type: Optional[torch.dtype] = None,
36+
mix_placement: Optional[bool] = False,
37+
num_logical_experts: int = -1,
3638
global_num_experts: int = -1):
3739
"""
3840
Fused experts with select experts.
@@ -95,6 +97,19 @@ def select_experts(hidden_states: torch.Tensor,
9597
e_score_correction_bias=e_score_correction_bias,
9698
global_num_experts=global_num_experts,
9799
)
100+
if mix_placement:
101+
pad_shared_expert_ids = torch.full((topk_ids.shape[0], 1),
102+
num_logical_experts,
103+
dtype=topk_ids.dtype,
104+
device=topk_ids.device)
105+
106+
pad_shared_expert_weights = torch.full((topk_weights.shape[0], 1),
107+
0.4,
108+
dtype=topk_weights.dtype,
109+
device=topk_weights.device)
110+
topk_ids = torch.cat([topk_ids, pad_shared_expert_ids], dim=1)
111+
topk_weights = torch.cat([topk_weights, pad_shared_expert_weights],
112+
dim=1)
98113
return topk_weights, topk_ids
99114

100115

@@ -302,4 +317,4 @@ def _native_select_experts(
302317
topk_ids = topk_ids.to(torch.int32)
303318
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
304319

305-
return topk_weights, topk_ids
320+
return topk_weights, topk_ids

0 commit comments

Comments
 (0)