14
14
limitations under the License.
15
15
"""
16
16
17
+ import argparse
18
+ import pprint
19
+
17
20
import torch
18
21
from torch .nn import functional as F
19
- from triton .testing import do_bench
20
22
21
- import flashinfer
22
23
import flashinfer .fused_moe as fused_moe
23
24
from flashinfer import fp4_quantize
25
+ from flashinfer .autotuner import AutoTuner , autotune , get_config_path
26
+ from flashinfer .testing .utils import bench_gpu_time_with_cudagraph
24
27
25
- BATCH_SIZES = [
26
- 1 ,
27
- 2 ,
28
- 4 ,
29
- 8 ,
30
- 16 ,
31
- 24 ,
32
- 32 ,
33
- 48 ,
34
- 64 ,
35
- 96 ,
36
- 128 ,
37
- 256 ,
38
- 512 ,
39
- 1024 ,
40
- 1536 ,
41
- 2048 ,
42
- 3072 ,
43
- 4096 ,
44
- ]
45
-
46
- configs = []
47
- hidden_size = 7168
48
- num_experts = [32 , 256 ]
49
- top_k = [8 ]
50
- intermediate_size = [256 , 2048 ]
51
28
FLOAT4_E2M1_MAX = 6.0
52
29
FLOAT8_E4M3_MAX = torch .finfo (torch .float8_e4m3fn ).max
53
- FP8_DTYPE = torch . float8_e4m3fn
30
+
54
31
55
32
test_configs = [
56
33
{
@@ -96,6 +73,7 @@ def bench_cutlass_fused_moe(
96
73
num_experts ,
97
74
top_k ,
98
75
intermediate_size ,
76
+ skip_autotune ,
99
77
):
100
78
torch .manual_seed (42 )
101
79
quant_blocksize = 16
@@ -165,12 +143,24 @@ def bench_cutlass_fused_moe(
165
143
]
166
144
hidden_states = x
167
145
hidden_states , input_sf = fp4_quantize (x , a1_gs )
168
- repeats = 3
169
- from flashinfer .autotuner import AutoTuner , autotune
170
146
171
- AutoTuner .get ().clear_cache ()
172
- with torch .inference_mode (), autotune ():
173
- for _ in range (2 ):
147
+ # Warmup
148
+ for _ in range (3 ):
149
+ _ = fused_moe .cutlass_fused_moe (
150
+ hidden_states ,
151
+ selected_experts .to (torch .int ),
152
+ routing_weights ,
153
+ w1_q .contiguous ().view (torch .long ),
154
+ w2_q .contiguous ().view (torch .long ),
155
+ otype ,
156
+ quant_scales = quant_scales ,
157
+ input_sf = input_sf ,
158
+ output = flash_output ,
159
+ tune_max_num_tokens = 16384 ,
160
+ )
161
+
162
+ if not skip_autotune :
163
+ with torch .inference_mode (), autotune (True ):
174
164
_ = fused_moe .cutlass_fused_moe (
175
165
hidden_states ,
176
166
selected_experts .to (torch .int ),
@@ -181,8 +171,9 @@ def bench_cutlass_fused_moe(
181
171
quant_scales = quant_scales ,
182
172
input_sf = input_sf ,
183
173
output = flash_output ,
174
+ tune_max_num_tokens = 16384 ,
184
175
)
185
- ms = do_bench (
176
+ ms_list = bench_gpu_time_with_cudagraph (
186
177
lambda : fused_moe .cutlass_fused_moe (
187
178
hidden_states ,
188
179
selected_experts .to (torch .int ),
@@ -195,23 +186,44 @@ def bench_cutlass_fused_moe(
195
186
output = flash_output ,
196
187
)
197
188
)
189
+ avg_ms = sum (ms_list ) / len (ms_list )
190
+ print (f"{ 'input' :<15} { 'weight1' :<20} { 'weight2' :<20} { 'time(ms)' } " )
198
191
print (
199
- f"batch_size= { batch_size } , num_experts= { num_experts } , top_k= { top_k } , intermediate_size= { intermediate_size } "
192
+ f"{ str ( tuple ( hidden_states . shape )):<15 } { str ( tuple ( w1 . shape )):<20 } { str ( tuple ( w2 . shape )):<20 } { avg_ms :.3f } "
200
193
)
201
- print (f"execution time: { ms } ms" )
202
194
203
195
204
196
if __name__ == "__main__" :
197
+ parser = argparse .ArgumentParser ()
198
+ parser .add_argument (
199
+ "--update-config" ,
200
+ action = "store_true" ,
201
+ help = "Update the config file with the new profiling results" ,
202
+ )
203
+ parser .add_argument (
204
+ "--num-tokens" , type = int , default = 32 , help = "Number of tokens to profile"
205
+ )
206
+ parser .add_argument ("--skip-autotune" , action = "store_true" , help = "Skip autotuning" )
207
+ args = parser .parse_args ()
208
+ AutoTuner .get ().clear_cache ()
209
+
205
210
for config in test_configs :
206
- hidden_size = config ["hidden_size" ]
207
- num_experts = config ["num_experts" ]
208
- top_k = config ["top_k" ]
209
- intermediate_size = config ["intermediate_size" ]
210
- for batch_size in BATCH_SIZES :
211
- bench_cutlass_fused_moe (
212
- batch_size ,
213
- hidden_size ,
214
- num_experts ,
215
- top_k ,
216
- intermediate_size ,
217
- )
211
+ bench_cutlass_fused_moe (
212
+ args .num_tokens ,
213
+ config ["hidden_size" ],
214
+ config ["num_experts" ],
215
+ config ["top_k" ],
216
+ config ["intermediate_size" ],
217
+ args .skip_autotune ,
218
+ )
219
+
220
+ configs = AutoTuner .get ().profiling_cache
221
+ if args .update_config and configs :
222
+ # The original key contains a runner's hash in k[2] which might be different across machines.
223
+ # So, we remove it for now. v[0] and v[1] are the runner id and the tactic.
224
+ converted = {str ((k [0 ], k [1 ], k [3 ])): (v [0 ], v [1 ]) for k , v in configs .items ()}
225
+ config_path = get_config_path (is_module = False )
226
+ with open (config_path , "w" ) as f :
227
+ f .write ("best_configs = " )
228
+ pprint .pprint (converted , stream = f )
229
+ print (f"Saved the cache to { config_path } " )
0 commit comments