|
| 1 | +import torch |
| 2 | +import os |
| 3 | + |
| 4 | +import threading |
| 5 | + |
| 6 | +def singleton_threadsafe(cls): |
| 7 | + instances = {} |
| 8 | + lock = threading.Lock() |
| 9 | + |
| 10 | + def get_instance(*args, **kwargs): |
| 11 | + with lock: |
| 12 | + if cls not in instances: |
| 13 | + instances[cls] = cls(*args, **kwargs) |
| 14 | + return instances[cls] |
| 15 | + return get_instance |
| 16 | + |
| 17 | +@singleton_threadsafe |
| 18 | +class BalancedTensor: |
| 19 | + def __init__(self, num_experts=256, num_selected=8): |
| 20 | + self.balanced_tensors = {} |
| 21 | + self.num_experts = num_experts |
| 22 | + self.num_selected = num_selected |
| 23 | + |
| 24 | + def generate_balanced_tensor(self, length): |
| 25 | + # 初始化一个 length * 8 的全零张量,放置在 GPU 上 |
| 26 | + tensor = torch.zeros((length, self.num_selected), dtype=torch.int, device='cuda') |
| 27 | + # 初始化每个专家的负载计数 |
| 28 | + expert_load = torch.zeros(self.num_experts, dtype=torch.int, device='cuda') |
| 29 | + |
| 30 | + for i in range(length): |
| 31 | + available_experts = torch.arange(self.num_experts, device='cuda') |
| 32 | + selected = [] |
| 33 | + for _ in range(self.num_selected): |
| 34 | + # 计算每个可用专家的当前负载 |
| 35 | + current_load = expert_load[available_experts] |
| 36 | + # 选择负载最小的专家 |
| 37 | + min_load_indices = torch.where(current_load == current_load.min())[0] |
| 38 | + if len(min_load_indices) > 1: |
| 39 | + # 如果有多个负载最小的专家,随机选择一个 |
| 40 | + chosen_index = torch.randint(0, len(min_load_indices), (1,), device='cuda').item() |
| 41 | + chosen_expert_index = min_load_indices[chosen_index] |
| 42 | + else: |
| 43 | + chosen_expert_index = min_load_indices[0] |
| 44 | + chosen_expert = available_experts[chosen_expert_index] |
| 45 | + selected.append(chosen_expert) |
| 46 | + # 从可用专家列表中移除已选择的专家 |
| 47 | + available_experts = torch.cat( |
| 48 | + [available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1:]]) |
| 49 | + # 更新该专家的负载 |
| 50 | + expert_load[chosen_expert] += 1 |
| 51 | + tensor[i] = torch.tensor(selected, dtype=torch.int, device='cuda') |
| 52 | + return tensor |
| 53 | + |
| 54 | + def get_balance_topk_ids(self, length): |
| 55 | + if self.balanced_tensors.get(length) is not None: |
| 56 | + #print("find length ", length) |
| 57 | + return self.balanced_tensors[length] |
| 58 | + else: |
| 59 | + #print("generate length ", length) |
| 60 | + tensor = self.generate_balanced_tensor(length) |
| 61 | + self.balanced_tensors[length] = tensor |
| 62 | + return tensor |
| 63 | + |
0 commit comments