Skip to content

Commit 11b6ade

Browse files
author
root
committed
use faster balance algo
1 parent 66c0f2d commit 11b6ade

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

lightllm/utils/balance_utils.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,31 +34,27 @@ def __init__(self, num_experts=256, num_selected=8):
3434
def generate_balanced_tensor(self, num_tokens):
3535
# Evenly distribute num_tokens to num_selected experts out of num_experts.
3636
# Note that the num_selected experts activated by a token cannot be repeated.
37-
# Performance is not that important, as it is only activated in special scenarios.
38-
tensor = torch.zeros((num_tokens, self.num_selected), dtype=torch.int, device="cuda")
37+
tensor = torch.empty((num_tokens, self.num_selected), dtype=torch.int, device="cuda")
3938
expert_load = torch.zeros(self.num_experts, dtype=torch.int, device="cuda")
4039

4140
for i in range(num_tokens):
42-
available_experts = torch.arange(self.num_experts, device="cuda")
43-
selected = []
44-
for _ in range(self.num_selected):
45-
current_load = expert_load[available_experts]
46-
min_load_indices = torch.where(current_load == current_load.min())[0]
41+
selected_mask = torch.zeros(self.num_experts, dtype=torch.bool, device="cuda")
42+
for j in range(self.num_selected):
43+
# Use a large value for already selected experts to exclude them
44+
load_view = torch.where(selected_mask, torch.iinfo(expert_load.dtype).max, expert_load)
45+
46+
min_load_indices = torch.where(load_view == load_view.min())[0]
47+
4748
if len(min_load_indices) > 1:
4849
# If there are multiple least-loaded experts, select one randomly
49-
chosen_index = torch.randint(0, len(min_load_indices), (1,), device="cuda").item()
50-
chosen_expert_index = min_load_indices[chosen_index]
50+
rand_idx = torch.randint(0, len(min_load_indices), (1,), device="cuda").item()
51+
chosen_expert = min_load_indices[rand_idx]
5152
else:
52-
chosen_expert_index = min_load_indices[0]
53-
chosen_expert = available_experts[chosen_expert_index]
54-
selected.append(chosen_expert)
55-
# Remove the selected expert from the list of available experts
56-
available_experts = torch.cat(
57-
[available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1 :]]
58-
)
59-
expert_load[chosen_expert] += 1
53+
chosen_expert = min_load_indices[0]
6054

61-
tensor[i] = torch.tensor(selected, dtype=torch.int, device="cuda")
55+
tensor[i, j] = chosen_expert
56+
expert_load[chosen_expert] += 1
57+
selected_mask[chosen_expert] = True
6258

6359
return tensor
6460

0 commit comments

Comments
 (0)