Skip to content

Commit 88ca5a0

Browse files
committed
add ep fake balance
1 parent 81c5f61 commit 88ca5a0

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ def fused_experts_impl(
142142

143143
# scatter
144144
all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums.
145+
146+
# 用于调试负载平衡的重要日志
147+
#rank=dist.get_rank()
148+
#logger.info(f"prefill, [{rank}], all_tokens = {all_tokens}, num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}")
149+
145150
# gather_out shape [recive_num_tokens, hidden]
146151
gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype)
147152
if all_tokens > 0:
@@ -219,6 +224,13 @@ def fused_experts_impl(
219224
async_finish=False,
220225
return_recv_hook=False,
221226
)
227+
228+
# 用于调试负载平衡的重要日志
229+
# when decoding graph is open, we can not call logger. --profile can close cuda graph
230+
#rank=dist.get_rank()
231+
#all_tokens = sum(masked_m)
232+
#logger.info(f"decode, [{rank}], all_tokens = {all_tokens}, expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}")
233+
222234
# deepgemm
223235
gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m)
224236
# low latency combine

lightllm/common/fused_moe/topk_select.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from lightllm.utils.sgl_utils import sgl_ops
2323
from lightllm.utils.light_utils import light_ops
24+
from lightllm.utils.balance_utils import BalancedTensor
2425
from typing import Callable, List, Optional, Tuple
2526
from lightllm.common.fused_moe.softmax_topk import softmax_topk
2627

@@ -227,4 +228,11 @@ def select_experts(
227228
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize
228229
)
229230

231+
# EP fake负载平衡开关
232+
if os.environ.get("EP_FAKE_BALANCE_ENABLED") == "true":
233+
M, _ = hidden_states.shape
234+
balanced_tensor_collection = BalancedTensor()
235+
balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(M)
236+
topk_ids.copy_(balance_topk_ids)
237+
230238
return topk_weights, topk_ids

lightllm/utils/balance_utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
import os
3+
4+
import threading
5+
6+
def singleton_threadsafe(cls):
7+
instances = {}
8+
lock = threading.Lock()
9+
10+
def get_instance(*args, **kwargs):
11+
with lock:
12+
if cls not in instances:
13+
instances[cls] = cls(*args, **kwargs)
14+
return instances[cls]
15+
return get_instance
16+
17+
@singleton_threadsafe
18+
class BalancedTensor:
19+
def __init__(self, num_experts=256, num_selected=8):
20+
self.balanced_tensors = {}
21+
self.num_experts = num_experts
22+
self.num_selected = num_selected
23+
24+
def generate_balanced_tensor(self, length):
25+
# 初始化一个 length * 8 的全零张量,放置在 GPU 上
26+
tensor = torch.zeros((length, self.num_selected), dtype=torch.int, device='cuda')
27+
# 初始化每个专家的负载计数
28+
expert_load = torch.zeros(self.num_experts, dtype=torch.int, device='cuda')
29+
30+
for i in range(length):
31+
available_experts = torch.arange(self.num_experts, device='cuda')
32+
selected = []
33+
for _ in range(self.num_selected):
34+
# 计算每个可用专家的当前负载
35+
current_load = expert_load[available_experts]
36+
# 选择负载最小的专家
37+
min_load_indices = torch.where(current_load == current_load.min())[0]
38+
if len(min_load_indices) > 1:
39+
# 如果有多个负载最小的专家,随机选择一个
40+
chosen_index = torch.randint(0, len(min_load_indices), (1,), device='cuda').item()
41+
chosen_expert_index = min_load_indices[chosen_index]
42+
else:
43+
chosen_expert_index = min_load_indices[0]
44+
chosen_expert = available_experts[chosen_expert_index]
45+
selected.append(chosen_expert)
46+
# 从可用专家列表中移除已选择的专家
47+
available_experts = torch.cat(
48+
[available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1:]])
49+
# 更新该专家的负载
50+
expert_load[chosen_expert] += 1
51+
tensor[i] = torch.tensor(selected, dtype=torch.int, device='cuda')
52+
return tensor
53+
54+
def get_balance_topk_ids(self, length):
55+
if self.balanced_tensors.get(length) is not None:
56+
#print("find length ", length)
57+
return self.balanced_tensors[length]
58+
else:
59+
#print("generate length ", length)
60+
tensor = self.generate_balanced_tensor(length)
61+
self.balanced_tensors[length] = tensor
62+
return tensor
63+

0 commit comments

Comments
 (0)