Skip to content

Commit 4bd6a9f

Browse files
authored
[Bugs] Fix DeepGEMM pre-compile tools. (#3351)
Fix some miss cache problems. Add README.md.
1 parent d4e3a20 commit 4bd6a9f

File tree

4 files changed

+198
-52
lines changed

4 files changed

+198
-52
lines changed

tools/deep_gemm_pre-compile/README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# DeepGEMM Pre-compilation Tool
2+
3+
This tool provides pre-compilation functionality for DeepGEMM kernels to optimize performance.
4+
5+
## Usage
6+
7+
### 1. Using Shell Script (Recommended)
8+
```bash
9+
bash pre_compile.sh \
10+
[MODEL_PATH] \
11+
[TP_SIZE] \
12+
[EP_SIZE] \
13+
[HAS_SHARED_EXPERTS] \
14+
[OUTPUT_FILE]
15+
```
16+
17+
The script will:
18+
1. Generate configurations
19+
2. Pre-compile all kernels
20+
21+
### 2. Alternative: Manual Steps
22+
If you need more control, you can run the steps manually:
23+
24+
#### Generate Configuration
25+
```bash
26+
python generate_config.py \
27+
--model /path/to/model \
28+
--tensor-parallel-size [TP_SIZE] \
29+
--expert-parallel-size [EP_SIZE] \
30+
--has-shared-experts [True/False] \
31+
--output [CONFIG_FILE]
32+
```
33+
34+
Arguments:
35+
- `--model`: Path to model directory containing config.json
36+
- `--tensor-parallel-size`: Tensor parallel size (default: 1)
37+
- `--expert-parallel-size`: Expert parallel size (default: 8)
38+
- `--has-shared-experts`: Whether model has shared experts (default: False)
39+
- `--output`: Output config file path (default: ./deep_gemm_pre_compile_config.jsonl)
40+
41+
#### Pre-compile Kernels
42+
```bash
43+
python pre_compile.py \
44+
--config-file [CONFIG_FILE] \
45+
--expert-parallel-size [EP_SIZE] \
46+
--num-threads [NUM_THREADS]
47+
```
48+
49+
Arguments:
50+
- `--config-file`: Path to config file generated in step 1
51+
- `--expert-parallel-size`: Expert parallel size (must match step 1)
52+
- `--num-threads`: Number of compilation threads (default: CPU cores)
53+
54+
## Environment Variables
55+
- `PRE_COMPILE_LOG_LEVEL`: Set log level (DEBUG/INFO/WARNING/ERROR)
56+
- `DG_CACHE_DIR`: Cache directory for compiled kernels (default: ./deep_gemm_cache)
57+
58+
## Notes
59+
- For best performance, set `--num-threads` to the number of available CPU cores
60+
- The compilation process may take significant time depending on configuration size
61+
- Compiled kernels will be cached in `DG_CACHE_DIR`

tools/deep_gemm_pre-compile/generate_config.py

Lines changed: 122 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
import math
1919
import os
20-
from typing import Tuple
20+
from typing import List, Tuple
2121

2222
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import get_smem_config
2323

@@ -27,33 +27,84 @@
2727
logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO"))
2828

2929

30-
def generate_kn_pairs(model_cfg: dict) -> Tuple[list, list, list]:
30+
def generate_kn_pairs(args, model_cfg: dict) -> Tuple[List, List, List]:
3131
hidden_size = model_cfg["hidden_size"]
3232
intermediate_size = model_cfg["intermediate_size"]
3333
moe_intermediate_size = model_cfg["moe_intermediate_size"]
3434
num_attention_heads = model_cfg["num_attention_heads"]
3535
num_key_value_heads = model_cfg["num_key_value_heads"]
3636
head_dim = int(hidden_size / num_attention_heads)
37-
gemm_kn_pairs = [
37+
tp_size = args.tensor_parallel_size
38+
ep_size = args.expert_parallel_size
39+
has_shared_experts = args.has_shared_experts.lower() == "true"
40+
41+
gemm_kn_pairs = []
42+
grouped_gemm_contiguous_kn_pairs = []
43+
grouped_gemm_masked_kn_pairs = []
44+
if tp_size > 1 and ep_size == 1:
45+
logger.debug("Generating kn pairs for tensor parallel.")
46+
# Dense normal gemm
47+
gemm_kn_pairs.extend(
48+
[
49+
[int(intermediate_size / tp_size), hidden_size],
50+
[hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2) / tp_size)],
51+
[hidden_size, int(intermediate_size * 2 / tp_size)],
52+
[int(hidden_size / tp_size), hidden_size],
53+
]
54+
)
55+
56+
# Moe grouped gemm contiguous
57+
grouped_gemm_contiguous_kn_pairs.extend(
58+
[
59+
[int(moe_intermediate_size / tp_size), hidden_size],
60+
[hidden_size, int(moe_intermediate_size * 2 / tp_size)],
61+
]
62+
)
63+
if has_shared_experts:
64+
logger.debug("Generating kn pairs for models with shared experts.")
65+
gemm_kn_pairs.extend(
66+
[
67+
[hidden_size, int(moe_intermediate_size * 4 / tp_size)],
68+
[int(moe_intermediate_size * 2 / tp_size), hidden_size],
69+
]
70+
)
71+
elif tp_size == 1 and ep_size > 1:
72+
logger.debug("Generating kn pairs for expert parallel.")
3873
# Dense normal gemm
39-
[hidden_size, intermediate_size * 2],
40-
[intermediate_size, hidden_size],
41-
[hidden_size, hidden_size],
42-
[
43-
hidden_size,
44-
(num_attention_heads + num_key_value_heads * 2) * head_dim,
45-
],
46-
]
47-
grouped_gemm_contiguous_kn_pairs = [
74+
gemm_kn_pairs.extend(
75+
[
76+
[intermediate_size, hidden_size],
77+
[hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2))],
78+
[hidden_size, int(intermediate_size * 2)],
79+
[hidden_size, hidden_size],
80+
]
81+
)
4882
# Moe grouped gemm contiguous
49-
[hidden_size, moe_intermediate_size * 2],
50-
[moe_intermediate_size, hidden_size],
51-
]
52-
grouped_gemm_masked_kn_pairs = [
83+
grouped_gemm_contiguous_kn_pairs.extend(
84+
[
85+
[moe_intermediate_size, hidden_size],
86+
[hidden_size, int(moe_intermediate_size * 2)],
87+
]
88+
)
5389
# Moe grouped gemm masked
54-
[hidden_size, moe_intermediate_size * 2],
55-
[moe_intermediate_size, hidden_size],
56-
]
90+
grouped_gemm_masked_kn_pairs.extend(
91+
[
92+
[moe_intermediate_size, hidden_size],
93+
[hidden_size, int(moe_intermediate_size * 2)],
94+
]
95+
)
96+
if has_shared_experts:
97+
logger.debug("Generating kn pairs for models with shared experts.")
98+
gemm_kn_pairs.extend(
99+
[
100+
[hidden_size, int(moe_intermediate_size * 4)],
101+
[int(moe_intermediate_size * 2), hidden_size],
102+
]
103+
)
104+
elif tp_size > 1 and ep_size > 1:
105+
raise ValueError("Not supported to enable EP and TP at the same time for now.")
106+
else:
107+
raise ValueError("Please check the tensor parallel size and expert parallel size.")
57108

58109
return (
59110
gemm_kn_pairs,
@@ -78,7 +129,8 @@ def generate_json(
78129
counter = 0
79130
with open(output_path, "a+", encoding="utf-8") as f:
80131
for block_m in BLOCK_MS:
81-
for block_n in BLOCK_NS:
132+
# NOTES: the block sizes can not be too large, so at least one dim less than 128
133+
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, BLOCK_NS):
82134
if 128 % block_n != 0 and 128 // math.gcd(128, block_n) <= 4:
83135
NUM_STAGES = [4, 3]
84136
else:
@@ -110,33 +162,43 @@ def generate_json(
110162
def main(args):
111163
with open(os.path.join(args.model, "config.json"), "r") as f:
112164
model_cfg = json.load(f)
113-
165+
logger.debug(
166+
f"TP Size: {args.tensor_parallel_size}, "
167+
f"EP Size: {args.expert_parallel_size}, "
168+
f"has shared experts: {args.has_shared_experts}"
169+
)
170+
logger.info(f"Configurations generated and saved to {args.output}")
114171
(
115172
gemm_kn_pairs,
116173
grouped_gemm_contiguous_kn_pairs,
117174
grouped_gemm_masked_kn_pairs,
118-
) = generate_kn_pairs(model_cfg)
119-
num_gemm = generate_json(
120-
gemm_kn_pairs,
121-
model_cfg["moe_num_experts"],
122-
args.output,
123-
)
124-
num_grouped_contiguous = generate_json(
125-
grouped_gemm_contiguous_kn_pairs,
126-
model_cfg["moe_num_experts"],
127-
args.output,
128-
is_grouped_contiguous=True,
129-
)
130-
num_grouped_masked = generate_json(
131-
grouped_gemm_masked_kn_pairs,
132-
model_cfg["moe_num_experts"],
133-
args.output,
134-
is_grouped_masked=True,
135-
)
136-
logger.info(f"Configurations generated and saved to {args.output}")
137-
logger.info(f"Generated {num_gemm} gemm configuration.")
138-
logger.info(f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration.")
139-
logger.info(f"Generated {num_grouped_masked} grouped_gemm_masked configuration.")
175+
) = generate_kn_pairs(args, model_cfg)
176+
logger.debug(f"GEMM KN pairs: {gemm_kn_pairs}")
177+
logger.debug(f"Grouped GEMM Contiguous KN pairs: {grouped_gemm_contiguous_kn_pairs}")
178+
logger.debug(f"Grouped GEMM Masked KN pairs: {grouped_gemm_masked_kn_pairs}")
179+
if len(gemm_kn_pairs) > 0:
180+
num_gemm = generate_json(
181+
gemm_kn_pairs,
182+
model_cfg["moe_num_experts"],
183+
args.output,
184+
)
185+
logger.info(f"Generated {num_gemm} gemm configuration.")
186+
if len(grouped_gemm_contiguous_kn_pairs) > 0:
187+
num_grouped_contiguous = generate_json(
188+
grouped_gemm_contiguous_kn_pairs,
189+
model_cfg["moe_num_experts"],
190+
args.output,
191+
is_grouped_contiguous=True,
192+
)
193+
logger.info(f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration.")
194+
if len(grouped_gemm_masked_kn_pairs) > 0:
195+
num_grouped_masked = generate_json(
196+
grouped_gemm_masked_kn_pairs,
197+
model_cfg["moe_num_experts"],
198+
args.output,
199+
is_grouped_masked=True,
200+
)
201+
logger.info(f"Generated {num_grouped_masked} grouped_gemm_masked configuration.")
140202

141203

142204
if __name__ == "__main__":
@@ -146,6 +208,23 @@ def main(args):
146208
type=str,
147209
required=True,
148210
)
211+
parser.add_argument(
212+
"--tensor-parallel-size",
213+
"--tp",
214+
type=int,
215+
default=1,
216+
)
217+
parser.add_argument(
218+
"--expert-parallel-size",
219+
"--ep",
220+
type=int,
221+
default=1,
222+
)
223+
parser.add_argument(
224+
"--has-shared-experts",
225+
type=str,
226+
default="False",
227+
)
149228
parser.add_argument(
150229
"--output",
151230
type=str,

tools/deep_gemm_pre-compile/pre_compile.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,25 +162,25 @@ def pre_compile_from_config(config_file: str, num_threads: int, expert_parallel:
162162

163163

164164
def main(args):
165-
pre_compile_from_config(args.config_file, args.num_threads, args.expert_parallel)
165+
pre_compile_from_config(args.config_file, args.num_threads, args.expert_parallel_size)
166166

167167

168168
if __name__ == "__main__":
169169

170170
parser = argparse.ArgumentParser()
171171
parser.add_argument(
172-
"--config_file",
172+
"--config-file",
173173
type=str,
174174
default="./deep_gemm_pre_compile_config.jsonl",
175175
)
176176
parser.add_argument(
177-
"--expert_parallel",
177+
"--expert-parallel-size",
178178
"--ep",
179179
type=int,
180180
default=8,
181181
)
182182
parser.add_argument(
183-
"--num_threads",
183+
"--num-threads",
184184
type=int,
185185
default=16,
186186
)

tools/deep_gemm_pre-compile/pre_compile.sh

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,20 @@ export DG_CACHE_DIR=$(pwd)/deep_gemm_cache
1818
echo DeepGEMM Cache Dir: $DG_CACHE_DIR
1919

2020
MODEL_PATH=${1:-"/path/to/model"}
21-
EXPERT_PARALLEL=${2:-"8"}
21+
TENSOR_PARALLEL_SIZE=${2:-"1"}
22+
EXPERT_PARALLEL_SIZE=${3:-"8"}
23+
HAS_SHARED_EXPERTS=${4:-"False"}
24+
OUTPUT_FILE=${5:-"./deep_gemm_pre_compile_config.jsonl"}
2225
nproc=$(nproc)
2326

2427
python generate_config.py \
2528
--model $MODEL_PATH \
26-
--output=./deep_gemm_pre_compile_config.jsonl
29+
--tensor-parallel-size $TENSOR_PARALLEL_SIZE \
30+
--expert-parallel-size $EXPERT_PARALLEL_SIZE \
31+
--has-shared-experts $HAS_SHARED_EXPERTS \
32+
--output $OUTPUT_FILE
2733

2834
python pre_compile.py \
29-
--config_file=./deep_gemm_pre_compile_config.jsonl \
30-
--expert_parallel=$EXPERT_PARALLEL \
31-
--num_threads=$nproc
35+
--config-file $OUTPUT_FILE \
36+
--expert-parallel-size $EXPERT_PARALLEL_SIZE \
37+
--num-threads $nproc

0 commit comments

Comments
 (0)