-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprefetch_rebalance.py.v7
More file actions
63 lines (59 loc) · 2.45 KB
/
prefetch_rebalance.py.v7
File metadata and controls
63 lines (59 loc) · 2.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from sglang.load_balance.prefetch_rebalance_cython import prefetch_impl, rebalance_impl
def prefetch(
next_topk_ids: torch.Tensor, # Shape: (num_gpus * seq_len, experts_per_token) - predicted expert IDs for each token
num_experts: int,
seq_len: int,
num_gpus: int,
cur_rank: int,
send_map_local_out: torch.Tensor,
send_map_local_len_out: torch.Tensor,
recv_map_local_len_out: torch.Tensor,
max_dup_per_gpu: int = 6, # (N) maximum number of duplicated experts per GPU
n_local_dup_per_gpu: int = 4, # (L) local hot experts to duplicate
max_send_len: int = 6, # (P) maximum length in send map
force_max_dup: bool = False, # force each GPU to have max_dup_per_gpu experts
verbose: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
next_topk_ids_numpy = next_topk_ids.reshape(num_gpus, seq_len, -1).numpy()
next_expert_gpu_mapping, remote_gpu_expert_load, recv_experts_next_cur_rank = prefetch_impl(
next_topk_ids_numpy,
num_experts,
seq_len,
num_gpus,
cur_rank,
send_map_local_out.numpy(),
send_map_local_len_out.numpy(),
recv_map_local_len_out.numpy(),
max_dup_per_gpu,
n_local_dup_per_gpu,
max_send_len,
force_max_dup,
)
return torch.tensor(next_expert_gpu_mapping), torch.tensor(remote_gpu_expert_load), torch.tensor(recv_experts_next_cur_rank)
def rebalance(
topk_ids: torch.Tensor, # Shape: (num_gpus * seq_len, experts_per_token) - true expert IDs for each token
expert_gpu_mapping: torch.Tensor, # Shape: (num_experts, num_gpus) - expert placement matrix
init_remote_gpu_expert_load: torch.Tensor,
num_experts: int,
seq_len: int,
num_gpus: int,
remote_gpu_expert_info_out: torch.Tensor,
remote_gpu_load_out: torch.Tensor,
remote_gpu_expert_summary_out: torch.Tensor,
min_dup_experts_tokens: int = 0, # (Q) minimum limit of tokens that duplicated experts process on a GPU
verbose: bool = False,
) -> None:
topk_ids_numpy = topk_ids.reshape(num_gpus, seq_len, -1).numpy()
rebalance_impl(
topk_ids_numpy,
expert_gpu_mapping.numpy(),
init_remote_gpu_expert_load.numpy(),
num_experts,
seq_len,
num_gpus,
remote_gpu_expert_info_out.numpy(),
remote_gpu_load_out.numpy(),
remote_gpu_expert_summary_out.numpy(),
min_dup_experts_tokens,
)