Skip to content

Conversation

@STwangyingrui
Copy link
Collaborator

Add fake balance for EP mode, which is controled by environment variable of EP_FAKE_BALANCE_ENABLED.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @STwangyingrui, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a configurable "fake balance" mechanism for Expert Parallelism (EP) mode, primarily intended for debugging and testing the load distribution across Mixture-of-Experts (MoE) models. It allows developers to simulate a perfectly balanced expert load by overriding the expert selection process, controlled via an environment variable.

Highlights

  • Simulated Expert Load Balancing: Implements a "fake balance" feature for Expert Parallelism (EP) mode, enabling a controlled simulation of perfectly balanced expert distribution for debugging and testing purposes.
  • Environment Variable Control: The fake balance mechanism is activated by setting the EP_FAKE_BALANCE_ENABLED environment variable to "true", providing an easy toggle for this feature.
  • Deterministic Expert Assignment: Introduces a new BalancedTensor utility that generates and caches expert assignments designed to distribute tokens evenly among experts, overriding the standard top-k selection when enabled.
  • Debugging Instrumentation: Adds commented-out logging statements in the fused_experts_impl function to aid in debugging and observing token distribution during both prefill and decode phases, with a note about cudagraph compatibility.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a 'fake balance' mode for expert parallelism, controlled by an environment variable. This is useful for debugging and performance analysis. My review has identified a few critical issues and areas for improvement:

  • There's a bug in the singleton_threadsafe decorator that would cause issues if BalancedTensor is used with different parameters across different MoE layers. I've also pointed out that BalancedTensor is being instantiated with hardcoded default values instead of the model's actual configuration.
  • The core logic for generating balanced tensors in generate_balanced_tensor is very inefficient and will likely be a performance bottleneck. I've suggested a much more performant alternative.
  • There are several instances of commented-out debug code that should be cleaned up for better maintainability.

Addressing these points will significantly improve the quality and correctness of this feature.

Comment on lines 231 to 237
# EP fake负载平衡开关
if os.environ.get("EP_FAKE_BALANCE_ENABLED") == "true":
M, _ = hidden_states.shape
balanced_tensor_collection = BalancedTensor()
balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(M)
topk_ids.copy_(balance_topk_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The BalancedTensor is instantiated with default num_experts and num_selected values. This is incorrect as different models or layers can have different numbers of experts and top-k values. This will lead to incorrect behavior. You should pass the correct num_experts and top_k from the current context. The number of experts can be obtained from router_logits.shape[1] and top_k is available as the top_k parameter.

Suggested change
# EP fake负载平衡开关
if os.environ.get("EP_FAKE_BALANCE_ENABLED") == "true":
M, _ = hidden_states.shape
balanced_tensor_collection = BalancedTensor()
balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(M)
topk_ids.copy_(balance_topk_ids)
# Switch for EP fake load balancing.
if os.environ.get("EP_FAKE_BALANCE_ENABLED") == "true":
M, _ = hidden_states.shape
num_experts = router_logits.shape[1]
balanced_tensor_collection = BalancedTensor(num_experts=num_experts, num_selected=top_k)
balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(M)
topk_ids.copy_(balance_topk_ids)

Comment on lines 6 to 24
def singleton_threadsafe(cls):
instances = {}
lock = threading.Lock()

def get_instance(*args, **kwargs):
with lock:
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of singleton_threadsafe creates a singleton per class, ignoring constructor arguments (*args, **kwargs) on subsequent calls. This will lead to incorrect behavior if BalancedTensor is instantiated with different num_experts or num_selected values, as it will always return the first-created instance with its initial parameters.

To fix this, the singleton's uniqueness key should be based on the class and its instantiation arguments.

Suggested change
def singleton_threadsafe(cls):
instances = {}
lock = threading.Lock()
def get_instance(*args, **kwargs):
with lock:
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
def singleton_threadsafe(cls):
instances = {}
lock = threading.Lock()
def get_instance(*args, **kwargs):
# A key that includes the arguments is needed for parameter-dependent singletons.
# Using a tuple of args and a frozenset of kwargs items makes it hashable.
key = (cls, args, frozenset(kwargs.items()))
with lock:
if key not in instances:
instances[key] = cls(*args, **kwargs)
return instances[key]
return get_instance

Comment on lines 24 to 55
def generate_balanced_tensor(self, length):
# 初始化一个 length * 8 的全零张量,放置在 GPU 上
tensor = torch.zeros((length, self.num_selected), dtype=torch.int, device='cuda')
# 初始化每个专家的负载计数
expert_load = torch.zeros(self.num_experts, dtype=torch.int, device='cuda')

for i in range(length):
available_experts = torch.arange(self.num_experts, device='cuda')
selected = []
for _ in range(self.num_selected):
# 计算每个可用专家的当前负载
current_load = expert_load[available_experts]
# 选择负载最小的专家
min_load_indices = torch.where(current_load == current_load.min())[0]
if len(min_load_indices) > 1:
# 如果有多个负载最小的专家,随机选择一个
chosen_index = torch.randint(0, len(min_load_indices), (1,), device='cuda').item()
chosen_expert_index = min_load_indices[chosen_index]
else:
chosen_expert_index = min_load_indices[0]
chosen_expert = available_experts[chosen_expert_index]
selected.append(chosen_expert)
# 从可用专家列表中移除已选择的专家
available_experts = torch.cat(
[available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1:]])
# 更新该专家的负载
expert_load[chosen_expert] += 1
tensor[i] = torch.tensor(selected, dtype=torch.int, device='cuda')
return tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The generate_balanced_tensor method is highly inefficient and can become a performance bottleneck. It uses nested Python loops and, most critically, torch.cat inside the inner loop to remove an element from available_experts. torch.cat creates a new tensor in every iteration, leading to significant overhead.

A more efficient, vectorized approach should be used. While the current implementation has a specific sequential greedy logic, a slightly different but much faster algorithm that still provides good load balancing would be preferable for this debugging/testing feature.

    def generate_balanced_tensor(self, length):
        # A more performant way to generate a balanced tensor for expert selection.
        tensor = torch.empty((length, self.num_selected), dtype=torch.int, device='cuda')
        expert_load = torch.zeros(self.num_experts, dtype=torch.int, device='cuda')
        
        expert_indices = torch.arange(self.num_experts, device='cuda')

        for i in range(length):
            # To break ties randomly when loads are equal, we can shuffle indices
            # of experts with the same load. A simple way is to shuffle all
            # indices and then sort by load.
            shuffled_indices = expert_indices[torch.randperm(self.num_experts, device='cuda')]
            sorted_shuffled_indices = shuffled_indices[torch.argsort(expert_load[shuffled_indices])]
            
            # Select the top `num_selected` experts with the lowest load
            selected_experts = sorted_shuffled_indices[:self.num_selected]
            
            tensor[i] = selected_experts
            
            # Update loads for the selected experts using an efficient scatter_add
            expert_load.scatter_add_(0, selected_experts, torch.ones_like(selected_experts, dtype=torch.int))
            
        return tensor

Comment on lines 146 to 148
# 用于调试负载平衡的重要日志
#rank=dist.get_rank()
#logger.info(f"prefill, [{rank}], all_tokens = {all_tokens}, num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block contains commented-out debugging code. This adds clutter to the code and should be removed before merging. If this logging is valuable for future debugging, consider adding it behind a specific debug flag instead of leaving it commented out.

Comment on lines 228 to 232
# 用于调试负载平衡的重要日志
# NOTE: when decoding graph is open, we can not call logger. Thus it can only be used when --disable_cudagraph
#rank=dist.get_rank()
#all_tokens = sum(masked_m)
#logger.info(f"decode, [{rank}], all_tokens = {all_tokens}, expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block contains commented-out debugging code. This should be removed to keep the codebase clean. The note about disable_cudagraph is useful, but this kind of debugging code is better kept out of the main branch.

@STwangyingrui STwangyingrui force-pushed the yr/ep_balance branch 2 times, most recently from 794a033 to d2141e6 Compare July 8, 2025 03:31
fix PR checks: faster balance algo, more robust balance management, from env control to option control, better logger info control, better format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants