Skip to content

Commit d2141e6

Browse files
committed
fix pre-commit checks:
fix PR checks: faster balance algo, more robust balance management, from env control to option control, better logger info control, better format
1 parent 88ca5a0 commit d2141e6

File tree

5 files changed

+65
-48
lines changed

5 files changed

+65
-48
lines changed

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import triton.language as tl
66
from typing import Any, Callable, Dict, Optional, Tuple
77
import torch.distributed as dist
8+
from lightllm.utils.envs_utils import get_env_start_args
89
from lightllm.utils.log_utils import init_logger
910
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
1011
from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd
@@ -143,9 +144,13 @@ def fused_experts_impl(
143144
# scatter
144145
all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums.
145146

146-
# 用于调试负载平衡的重要日志
147-
#rank=dist.get_rank()
148-
#logger.info(f"prefill, [{rank}], all_tokens = {all_tokens}, num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}")
147+
if get_env_start_args().enable_ep_fake_balance:
148+
rank = dist.get_rank()
149+
if rank == 0:
150+
logger.info(
151+
f"prefill, [{rank}], all_tokens = {all_tokens}, "
152+
f"num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}"
153+
)
149154

150155
# gather_out shape [recive_num_tokens, hidden]
151156
gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype)
@@ -225,11 +230,16 @@ def fused_experts_impl(
225230
return_recv_hook=False,
226231
)
227232

228-
# 用于调试负载平衡的重要日志
229-
# when decoding graph is open, we can not call logger. --profile can close cuda graph
230-
#rank=dist.get_rank()
231-
#all_tokens = sum(masked_m)
232-
#logger.info(f"decode, [{rank}], all_tokens = {all_tokens}, expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}")
233+
# NOTE: when decoding graph is open, we can not call logger. Thus it can only be used when --disable_cudagraph
234+
args = get_env_start_args()
235+
if args.enable_ep_fake_balance and args.disable_cudagraph:
236+
rank = dist.get_rank()
237+
all_tokens = sum(masked_m)
238+
if rank == 0:
239+
logger.info(
240+
f"decode, [{rank}], all_tokens = {all_tokens}, "
241+
f"expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}"
242+
)
233243

234244
# deepgemm
235245
gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m)

lightllm/common/fused_moe/topk_select.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from lightllm.utils.sgl_utils import sgl_ops
2323
from lightllm.utils.light_utils import light_ops
24+
from lightllm.utils.envs_utils import get_env_start_args
2425
from lightllm.utils.balance_utils import BalancedTensor
2526
from typing import Callable, List, Optional, Tuple
2627
from lightllm.common.fused_moe.softmax_topk import softmax_topk
@@ -229,10 +230,10 @@ def select_experts(
229230
)
230231

231232
# EP fake负载平衡开关
232-
if os.environ.get("EP_FAKE_BALANCE_ENABLED") == "true":
233-
M, _ = hidden_states.shape
234-
balanced_tensor_collection = BalancedTensor()
235-
balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(M)
233+
if get_env_start_args().enable_ep_fake_balance:
234+
num_tokens, num_experts = router_logits.shape
235+
balanced_tensor_collection = BalancedTensor(num_experts=num_experts, num_selected=top_k)
236+
balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(num_tokens)
236237
topk_ids.copy_(balance_topk_ids)
237238

238239
return topk_weights, topk_ids

lightllm/server/api_cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
333333
parser.add_argument(
334334
"--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway"
335335
)
336+
337+
parser.add_argument("--enable_ep_fake_balance", action="store_true", help="Enable the fake balance of the EP mode")
338+
336339
parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage")
337340

338341
parser.add_argument(

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class StartArgs:
7676
visual_dp: int = field(default=1)
7777
visual_nccl_ports: List[int] = field(default_factory=lambda: [29500])
7878
enable_monitor_auth: bool = field(default=False)
79+
enable_ep_fake_balance: bool = field(default=False)
7980
disable_cudagraph: bool = field(default=False)
8081
graph_max_batch_size: int = field(default=256)
8182
graph_split_batch_size: int = field(default=32)

lightllm/utils/balance_utils.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,27 @@
33

44
import threading
55

6+
from lightllm.utils.log_utils import init_logger
7+
8+
logger = init_logger(__name__)
9+
10+
611
def singleton_threadsafe(cls):
712
instances = {}
813
lock = threading.Lock()
914

1015
def get_instance(*args, **kwargs):
16+
# A key that includes the arguments is needed for parameter-dependent singletons.
17+
# Using a tuple of args and a frozenset of kwargs items makes it hashable.
18+
key = (cls, args, frozenset(kwargs.items()))
1119
with lock:
12-
if cls not in instances:
13-
instances[cls] = cls(*args, **kwargs)
14-
return instances[cls]
20+
if key not in instances:
21+
instances[key] = cls(*args, **kwargs)
22+
return instances[key]
23+
1524
return get_instance
1625

26+
1727
@singleton_threadsafe
1828
class BalancedTensor:
1929
def __init__(self, num_experts=256, num_selected=8):
@@ -22,42 +32,34 @@ def __init__(self, num_experts=256, num_selected=8):
2232
self.num_selected = num_selected
2333

2434
def generate_balanced_tensor(self, length):
25-
# 初始化一个 length * 8 的全零张量,放置在 GPU 上
26-
tensor = torch.zeros((length, self.num_selected), dtype=torch.int, device='cuda')
27-
# 初始化每个专家的负载计数
28-
expert_load = torch.zeros(self.num_experts, dtype=torch.int, device='cuda')
35+
tensor = torch.empty((length, self.num_selected), dtype=torch.int, device="cuda")
36+
expert_load = torch.zeros(self.num_experts, dtype=torch.int, device="cuda")
37+
38+
expert_indices = torch.arange(self.num_experts, device="cuda")
2939

3040
for i in range(length):
31-
available_experts = torch.arange(self.num_experts, device='cuda')
32-
selected = []
33-
for _ in range(self.num_selected):
34-
# 计算每个可用专家的当前负载
35-
current_load = expert_load[available_experts]
36-
# 选择负载最小的专家
37-
min_load_indices = torch.where(current_load == current_load.min())[0]
38-
if len(min_load_indices) > 1:
39-
# 如果有多个负载最小的专家,随机选择一个
40-
chosen_index = torch.randint(0, len(min_load_indices), (1,), device='cuda').item()
41-
chosen_expert_index = min_load_indices[chosen_index]
42-
else:
43-
chosen_expert_index = min_load_indices[0]
44-
chosen_expert = available_experts[chosen_expert_index]
45-
selected.append(chosen_expert)
46-
# 从可用专家列表中移除已选择的专家
47-
available_experts = torch.cat(
48-
[available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1:]])
49-
# 更新该专家的负载
50-
expert_load[chosen_expert] += 1
51-
tensor[i] = torch.tensor(selected, dtype=torch.int, device='cuda')
41+
# To break ties randomly when loads are equal, we can shuffle indices
42+
# of experts with the same load. A simple way is to shuffle all
43+
# indices and then sort by load.
44+
shuffled_indices = expert_indices[torch.randperm(self.num_experts, device="cuda")]
45+
sorted_shuffled_indices = shuffled_indices[torch.argsort(expert_load[shuffled_indices])]
46+
47+
# Select the top `num_selected` experts with the lowest load
48+
selected_experts = sorted_shuffled_indices[: self.num_selected]
49+
50+
tensor[i] = selected_experts
51+
52+
# Update loads for the selected experts using an efficient scatter_add
53+
expert_load.scatter_add_(0, selected_experts, torch.ones_like(selected_experts, dtype=torch.int))
54+
5255
return tensor
5356

54-
def get_balance_topk_ids(self, length):
55-
if self.balanced_tensors.get(length) is not None:
56-
#print("find length ", length)
57-
return self.balanced_tensors[length]
57+
def get_balance_topk_ids(self, num_tokens):
58+
if self.balanced_tensors.get(num_tokens) is not None:
59+
# logger.info(f"find balanced tensor for num_tokens={num_tokens}")
60+
return self.balanced_tensors[num_tokens]
5861
else:
59-
#print("generate length ", length)
60-
tensor = self.generate_balanced_tensor(length)
61-
self.balanced_tensors[length] = tensor
62+
# logger.info(f"generate balanced tensor for num_tokens={num_tokens}")
63+
tensor = self.generate_balanced_tensor(num_tokens)
64+
self.balanced_tensors[num_tokens] = tensor
6265
return tensor
63-

0 commit comments

Comments
 (0)