Skip to content

Commit f535684

Browse files
author
wangzaijun
committed
add kernel tunning setting config.
1 parent d9ddd46 commit f535684

File tree

5 files changed

+134
-16
lines changed

5 files changed

+134
-16
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
get_device_sm_shared_mem_num,
3131
get_device_warp_size,
3232
)
33+
from .moe_kernel_configs import MoeGroupedGemmKernelConfig
3334

3435
FFN_MOE_CHUNK_SIZE = 8 * 1024
3536

@@ -365,16 +366,25 @@ def grouped_matmul(
365366
out is tensor shape [token_num * topk_num, out_dim]
366367
"""
367368
compute_type = tl.bfloat16 if out.dtype == torch.bfloat16 else tl.float16
369+
expert_num, n, k = expert_weights.shape
370+
assert token_inputs.shape[1] == k
371+
assert expert_to_token_index.shape == expert_to_weights.shape
372+
assert token_inputs.is_contiguous()
373+
assert expert_to_token_num.is_contiguous()
374+
assert expert_to_weights.is_contiguous()
375+
assert expert_weights.is_contiguous()
368376

369377
if not run_config:
370-
run_config = {
371-
"BLOCK_SIZE_M": 64,
372-
"BLOCK_SIZE_N": 64,
373-
"BLOCK_SIZE_K": 32,
374-
"GROUP_SIZE_M": 1,
375-
"num_warps": 4,
376-
"num_stages": 3,
377-
}
378+
run_config = MoeGroupedGemmKernelConfig.try_to_get_best_config(
379+
M=token_inputs.shape[0],
380+
N=n,
381+
K=k,
382+
topk_num=topk_num,
383+
expert_num=expert_num,
384+
mul_routed_weight=mul_routed_weight,
385+
use_fp8_w8a8=use_fp8_w8a8,
386+
out_dtype=str(out.dtype),
387+
)
378388
BLOCK_SIZE_M = run_config["BLOCK_SIZE_M"]
379389
BLOCK_SIZE_N = run_config["BLOCK_SIZE_N"]
380390
BLOCK_SIZE_K = run_config["BLOCK_SIZE_K"]
@@ -385,14 +395,6 @@ def grouped_matmul(
385395
if use_fp8_w8a8:
386396
token_inputs, token_input_scale = ops.scaled_fp8_quant(token_inputs, token_input_scale)
387397

388-
expert_num, n, k = expert_weights.shape
389-
assert token_inputs.shape[1] == k
390-
assert expert_to_token_index.shape == expert_to_weights.shape
391-
assert token_inputs.is_contiguous()
392-
assert expert_to_token_num.is_contiguous()
393-
assert expert_to_weights.is_contiguous()
394-
assert expert_weights.is_contiguous()
395-
396398
kernel = grouped_matmul_kernel.warmup(
397399
expert_token_limit,
398400
k,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
from frozendict import frozendict
3+
from functools import lru_cache
4+
from lightllm.common.kernel_config import KernelConfigs
5+
from lightllm.utils.log_utils import init_logger
6+
7+
logger = init_logger(__name__)
8+
9+
10+
class MoeGroupedGemmKernelConfig(KernelConfigs):
11+
@classmethod
12+
@lru_cache(maxsize=200)
13+
def try_to_get_best_config(
14+
cls,
15+
M: int,
16+
N: int,
17+
K: int,
18+
topk_num: int,
19+
expert_num: int,
20+
mul_routed_weight: bool,
21+
use_fp8_w8a8: bool,
22+
out_dtype: str,
23+
) -> dict:
24+
key_params = {
25+
"N": N,
26+
"K": K,
27+
"topk_num": topk_num,
28+
"expert_num": expert_num,
29+
"mul_routed_weight": mul_routed_weight,
30+
"use_fp8_w8a8": use_fp8_w8a8,
31+
"out_dtype": out_dtype,
32+
}
33+
key_params = frozendict(key_params)
34+
35+
finded_config = KernelConfigs.get_the_config(key_params, os.path.dirname(os.path.realpath(__file__)))
36+
37+
if finded_config:
38+
config = finded_config[min(finded_config.keys(), key=lambda x: abs(x - M))]
39+
return config
40+
else:
41+
if M <= expert_num:
42+
config = {
43+
"BLOCK_SIZE_M": 16,
44+
"BLOCK_SIZE_N": 32,
45+
"BLOCK_SIZE_K": 64,
46+
"GROUP_SIZE_M": 1,
47+
"num_warps": 4,
48+
"num_stages": 1,
49+
}
50+
else:
51+
config = {
52+
"BLOCK_SIZE_M": 64,
53+
"BLOCK_SIZE_N": 64,
54+
"BLOCK_SIZE_K": 32,
55+
"GROUP_SIZE_M": 8,
56+
"num_warps": 4,
57+
"num_stages": 1,
58+
}
59+
return config

lightllm/common/kernel_config.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
import json
3+
import re
4+
from abc import ABC, abstractmethod
5+
from typing import Dict, Any, Optional
6+
from functools import lru_cache
7+
from lightllm.utils.log_utils import init_logger
8+
from lightllm.utils.device_utils import get_current_device_name
9+
10+
logger = init_logger(__name__)
11+
12+
13+
class KernelConfigs(ABC):
14+
@classmethod
15+
def get_config_file_name(cls, params: Dict[str, Any]) -> str:
16+
json_str = json.dumps(params, sort_keys=True)
17+
json_str = json_str.replace(" ", "").replace("\n", "").replace('"', "")
18+
filename = json_str
19+
device_name = get_current_device_name().replace(" ", "_")
20+
return f"{filename}_{device_name}.json"
21+
22+
@lru_cache(maxsize=None)
23+
def get_the_config(params: Dict[str, Any], config_dir_path) -> Optional[dict]:
24+
json_file_name = KernelConfigs.get_config_file_name(params)
25+
config_file_path = os.path.join(config_dir_path, "configs", json_file_name)
26+
27+
if os.path.exists(config_file_path):
28+
return json.load(config_file_path)
29+
else:
30+
logger.warning(f"can not find config_path {config_file_path}")
31+
return None
32+
33+
@classmethod
34+
def store_config(cls, params: Dict[str, Any], config_dir_path: str, dest_json: dict):
35+
json_file_name = KernelConfigs.get_config_file_name(params)
36+
config_file_path = os.path.join(config_dir_path, "configs", json_file_name)
37+
with open(config_file_path, mode="w") as file:
38+
json.dump(dest_json, file)
39+
return
40+
41+
@classmethod
42+
@abstractmethod
43+
def try_to_get_best_config(cls, *args, **kwargs) -> dict:
44+
pass

lightllm/utils/device_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,15 @@ def get_device_warp_size():
3535

3636
properties = driver.active.utils.get_device_properties(0)
3737
return properties["warpSize"]
38+
39+
40+
@lru_cache(maxsize=None)
41+
def get_current_device_name():
42+
import torch
43+
44+
if torch.cuda.is_available():
45+
device = torch.cuda.current_device()
46+
gpu_name = torch.cuda.get_device_name(device)
47+
return gpu_name
48+
else:
49+
return None

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,4 @@ prometheus_client==0.20.0
8181
outlines==0.0.46
8282
cchardet==2.1.7
8383
ujson==5.10.0
84+
frozendict==2.4.6

0 commit comments

Comments
 (0)