Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions lightllm/common/fused_moe/grouped_fused_moe_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton.language as tl
from typing import Any, Callable, Dict, Optional, Tuple
import torch.distributed as dist
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.log_utils import init_logger
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd
Expand Down Expand Up @@ -142,6 +143,15 @@ def fused_experts_impl(

# scatter
all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums.

if get_env_start_args().enable_ep_fake_balance:
rank = dist.get_rank()
if rank == 0:
logger.info(
f"prefill, [{rank}], all_tokens = {all_tokens}, "
f"num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}"
)

# gather_out shape [recive_num_tokens, hidden]
gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype)
if all_tokens > 0:
Expand Down Expand Up @@ -219,6 +229,18 @@ def fused_experts_impl(
async_finish=False,
return_recv_hook=False,
)

# NOTE: when decoding graph is open, we can not call logger. Thus it can only be used when --disable_cudagraph
args = get_env_start_args()
if args.enable_ep_fake_balance and args.disable_cudagraph:
rank = dist.get_rank()
all_tokens = sum(masked_m)
if rank == 0:
logger.info(
f"decode, [{rank}], all_tokens = {all_tokens}, "
f"expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}"
)

# deepgemm
gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m)
# low latency combine
Expand Down
9 changes: 9 additions & 0 deletions lightllm/common/fused_moe/topk_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch
from lightllm.utils.sgl_utils import sgl_ops
from lightllm.utils.light_utils import light_ops
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.balance_utils import BalancedTensor
from typing import Callable, List, Optional, Tuple
from lightllm.common.fused_moe.softmax_topk import softmax_topk

Expand Down Expand Up @@ -227,4 +229,11 @@ def select_experts(
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize
)

# EP fake负载平衡开关
if get_env_start_args().enable_ep_fake_balance:
num_tokens, num_experts = router_logits.shape
balanced_tensor_collection = BalancedTensor(num_experts=num_experts, num_selected=top_k)
balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(num_tokens)
topk_ids.copy_(balance_topk_ids)

return topk_weights, topk_ids
3 changes: 3 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway"
)

parser.add_argument("--enable_ep_fake_balance", action="store_true", help="Enable the fake balance of the EP mode")

parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage")

parser.add_argument(
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class StartArgs:
visual_dp: int = field(default=1)
visual_nccl_ports: List[int] = field(default_factory=lambda: [29500])
enable_monitor_auth: bool = field(default=False)
enable_ep_fake_balance: bool = field(default=False)
disable_cudagraph: bool = field(default=False)
graph_max_batch_size: int = field(default=256)
graph_split_batch_size: int = field(default=32)
Expand Down
65 changes: 65 additions & 0 deletions lightllm/utils/balance_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import os

import threading

from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


def singleton_threadsafe(cls):
instances = {}
lock = threading.Lock()

def get_instance(*args, **kwargs):
# A key that includes the arguments is needed for parameter-dependent singletons.
# Using a tuple of args and a frozenset of kwargs items makes it hashable.
key = (cls, args, frozenset(kwargs.items()))
with lock:
if key not in instances:
instances[key] = cls(*args, **kwargs)
return instances[key]

return get_instance


@singleton_threadsafe
class BalancedTensor:
def __init__(self, num_experts=256, num_selected=8):
self.balanced_tensors = {}
self.num_experts = num_experts
self.num_selected = num_selected

def generate_balanced_tensor(self, length):
tensor = torch.empty((length, self.num_selected), dtype=torch.int, device="cuda")
expert_load = torch.zeros(self.num_experts, dtype=torch.int, device="cuda")

expert_indices = torch.arange(self.num_experts, device="cuda")

for i in range(length):
# To break ties randomly when loads are equal, we can shuffle indices
# of experts with the same load. A simple way is to shuffle all
# indices and then sort by load.
shuffled_indices = expert_indices[torch.randperm(self.num_experts, device="cuda")]
sorted_shuffled_indices = shuffled_indices[torch.argsort(expert_load[shuffled_indices])]

# Select the top `num_selected` experts with the lowest load
selected_experts = sorted_shuffled_indices[: self.num_selected]

tensor[i] = selected_experts

# Update loads for the selected experts using an efficient scatter_add
expert_load.scatter_add_(0, selected_experts, torch.ones_like(selected_experts, dtype=torch.int))

return tensor

def get_balance_topk_ids(self, num_tokens):
if self.balanced_tensors.get(num_tokens) is not None:
# logger.info(f"find balanced tensor for num_tokens={num_tokens}")
return self.balanced_tensors[num_tokens]
else:
# logger.info(f"generate balanced tensor for num_tokens={num_tokens}")
tensor = self.generate_balanced_tensor(num_tokens)
self.balanced_tensors[num_tokens] = tensor
return tensor