Skip to content

Commit 743ddc3

Browse files
authored
fix moe_align1 kernel performance issue in prefill stage. (#718)
old version: test groped fused moe speed.py 200 token num: 200 cost time:0.0011632442474365234 s .256 token num: 256 cost time: 0.0011243820198429688 S .8192 token num:8192 cost time: 0.05202174186706543 s new version: test groped fused moe speed.py 200 token num: 200 cost time:0.0011744499206542969 5 256 token num:256 cost time:0.0010919570922851562 s .8192 token num: 8192 cost time: 0.003216266632080078 s 8192 token 10x faster.
1 parent b8cfd70 commit 743ddc3

File tree

3 files changed

+89
-26
lines changed

3 files changed

+89
-26
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -114,29 +114,37 @@ def moe_align1_kernel(
114114
experts_topk_weight, # [expert_num, token_num * topk_num]
115115
experts_topk_weight_stride0,
116116
experts_topk_weight_stride1,
117-
TOKEN_BLOCK_N: tl.constexpr,
117+
TOKEN_BLOCK_SIZE: tl.constexpr,
118+
NUM_STAGE: tl.constexpr,
118119
):
119120

120121
expert_id = tl.program_id(axis=0)
121-
n_range = tl.arange(0, TOKEN_BLOCK_N)
122122

123-
topk_weights_data = tl.load(topk_weights + n_range, mask=n_range < experts_info_n, other=0)
124-
expert_data = tl.load(
125-
experts_info_ptr + expert_id * experts_info_stride0 + n_range, mask=n_range < experts_info_n, other=0
126-
)
127-
cumsum_expert_data = tl.cumsum(expert_data)
123+
off_n = tl.arange(0, TOKEN_BLOCK_SIZE)
128124

129-
tl.store(expert_token_num_ptr + expert_id, tl.max(cumsum_expert_data))
130-
tl.store(
131-
experts_info_ptr + expert_id * experts_info_stride0 + cumsum_expert_data - 1,
132-
n_range,
133-
mask=(expert_data == 1) & (n_range < experts_info_n),
134-
)
135-
tl.store(
136-
experts_topk_weight + expert_id * experts_topk_weight_stride0 + cumsum_expert_data - 1,
137-
topk_weights_data,
138-
mask=(expert_data == 1) & (n_range < experts_info_n),
139-
)
125+
pre_sum = 0
126+
127+
for start_loc in tl.range(0, experts_info_n, TOKEN_BLOCK_SIZE, num_stages=NUM_STAGE):
128+
n_range = start_loc + off_n
129+
topk_weights_data = tl.load(topk_weights + n_range, mask=n_range < experts_info_n, other=0)
130+
expert_data = tl.load(
131+
experts_info_ptr + expert_id * experts_info_stride0 + n_range, mask=n_range < experts_info_n, other=0
132+
)
133+
cumsum_expert_data = tl.cumsum(expert_data) + pre_sum
134+
pre_sum = tl.max(cumsum_expert_data)
135+
tl.store(
136+
experts_info_ptr + expert_id * experts_info_stride0 + cumsum_expert_data - 1,
137+
n_range,
138+
mask=(expert_data == 1) & (n_range < experts_info_n),
139+
)
140+
tl.store(
141+
experts_topk_weight + expert_id * experts_topk_weight_stride0 + cumsum_expert_data - 1,
142+
topk_weights_data,
143+
mask=(expert_data == 1) & (n_range < experts_info_n),
144+
)
145+
146+
tl.store(expert_token_num_ptr + expert_id, pre_sum)
147+
return
140148

141149

142150
def moe_align1(
@@ -184,7 +192,11 @@ def moe_align1(
184192
assert token_num_mul_topk <= FFN_MOE_CHUNK_SIZE * topk_num, "need split to handle seq len too long"
185193
assert exports_token_num.shape[0] == expert_num
186194
assert topk_weights.is_contiguous()
187-
TOKEN_BLOCK_N = triton.next_power_of_2(token_num_mul_topk)
195+
if token_num_mul_topk <= 512:
196+
TOKEN_BLOCK_SIZE = 256
197+
else:
198+
TOKEN_BLOCK_SIZE = 512 if token_num_mul_topk <= 4 * 1024 else 2048
199+
188200
grid = (expert_num,)
189201
moe_align1_kernel[grid](
190202
experts_info,
@@ -197,7 +209,8 @@ def moe_align1(
197209
experts_weight_info,
198210
experts_weight_info.stride(0),
199211
experts_weight_info.stride(1),
200-
TOKEN_BLOCK_N=TOKEN_BLOCK_N,
212+
TOKEN_BLOCK_SIZE=TOKEN_BLOCK_SIZE,
213+
NUM_STAGE=4,
201214
num_warps=8,
202215
num_stages=1,
203216
)

lightllm/common/fused_moe/test_groped_fused_moe.py renamed to unit_tests/common/fused_moe/test_groped_fused_moe.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import torch
22
import time
3-
from .grouped_fused_moe import moe_align, moe_align1, grouped_matmul
3+
import pytest
4+
from lightllm.common.fused_moe.grouped_fused_moe import moe_align, moe_align1, grouped_matmul
45
from lightllm.utils.log_utils import init_logger
56

7+
logger = init_logger(__name__)
8+
69
seed = 42
710
torch.manual_seed(seed)
811

912
if torch.cuda.is_available():
1013
torch.cuda.manual_seed(seed)
1114
torch.cuda.manual_seed_all(seed)
1215

13-
logger = init_logger(__name__)
14-
1516

1617
def test_moe_align():
1718
expert_num = 5
@@ -137,6 +138,4 @@ def test_grouped_matmul():
137138

138139

139140
if __name__ == "__main__":
140-
test_moe_align()
141-
test_moe_align1()
142-
test_grouped_matmul()
141+
pytest.main()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import time
3+
import pytest
4+
from lightllm.common.fused_moe.grouped_fused_moe import moe_align, moe_align1, grouped_matmul
5+
from lightllm.utils.log_utils import init_logger
6+
7+
seed = 42
8+
torch.manual_seed(seed)
9+
10+
if torch.cuda.is_available():
11+
torch.cuda.manual_seed(seed)
12+
torch.cuda.manual_seed_all(seed)
13+
14+
logger = init_logger(__name__)
15+
16+
17+
@pytest.mark.parametrize("token_num", [200, 256, 8 * 1024])
18+
def test_moe_align1(token_num):
19+
expert_num = 160
20+
topk_num = 6
21+
print(token_num)
22+
23+
def get_one():
24+
rnd_logics = torch.randn(token_num, expert_num, device="cuda")
25+
topk_values, topk_ids = torch.topk(rnd_logics, topk_num, dim=1)
26+
27+
experts_info = torch.zeros((expert_num, token_num * topk_num), dtype=torch.int32, device="cuda")
28+
experts_info.fill_(0)
29+
moe_align(topk_ids, experts_info)
30+
31+
topk_weights = torch.randn((token_num, topk_num), dtype=torch.float32, device="cuda")
32+
experts_token_num = torch.zeros((expert_num,), dtype=torch.int32, device="cuda")
33+
experts_weights = torch.zeros(experts_info.shape, dtype=torch.float32, device="cuda")
34+
return experts_info, topk_weights, experts_weights, experts_token_num
35+
36+
test_datas = [get_one() for _ in range(100)]
37+
38+
moe_align1(*test_datas[0], topk_num)
39+
40+
torch.cuda.synchronize()
41+
start = time.time()
42+
43+
for i in range(60):
44+
moe_align1(*test_datas[i + 1], topk_num)
45+
torch.cuda.synchronize()
46+
47+
print(f"token_num: {token_num} cost time: {time.time() - start} s")
48+
49+
50+
if __name__ == "__main__":
51+
pytest.main()

0 commit comments

Comments
 (0)