17
17
import logging
18
18
import math
19
19
import os
20
- from typing import Tuple
20
+ from typing import List , Tuple
21
21
22
22
from fastdeploy .model_executor .ops .gpu .deep_gemm .jit_kernels .gemm import get_smem_config
23
23
27
27
logger .setLevel (os .getenv ("PRE_COMPILE_LOG_LEVEL" , "INFO" ))
28
28
29
29
30
- def generate_kn_pairs (model_cfg : dict ) -> Tuple [list , list , list ]:
30
+ def generate_kn_pairs (args , model_cfg : dict ) -> Tuple [List , List , List ]:
31
31
hidden_size = model_cfg ["hidden_size" ]
32
32
intermediate_size = model_cfg ["intermediate_size" ]
33
33
moe_intermediate_size = model_cfg ["moe_intermediate_size" ]
34
34
num_attention_heads = model_cfg ["num_attention_heads" ]
35
35
num_key_value_heads = model_cfg ["num_key_value_heads" ]
36
36
head_dim = int (hidden_size / num_attention_heads )
37
- gemm_kn_pairs = [
37
+ tp_size = args .tensor_parallel_size
38
+ ep_size = args .expert_parallel_size
39
+ has_shared_experts = args .has_shared_experts .lower () == "true"
40
+
41
+ gemm_kn_pairs = []
42
+ grouped_gemm_contiguous_kn_pairs = []
43
+ grouped_gemm_masked_kn_pairs = []
44
+ if tp_size > 1 and ep_size == 1 :
45
+ logger .debug ("Generating kn pairs for tensor parallel." )
46
+ # Dense normal gemm
47
+ gemm_kn_pairs .extend (
48
+ [
49
+ [int (intermediate_size / tp_size ), hidden_size ],
50
+ [hidden_size , int (head_dim * (num_attention_heads + num_key_value_heads * 2 ) / tp_size )],
51
+ [hidden_size , int (intermediate_size * 2 / tp_size )],
52
+ [int (hidden_size / tp_size ), hidden_size ],
53
+ ]
54
+ )
55
+
56
+ # Moe grouped gemm contiguous
57
+ grouped_gemm_contiguous_kn_pairs .extend (
58
+ [
59
+ [int (moe_intermediate_size / tp_size ), hidden_size ],
60
+ [hidden_size , int (moe_intermediate_size * 2 / tp_size )],
61
+ ]
62
+ )
63
+ if has_shared_experts :
64
+ logger .debug ("Generating kn pairs for models with shared experts." )
65
+ gemm_kn_pairs .extend (
66
+ [
67
+ [hidden_size , int (moe_intermediate_size * 4 / tp_size )],
68
+ [int (moe_intermediate_size * 2 / tp_size ), hidden_size ],
69
+ ]
70
+ )
71
+ elif tp_size == 1 and ep_size > 1 :
72
+ logger .debug ("Generating kn pairs for expert parallel." )
38
73
# Dense normal gemm
39
- [hidden_size , intermediate_size * 2 ],
40
- [intermediate_size , hidden_size ],
41
- [hidden_size , hidden_size ],
42
- [
43
- hidden_size ,
44
- (num_attention_heads + num_key_value_heads * 2 ) * head_dim ,
45
- ],
46
- ]
47
- grouped_gemm_contiguous_kn_pairs = [
74
+ gemm_kn_pairs .extend (
75
+ [
76
+ [intermediate_size , hidden_size ],
77
+ [hidden_size , int (head_dim * (num_attention_heads + num_key_value_heads * 2 ))],
78
+ [hidden_size , int (intermediate_size * 2 )],
79
+ [hidden_size , hidden_size ],
80
+ ]
81
+ )
48
82
# Moe grouped gemm contiguous
49
- [hidden_size , moe_intermediate_size * 2 ],
50
- [moe_intermediate_size , hidden_size ],
51
- ]
52
- grouped_gemm_masked_kn_pairs = [
83
+ grouped_gemm_contiguous_kn_pairs .extend (
84
+ [
85
+ [moe_intermediate_size , hidden_size ],
86
+ [hidden_size , int (moe_intermediate_size * 2 )],
87
+ ]
88
+ )
53
89
# Moe grouped gemm masked
54
- [hidden_size , moe_intermediate_size * 2 ],
55
- [moe_intermediate_size , hidden_size ],
56
- ]
90
+ grouped_gemm_masked_kn_pairs .extend (
91
+ [
92
+ [moe_intermediate_size , hidden_size ],
93
+ [hidden_size , int (moe_intermediate_size * 2 )],
94
+ ]
95
+ )
96
+ if has_shared_experts :
97
+ logger .debug ("Generating kn pairs for models with shared experts." )
98
+ gemm_kn_pairs .extend (
99
+ [
100
+ [hidden_size , int (moe_intermediate_size * 4 )],
101
+ [int (moe_intermediate_size * 2 ), hidden_size ],
102
+ ]
103
+ )
104
+ elif tp_size > 1 and ep_size > 1 :
105
+ raise ValueError ("Not supported to enable EP and TP at the same time for now." )
106
+ else :
107
+ raise ValueError ("Please check the tensor parallel size and expert parallel size." )
57
108
58
109
return (
59
110
gemm_kn_pairs ,
@@ -78,7 +129,8 @@ def generate_json(
78
129
counter = 0
79
130
with open (output_path , "a+" , encoding = "utf-8" ) as f :
80
131
for block_m in BLOCK_MS :
81
- for block_n in BLOCK_NS :
132
+ # NOTES: the block sizes can not be too large, so at least one dim less than 128
133
+ for block_n in filter (lambda bn : block_m <= 128 or bn <= 128 , BLOCK_NS ):
82
134
if 128 % block_n != 0 and 128 // math .gcd (128 , block_n ) <= 4 :
83
135
NUM_STAGES = [4 , 3 ]
84
136
else :
@@ -110,33 +162,43 @@ def generate_json(
110
162
def main (args ):
111
163
with open (os .path .join (args .model , "config.json" ), "r" ) as f :
112
164
model_cfg = json .load (f )
113
-
165
+ logger .debug (
166
+ f"TP Size: { args .tensor_parallel_size } , "
167
+ f"EP Size: { args .expert_parallel_size } , "
168
+ f"has shared experts: { args .has_shared_experts } "
169
+ )
170
+ logger .info (f"Configurations generated and saved to { args .output } " )
114
171
(
115
172
gemm_kn_pairs ,
116
173
grouped_gemm_contiguous_kn_pairs ,
117
174
grouped_gemm_masked_kn_pairs ,
118
- ) = generate_kn_pairs (model_cfg )
119
- num_gemm = generate_json (
120
- gemm_kn_pairs ,
121
- model_cfg ["moe_num_experts" ],
122
- args .output ,
123
- )
124
- num_grouped_contiguous = generate_json (
125
- grouped_gemm_contiguous_kn_pairs ,
126
- model_cfg ["moe_num_experts" ],
127
- args .output ,
128
- is_grouped_contiguous = True ,
129
- )
130
- num_grouped_masked = generate_json (
131
- grouped_gemm_masked_kn_pairs ,
132
- model_cfg ["moe_num_experts" ],
133
- args .output ,
134
- is_grouped_masked = True ,
135
- )
136
- logger .info (f"Configurations generated and saved to { args .output } " )
137
- logger .info (f"Generated { num_gemm } gemm configuration." )
138
- logger .info (f"Generated { num_grouped_contiguous } grouped_gemm_contiguous configuration." )
139
- logger .info (f"Generated { num_grouped_masked } grouped_gemm_masked configuration." )
175
+ ) = generate_kn_pairs (args , model_cfg )
176
+ logger .debug (f"GEMM KN pairs: { gemm_kn_pairs } " )
177
+ logger .debug (f"Grouped GEMM Contiguous KN pairs: { grouped_gemm_contiguous_kn_pairs } " )
178
+ logger .debug (f"Grouped GEMM Masked KN pairs: { grouped_gemm_masked_kn_pairs } " )
179
+ if len (gemm_kn_pairs ) > 0 :
180
+ num_gemm = generate_json (
181
+ gemm_kn_pairs ,
182
+ model_cfg ["moe_num_experts" ],
183
+ args .output ,
184
+ )
185
+ logger .info (f"Generated { num_gemm } gemm configuration." )
186
+ if len (grouped_gemm_contiguous_kn_pairs ) > 0 :
187
+ num_grouped_contiguous = generate_json (
188
+ grouped_gemm_contiguous_kn_pairs ,
189
+ model_cfg ["moe_num_experts" ],
190
+ args .output ,
191
+ is_grouped_contiguous = True ,
192
+ )
193
+ logger .info (f"Generated { num_grouped_contiguous } grouped_gemm_contiguous configuration." )
194
+ if len (grouped_gemm_masked_kn_pairs ) > 0 :
195
+ num_grouped_masked = generate_json (
196
+ grouped_gemm_masked_kn_pairs ,
197
+ model_cfg ["moe_num_experts" ],
198
+ args .output ,
199
+ is_grouped_masked = True ,
200
+ )
201
+ logger .info (f"Generated { num_grouped_masked } grouped_gemm_masked configuration." )
140
202
141
203
142
204
if __name__ == "__main__" :
@@ -146,6 +208,23 @@ def main(args):
146
208
type = str ,
147
209
required = True ,
148
210
)
211
+ parser .add_argument (
212
+ "--tensor-parallel-size" ,
213
+ "--tp" ,
214
+ type = int ,
215
+ default = 1 ,
216
+ )
217
+ parser .add_argument (
218
+ "--expert-parallel-size" ,
219
+ "--ep" ,
220
+ type = int ,
221
+ default = 1 ,
222
+ )
223
+ parser .add_argument (
224
+ "--has-shared-experts" ,
225
+ type = str ,
226
+ default = "False" ,
227
+ )
149
228
parser .add_argument (
150
229
"--output" ,
151
230
type = str ,
0 commit comments