Skip to content

Commit 3787ccc

Browse files
author
wangzaijun
committed
fix tunning code.
1 parent d240c6e commit 3787ccc

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

test/kernel/fuse_moe_tuning_bf16.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -360,17 +360,22 @@ def tuning_configs(
360360
from lightllm.utils.tuning_utils import mp_tuning
361361
from lightllm.common.fused_moe.moe_kernel_configs import MoeGroupedGemmKernelConfig
362362

363-
# tuning to get deepseekv2 lite configs and store
363+
# tuning to get deepseekv2 lite configs and store tp 1
364+
expert_num = 64
365+
n = 1408 // 2 # up is n * 2
366+
hidden_dim = 2048
367+
topk_num = 6
368+
364369
up_dict = {}
365370
for m in [1, 8, 64, 128, 256, 512, 1024, 4096, 8192]:
366371
ans = mp_tuning(
367372
tuning_configs,
368373
{
369-
"expert_num": 64,
374+
"expert_num": expert_num,
370375
"m": m,
371-
"n": 1408 // 2,
372-
"k": 2048,
373-
"topk": 6,
376+
"n": n,
377+
"k": hidden_dim,
378+
"topk": topk_num,
374379
"dtype": torch.bfloat16,
375380
"test_count": 20,
376381
"use_fp8_w8a8": False,
@@ -379,10 +384,10 @@ def tuning_configs(
379384
)
380385
up_dict[m] = ans
381386
MoeGroupedGemmKernelConfig.save_config(
382-
N=1408,
383-
K=2048,
384-
topk_num=6,
385-
expert_num=64,
387+
N=n * 2,
388+
K=hidden_dim,
389+
topk_num=topk_num,
390+
expert_num=expert_num,
386391
mul_routed_weight=False,
387392
use_fp8_w8a8=False,
388393
out_dtype=str(torch.bfloat16),
@@ -394,11 +399,11 @@ def tuning_configs(
394399
ans = mp_tuning(
395400
tuning_configs,
396401
{
397-
"expert_num": 64,
402+
"expert_num": expert_num,
398403
"m": m,
399-
"n": 1408 // 2,
400-
"k": 2048,
401-
"topk": 6,
404+
"n": n,
405+
"k": hidden_dim,
406+
"topk": topk_num,
402407
"dtype": torch.bfloat16,
403408
"test_count": 20,
404409
"use_fp8_w8a8": False,
@@ -407,10 +412,10 @@ def tuning_configs(
407412
)
408413
down_dict[m] = ans
409414
MoeGroupedGemmKernelConfig.save_config(
410-
N=2048,
411-
K=1408 // 2,
415+
N=hidden_dim,
416+
K=n,
412417
topk_num=1,
413-
expert_num=64,
418+
expert_num=expert_num,
414419
mul_routed_weight=True,
415420
use_fp8_w8a8=False,
416421
out_dtype=str(torch.bfloat16),

test/kernel/fuse_moe_tuning_fp8.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -363,17 +363,22 @@ def tuning_configs(
363363
from lightllm.utils.tuning_utils import mp_tuning
364364
from lightllm.common.fused_moe.moe_kernel_configs import MoeGroupedGemmKernelConfig
365365

366-
# tuning to get deepseekv2 large configs and store in H800
366+
# tuning to get deepseekv2 large configs and store in H800, tp 8
367+
expert_num = 160
368+
n = 192 # up is n * 2
369+
hidden_dim = 5120
370+
topk_num = 6
371+
367372
up_dict = {}
368373
for m in [1, 8, 64, 128, 256, 512, 1024, 4096, 8192]:
369374
ans = mp_tuning(
370375
tuning_configs,
371376
{
372-
"expert_num": 160,
377+
"expert_num": expert_num,
373378
"m": m,
374-
"n": 192,
375-
"k": 5120,
376-
"topk": 6,
379+
"n": n,
380+
"k": hidden_dim,
381+
"topk": topk_num,
377382
"dtype": torch.bfloat16,
378383
"test_count": 20,
379384
"use_fp8_w8a8": True,
@@ -382,10 +387,10 @@ def tuning_configs(
382387
)
383388
up_dict[m] = ans
384389
MoeGroupedGemmKernelConfig.save_config(
385-
N=192 * 2,
386-
K=5120,
387-
topk_num=6,
388-
expert_num=160,
390+
N=n * 2,
391+
K=hidden_dim,
392+
topk_num=topk_num,
393+
expert_num=expert_num,
389394
mul_routed_weight=False,
390395
use_fp8_w8a8=True,
391396
out_dtype=str(torch.bfloat16),
@@ -397,11 +402,11 @@ def tuning_configs(
397402
ans = mp_tuning(
398403
tuning_configs,
399404
{
400-
"expert_num": 160,
405+
"expert_num": expert_num,
401406
"m": m,
402-
"n": 192,
403-
"k": 5120,
404-
"topk": 6,
407+
"n": n,
408+
"k": hidden_dim,
409+
"topk": topk_num,
405410
"dtype": torch.bfloat16,
406411
"test_count": 20,
407412
"use_fp8_w8a8": True,
@@ -411,10 +416,10 @@ def tuning_configs(
411416
down_dict[m] = ans
412417

413418
MoeGroupedGemmKernelConfig.save_config(
414-
N=5120,
415-
K=192,
419+
N=hidden_dim,
420+
K=n,
416421
topk_num=1,
417-
expert_num=160,
422+
expert_num=expert_num,
418423
mul_routed_weight=True,
419424
use_fp8_w8a8=True,
420425
out_dtype=str(torch.bfloat16),

0 commit comments

Comments
 (0)