-
Notifications
You must be signed in to change notification settings - Fork 293
Add EP fake balance #959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EP fake balance #959
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @STwangyingrui, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a configurable "fake balance" mechanism for Expert Parallelism (EP) mode, primarily intended for debugging and testing the load distribution across Mixture-of-Experts (MoE) models. It allows developers to simulate a perfectly balanced expert load by overriding the expert selection process, controlled via an environment variable.
Highlights
- Simulated Expert Load Balancing: Implements a "fake balance" feature for Expert Parallelism (EP) mode, enabling a controlled simulation of perfectly balanced expert distribution for debugging and testing purposes.
- Environment Variable Control: The fake balance mechanism is activated by setting the
EP_FAKE_BALANCE_ENABLEDenvironment variable to "true", providing an easy toggle for this feature. - Deterministic Expert Assignment: Introduces a new
BalancedTensorutility that generates and caches expert assignments designed to distribute tokens evenly among experts, overriding the standard top-k selection when enabled. - Debugging Instrumentation: Adds commented-out logging statements in the
fused_experts_implfunction to aid in debugging and observing token distribution during both prefill and decode phases, with a note aboutcudagraphcompatibility.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a 'fake balance' mode for expert parallelism, controlled by an environment variable. This is useful for debugging and performance analysis. My review has identified a few critical issues and areas for improvement:
- There's a bug in the
singleton_threadsafedecorator that would cause issues ifBalancedTensoris used with different parameters across different MoE layers. I've also pointed out thatBalancedTensoris being instantiated with hardcoded default values instead of the model's actual configuration. - The core logic for generating balanced tensors in
generate_balanced_tensoris very inefficient and will likely be a performance bottleneck. I've suggested a much more performant alternative. - There are several instances of commented-out debug code that should be cleaned up for better maintainability.
Addressing these points will significantly improve the quality and correctness of this feature.
| # EP fake负载平衡开关 | ||
| if os.environ.get("EP_FAKE_BALANCE_ENABLED") == "true": | ||
| M, _ = hidden_states.shape | ||
| balanced_tensor_collection = BalancedTensor() | ||
| balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(M) | ||
| topk_ids.copy_(balance_topk_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The BalancedTensor is instantiated with default num_experts and num_selected values. This is incorrect as different models or layers can have different numbers of experts and top-k values. This will lead to incorrect behavior. You should pass the correct num_experts and top_k from the current context. The number of experts can be obtained from router_logits.shape[1] and top_k is available as the top_k parameter.
| # EP fake负载平衡开关 | |
| if os.environ.get("EP_FAKE_BALANCE_ENABLED") == "true": | |
| M, _ = hidden_states.shape | |
| balanced_tensor_collection = BalancedTensor() | |
| balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(M) | |
| topk_ids.copy_(balance_topk_ids) | |
| # Switch for EP fake load balancing. | |
| if os.environ.get("EP_FAKE_BALANCE_ENABLED") == "true": | |
| M, _ = hidden_states.shape | |
| num_experts = router_logits.shape[1] | |
| balanced_tensor_collection = BalancedTensor(num_experts=num_experts, num_selected=top_k) | |
| balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(M) | |
| topk_ids.copy_(balance_topk_ids) |
| def singleton_threadsafe(cls): | ||
| instances = {} | ||
| lock = threading.Lock() | ||
|
|
||
| def get_instance(*args, **kwargs): | ||
| with lock: | ||
| if cls not in instances: | ||
| instances[cls] = cls(*args, **kwargs) | ||
| return instances[cls] | ||
| return get_instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of singleton_threadsafe creates a singleton per class, ignoring constructor arguments (*args, **kwargs) on subsequent calls. This will lead to incorrect behavior if BalancedTensor is instantiated with different num_experts or num_selected values, as it will always return the first-created instance with its initial parameters.
To fix this, the singleton's uniqueness key should be based on the class and its instantiation arguments.
| def singleton_threadsafe(cls): | |
| instances = {} | |
| lock = threading.Lock() | |
| def get_instance(*args, **kwargs): | |
| with lock: | |
| if cls not in instances: | |
| instances[cls] = cls(*args, **kwargs) | |
| return instances[cls] | |
| return get_instance | |
| def singleton_threadsafe(cls): | |
| instances = {} | |
| lock = threading.Lock() | |
| def get_instance(*args, **kwargs): | |
| # A key that includes the arguments is needed for parameter-dependent singletons. | |
| # Using a tuple of args and a frozenset of kwargs items makes it hashable. | |
| key = (cls, args, frozenset(kwargs.items())) | |
| with lock: | |
| if key not in instances: | |
| instances[key] = cls(*args, **kwargs) | |
| return instances[key] | |
| return get_instance |
| def generate_balanced_tensor(self, length): | ||
| # 初始化一个 length * 8 的全零张量,放置在 GPU 上 | ||
| tensor = torch.zeros((length, self.num_selected), dtype=torch.int, device='cuda') | ||
| # 初始化每个专家的负载计数 | ||
| expert_load = torch.zeros(self.num_experts, dtype=torch.int, device='cuda') | ||
|
|
||
| for i in range(length): | ||
| available_experts = torch.arange(self.num_experts, device='cuda') | ||
| selected = [] | ||
| for _ in range(self.num_selected): | ||
| # 计算每个可用专家的当前负载 | ||
| current_load = expert_load[available_experts] | ||
| # 选择负载最小的专家 | ||
| min_load_indices = torch.where(current_load == current_load.min())[0] | ||
| if len(min_load_indices) > 1: | ||
| # 如果有多个负载最小的专家,随机选择一个 | ||
| chosen_index = torch.randint(0, len(min_load_indices), (1,), device='cuda').item() | ||
| chosen_expert_index = min_load_indices[chosen_index] | ||
| else: | ||
| chosen_expert_index = min_load_indices[0] | ||
| chosen_expert = available_experts[chosen_expert_index] | ||
| selected.append(chosen_expert) | ||
| # 从可用专家列表中移除已选择的专家 | ||
| available_experts = torch.cat( | ||
| [available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1:]]) | ||
| # 更新该专家的负载 | ||
| expert_load[chosen_expert] += 1 | ||
| tensor[i] = torch.tensor(selected, dtype=torch.int, device='cuda') | ||
| return tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generate_balanced_tensor method is highly inefficient and can become a performance bottleneck. It uses nested Python loops and, most critically, torch.cat inside the inner loop to remove an element from available_experts. torch.cat creates a new tensor in every iteration, leading to significant overhead.
A more efficient, vectorized approach should be used. While the current implementation has a specific sequential greedy logic, a slightly different but much faster algorithm that still provides good load balancing would be preferable for this debugging/testing feature.
def generate_balanced_tensor(self, length):
# A more performant way to generate a balanced tensor for expert selection.
tensor = torch.empty((length, self.num_selected), dtype=torch.int, device='cuda')
expert_load = torch.zeros(self.num_experts, dtype=torch.int, device='cuda')
expert_indices = torch.arange(self.num_experts, device='cuda')
for i in range(length):
# To break ties randomly when loads are equal, we can shuffle indices
# of experts with the same load. A simple way is to shuffle all
# indices and then sort by load.
shuffled_indices = expert_indices[torch.randperm(self.num_experts, device='cuda')]
sorted_shuffled_indices = shuffled_indices[torch.argsort(expert_load[shuffled_indices])]
# Select the top `num_selected` experts with the lowest load
selected_experts = sorted_shuffled_indices[:self.num_selected]
tensor[i] = selected_experts
# Update loads for the selected experts using an efficient scatter_add
expert_load.scatter_add_(0, selected_experts, torch.ones_like(selected_experts, dtype=torch.int))
return tensor| # 用于调试负载平衡的重要日志 | ||
| #rank=dist.get_rank() | ||
| #logger.info(f"prefill, [{rank}], all_tokens = {all_tokens}, num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # 用于调试负载平衡的重要日志 | ||
| # NOTE: when decoding graph is open, we can not call logger. Thus it can only be used when --disable_cudagraph | ||
| #rank=dist.get_rank() | ||
| #all_tokens = sum(masked_m) | ||
| #logger.info(f"decode, [{rank}], all_tokens = {all_tokens}, expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
794a033 to
d2141e6
Compare
fix PR checks: faster balance algo, more robust balance management, from env control to option control, better logger info control, better format
Add fake balance for EP mode, which is controled by environment variable of EP_FAKE_BALANCE_ENABLED.