Skip to content

Commit 9ae19d8

Browse files
author
none
committed
fix
1 parent c6ce96c commit 9ae19d8

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

test/kernel/fuse_moe_tuning.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,16 @@ def test_kernel(
6262
use_fp8_w8a8: bool,
6363
is_up: bool,
6464
block_shape,
65-
num_fused_experts: int,
65+
num_fused_shared_experts: int,
6666
**config,
6767
):
6868
set_seed()
6969
input_tuples = []
7070

7171
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
7272
w1_scale = w2_scale = None
73-
if num_fused_experts > 0:
74-
expert_num += num_fused_experts
73+
if num_fused_shared_experts > 0:
74+
expert_num += num_fused_shared_experts
7575

7676
if use_fp8_w8a8:
7777
init_dtype = dtype
@@ -95,28 +95,28 @@ def test_kernel(
9595
w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda()
9696
w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda()
9797

98-
rnd_logics = torch.randn(m, expert_num - num_fused_experts, device="cuda")
98+
rnd_logics = torch.randn(m, expert_num - num_fused_shared_experts, device="cuda")
9999
topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1)
100-
topk_weights = torch.randn((m, topk + num_fused_experts), device="cuda", dtype=dtype) / 10
101-
102-
if num_fused_experts > 0:
100+
if num_fused_shared_experts > 0:
101+
# 存在融合共享专家的时候,需要pad 共享专家对应的id 到topk_ids 中
103102
pad_topk_ids = torch.arange(
104-
start=expert_num - num_fused_experts,
103+
start=expert_num - num_fused_shared_experts,
105104
end=expert_num,
106105
step=1,
107106
dtype=topk_ids.dtype,
108-
device="cuda").view(1, num_fused_experts).repeat(topk_ids.shape[0], 1)
107+
device="cuda").view(1, num_fused_shared_experts).repeat(topk_ids.shape[0], 1)
109108
topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1)
109+
topk_weights = torch.randn((m, topk + num_fused_shared_experts), device="cuda", dtype=dtype) / 10
110110

111-
expert_to_tokens = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.int32, device="cuda")
112-
expert_to_weights = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.float32, device="cuda")
111+
expert_to_tokens = torch.empty((expert_num, (topk + num_fused_shared_experts) * m), dtype=torch.int32, device="cuda")
112+
expert_to_weights = torch.empty((expert_num, (topk + num_fused_shared_experts) * m), dtype=torch.float32, device="cuda")
113113
moe_align(topk_ids=topk_ids, out=expert_to_tokens)
114114
expert_to_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda")
115-
moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + num_fused_experts)
115+
moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + num_fused_shared_experts)
116116

117-
out1 = torch.zeros((m * (topk + num_fused_experts), 2 * n), dtype=torch.bfloat16, device="cuda")
118-
down_in = torch.zeros((m * (topk + num_fused_experts), n), dtype=torch.bfloat16, device="cuda")
119-
out2 = torch.zeros((m * (topk + num_fused_experts), k), dtype=torch.bfloat16, device="cuda")
117+
out1 = torch.zeros((m * (topk + num_fused_shared_experts), 2 * n), dtype=torch.bfloat16, device="cuda")
118+
down_in = torch.zeros((m * (topk + num_fused_shared_experts), n), dtype=torch.bfloat16, device="cuda")
119+
out2 = torch.zeros((m * (topk + num_fused_shared_experts), k), dtype=torch.bfloat16, device="cuda")
120120

121121
for _ in range(test_count):
122122
input_tuples.append(
@@ -230,7 +230,7 @@ def worker(
230230
use_fp8_w8a8: bool,
231231
is_up: bool,
232232
block_shape,
233-
num_fused_experts: int,
233+
num_fused_shared_experts: int,
234234
test_configs,
235235
queue,
236236
):
@@ -247,7 +247,7 @@ def worker(
247247
use_fp8_w8a8=use_fp8_w8a8,
248248
is_up=is_up,
249249
block_shape=block_shape,
250-
num_fused_experts=num_fused_experts,
250+
num_fused_shared_experts=num_fused_shared_experts,
251251
**test_configs[index],
252252
)
253253
queue.put(cost_time) # Put result in queue
@@ -280,7 +280,7 @@ def get_test_configs(split_id, split_count):
280280
4,
281281
8,
282282
]:
283-
for BLOCK_SIZE_M in [32, 64, 128]:
283+
for BLOCK_SIZE_M in [16, 32, 64, 128]:
284284
for BLOCK_SIZE_N in [32, 64, 128]:
285285
for BLOCK_SIZE_K in [32, 64, 128]:
286286
t_config = {
@@ -311,7 +311,7 @@ def tuning_configs(
311311
use_fp8_w8a8: bool,
312312
is_up: bool,
313313
block_shape,
314-
num_fused_experts: int,
314+
num_fused_shared_experts: int,
315315
):
316316
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
317317
best_config, best_cost_time = None, 10000000
@@ -335,7 +335,7 @@ def tuning_configs(
335335
use_fp8_w8a8,
336336
is_up,
337337
block_shape,
338-
num_fused_experts,
338+
num_fused_shared_experts,
339339
test_configs,
340340
queue,
341341
),
@@ -370,7 +370,7 @@ def tuning_configs(
370370
use_fp8_w8a8,
371371
is_up,
372372
block_shape,
373-
num_fused_experts,
373+
num_fused_shared_experts,
374374
test_configs,
375375
queue,
376376
),
@@ -437,15 +437,15 @@ def main(args):
437437
"use_fp8_w8a8": use_fp8_w8a8,
438438
"is_up": True,
439439
"block_shape": block_shape,
440-
"num_fused_experts": args.num_fused_experts,
440+
"num_fused_shared_experts": args.num_fused_shared_experts,
441441
},
442442
)
443443
up_dict[m] = ans
444444
MoeGroupedGemmKernelConfig.save_config(
445445
N=n * 2,
446446
K=hidden_dim,
447447
topk_num=topk_num,
448-
expert_num=expert_num + args.num_fused_experts,
448+
expert_num=expert_num + args.num_fused_shared_experts,
449449
mul_routed_weight=False,
450450
use_fp8_w8a8=use_fp8_w8a8,
451451
out_dtype=str(torch.bfloat16),
@@ -467,7 +467,7 @@ def main(args):
467467
"use_fp8_w8a8": use_fp8_w8a8,
468468
"is_up": False,
469469
"block_shape": block_shape,
470-
"num_fused_experts": args.num_fused_experts,
470+
"num_fused_shared_experts": args.num_fused_shared_experts,
471471
},
472472
)
473473
down_dict[m] = ans
@@ -476,7 +476,7 @@ def main(args):
476476
N=hidden_dim,
477477
K=n,
478478
topk_num=1,
479-
expert_num=expert_num + args.num_fused_experts,
479+
expert_num=expert_num + args.num_fused_shared_experts,
480480
mul_routed_weight=True,
481481
use_fp8_w8a8=use_fp8_w8a8,
482482
out_dtype=str(torch.bfloat16),
@@ -489,6 +489,6 @@ def main(args):
489489
parser.add_argument("--model_dir", type=str, default="deepseek-ai/DeepSeek-R1")
490490
parser.add_argument("--tp", type=int, default=8)
491491
parser.add_argument("--use_fp8_w8a8", action="store_true")
492-
parser.add_argument("--num_fused_experts", type=int, default=0)
492+
parser.add_argument("--num_fused_shared_experts", type=int, default=0)
493493
args = parser.parse_args()
494494
main(args)

0 commit comments

Comments
 (0)