Skip to content

Commit 80626a9

Browse files
stream-k v0.4 (#689)
* add perfci tuning shapes and fall back configs * make test_correctness work with multiple kernels * change persistent_gemm kernel name back * add persistent_gemm unit tests * adapt tune_streamk to be able to switch kernel tunning from command line * fix output file name issue and merge final TFLOPS and time into yaml file * remove comments
1 parent 4a7afd2 commit 80626a9

File tree

9 files changed

+238
-46
lines changed

9 files changed

+238
-46
lines changed

python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from streamk_kernel import streamk_gemm
66
#from streamk_kernel_atomic import streamk_gemm
7-
#from persistent_gemm import streamk_gemm
7+
#from persistent_gemm import persistent_gemm
88

99
torch.manual_seed(123)
1010
random.seed(123)

python/perf-kernels/streamk/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
# streak gemm script v0.4
2+
3+
### features added:
4+
- enable tuning with different kernels from command line, default for streamk_kernel
5+
- change persistent_gemm kernel back
6+
-
7+
8+
### command line example
9+
10+
tuning shapes using persistent_gemm
11+
```
12+
python tune_streamk.py --gemm_size_file input.yaml --kernel persistent_gemm,persistent_gemm --ngpus 8 --jobs 24
13+
```
14+
15+
test correctness using streamk_kernel
16+
```
17+
python tune_streamk.py --gemm_size_file tuning_results.yaml --compare_wo_tuning --kernel streamk_kernel,streamk_gemm
18+
19+
```
20+
121
# streamk gemm script v0.3
222

323
### features added:

python/perf-kernels/streamk/fallback_tuned_configs.yaml

Lines changed: 69 additions & 0 deletions
Large diffs are not rendered by default.

python/perf-kernels/streamk/persistent_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
@triton.jit()
6-
def streamk_gemm(
6+
def persistent_gemm(
77
A,
88
B,
99
C,

python/perf-kernels/streamk/tune_streamk.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import torch
1010
import triton
1111
import triton.language as tl
12-
13-
from streamk_kernel import streamk_gemm
12+
import importlib
1413

1514
from datetime import datetime
1615
import multiprocessing
@@ -206,9 +205,9 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
206205
os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid)
207206
jobId = gpuIdx
208207
while jobId < jobs:
209-
kernel_name = get_filename_profile_driver(M, N, K, jobId)
208+
kernelname = get_filename_profile_driver(M, N, K, jobId)
210209
if verbose:
211-
print(f"profiling {kernel_name} on GPU {gpuid}")
210+
print(f"profiling {kernelname} on GPU {gpuid}")
212211
run_bash_command_wrapper(
213212
f"rocprof --stats -o results_{jobId}.csv python {get_filename_profile_driver(M, N, K, jobId)}",
214213
# f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}",
@@ -217,15 +216,15 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
217216

218217

219218
def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs,
220-
run_bench, jobs, iters, skipWarmup, verbose=0, num_threads=32, gpus=[0], rotating_buffer_size=256,
221-
bias_size=0, icache_flush=False):
219+
run_bench, jobs, iters, skipWarmup, module_name, kernel_name, verbose=0, num_threads=32, gpus=[0],
220+
rotating_buffer_size=256, bias_size=0, icache_flush=False):
222221

223222
# precompile the kernels in parallel
224223
start_time = datetime.now()
225224
if not skipWarmup:
226225
# Generate kernel out of all configs
227226
fname = generate_compile_driver(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock,
228-
init_type, configs, rotating_buffer_size, bias_size)
227+
init_type, configs, rotating_buffer_size, bias_size, kernel_name)
229228

230229
run_bash_command(f"python {fname} -n {num_threads}", capture=(verbose < 2))
231230
compile_end = datetime.now()
@@ -235,7 +234,8 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p,
235234

236235
# Generate kernels out of all configs
237236
generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs,
238-
jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush)
237+
jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush, module_name,
238+
kernel_name)
239239

240240
# profile generated kernels
241241
running = [
@@ -377,8 +377,8 @@ def gen_rotating_tensors(M, N, K, dtype_a, need_Trans_a, dtype_b, need_Trans_b,
377377
return in_outs
378378

379379

380-
def matmul(a, b, c, bias, P, locks, num_sms, block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu,
381-
mfmaInstrSize, kpack, use_bias):
380+
def matmul(kernel_func, a, b, c, bias, P, locks, num_sms, block_m, block_n, block_k, group_m, num_warps, num_stages,
381+
waves_per_eu, mfmaInstrSize, kpack, use_bias):
382382
# Check constraints.
383383
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
384384
#assert a.is_contiguous(), "Matrix A must be contiguous"
@@ -396,7 +396,7 @@ def matmul(a, b, c, bias, P, locks, num_sms, block_m, block_n, block_k, group_m,
396396
streamk_tiles = m_tiles * n_tiles % num_sms
397397
# change num_xcds = 1 if using MI250
398398
num_xcds = 8
399-
streamk_gemm[
399+
kernel_func[
400400
grid,
401401
](a, b, c, bias, P, locks, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),
402402
stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m,
@@ -405,7 +405,8 @@ def matmul(a, b, c, bias, P, locks, num_sms, block_m, block_n, block_k, group_m,
405405
return c
406406

407407

408-
def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, bias_vector, verbose):
408+
def test_correctness(kernel_func, M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, bias_vector,
409+
verbose):
409410
block_m, block_n, block_k, group_m, num_sms, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(
410411
config)
411412
use_bias = bias_vector
@@ -423,8 +424,8 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
423424
c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]])
424425
locks = torch.zeros((num_sms, ), device="cuda", dtype=torch.int32)
425426
P = torch.zeros((num_sms, block_m * block_n), device="cuda", dtype=torch.float32)
426-
triton_output = matmul(a, b, c, bias, P, locks, num_sms, block_m, block_n, block_k, group_m, num_warps, num_stages,
427-
waves_per_eu, mfmaInstrSize, kpack, use_bias)
427+
triton_output = matmul(kernel_func, a, b, c, bias, P, locks, num_sms, block_m, block_n, block_k, group_m, num_warps,
428+
num_stages, waves_per_eu, mfmaInstrSize, kpack, use_bias)
428429
torch_output = torch.matmul(a_fp16, b_fp16)
429430
if use_bias:
430431
torch_output += bias_fp16[:, None]
@@ -435,7 +436,7 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
435436
size_str = ''
436437
if verbose:
437438
size_str = f'SIZE M: {M}, N: {N}, K: {K}, trans: {row_a_str}{row_b_str}'
438-
print(f'{size_str} correctness check')
439+
print(f'{kernel_func} {size_str} correctness check')
439440
torch.testing.assert_close(triton_output.to(torch.float16), torch_output, atol=atol, rtol=rtol)
440441
print(f'{size_str} Correct✅')
441442

@@ -446,6 +447,8 @@ def parse_args():
446447
allow_abbrev=False,
447448
)
448449

450+
parser.add_argument('--kernel', default='streamk_kernel, streamk_gemm',
451+
help='can specify different kernel file name')
449452
parser.add_argument("-m", type=int, default=0)
450453
parser.add_argument("-n", type=int, default=0)
451454
parser.add_argument("-k", type=int, default=0)
@@ -486,11 +489,6 @@ def parse_args():
486489
parser.add_argument("--hack_triton_compiler", action='store_true', default=False,
487490
help="Modify the triton source to avoid backend query")
488491
args = parser.parse_args()
489-
if not args.o:
490-
if args.benchmark:
491-
args.o = "benchmarking_results.csv"
492-
else:
493-
args.o = get_default_tuning_result_filename()
494492

495493
return args
496494

@@ -542,6 +540,19 @@ def get_rocm_version():
542540

543541
def main():
544542
args = parse_args()
543+
# parse kernel file and kernel function name
544+
module_name, kernel_name = args.kernel.split(',')
545+
module_name = module_name.strip()
546+
kernel_name = kernel_name.strip()
547+
module = importlib.import_module(module_name)
548+
kernel_func = getattr(module, kernel_name)
549+
550+
if not args.o:
551+
if args.benchmark:
552+
args.o = f"benchmarking_results_{kernel_name}.csv"
553+
else:
554+
args.o = get_default_tuning_result_filename(kernel_name)
555+
545556
matrix_size_file = args.gemm_size_file
546557
output_file = args.o
547558
keepTmp = args.keep
@@ -613,7 +624,8 @@ def main():
613624
for (M, N, K, col_a, col_b, myConfig) in mnks:
614625
if myConfig is None:
615626
raise Exception("kernel config is None, need to provide a tuning config")
616-
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, myConfig, bias_vector, True)
627+
test_correctness(kernel_func, M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, myConfig,
628+
bias_vector, True)
617629
return
618630

619631
configs_full = get_full_tuning_space()
@@ -658,7 +670,7 @@ def main():
658670
configs += delta_configs
659671

660672
## Append new configs into the tuning space
661-
generate_matmul_kernels(delta_configs)
673+
generate_matmul_kernels(delta_configs, module_name, kernel_name)
662674

663675
row_a_str = 'N' if col_a else 'T'
664676
row_b_str = 'N' if col_b else 'T'
@@ -679,8 +691,9 @@ def main():
679691
bias_size = M if bias_vector else 0
680692
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(
681693
M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, pruned_configs, run_bench,
682-
jobs, iters, skipWarmup, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level,
683-
rotating_buffer_size=rotating_buffer_size, bias_size=bias_size, icache_flush=icache_flush)
694+
jobs, iters, skipWarmup, module_name, kernel_name, num_threads=args.num_threads, gpus=gpus,
695+
verbose=verbose_level, rotating_buffer_size=rotating_buffer_size, bias_size=bias_size,
696+
icache_flush=icache_flush)
684697

685698
# post processing the numbers
686699
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
@@ -701,9 +714,9 @@ def main():
701714

702715
sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str}
703716
sizeDict.update(bestConfig)
717+
sizeDict.update({'TFLOPS': formatted_tflops, 'time(us)': minTime})
704718
if not run_bench:
705719
f_results.write("- " + str(sizeDict) + " ")
706-
f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n')
707720

708721
# remove generated files if asked to
709722
if not keepTmp:
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# If new gemm sizes/configs are added please provide a fallback tuned config in fallback_configs.yaml.
2+
# The fallback configs will be used in case the weekly tuning has not been run
3+
#
4+
- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'}
5+
- {'M': 4864, 'N': 4096, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N'}
6+
- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N'}
7+
- {'M': 4864, 'N': 4096, 'K': 4288, 'rowMajorA': 'T', 'rowMajorB': 'N'}
8+
9+
- {'M': 4864, 'N': 4096, 'K': 4097, 'rowMajorA': 'T', 'rowMajorB': 'N'}
10+
- {'M': 4864, 'N': 4096, 'K': 4098, 'rowMajorA': 'T', 'rowMajorB': 'N'}
11+
- {'M': 4864, 'N': 4096, 'K': 4100, 'rowMajorA': 'T', 'rowMajorB': 'N'}
12+
- {'M': 4864, 'N': 4096, 'K': 4104, 'rowMajorA': 'T', 'rowMajorB': 'N'}
13+
- {'M': 4864, 'N': 4096, 'K': 4112, 'rowMajorA': 'T', 'rowMajorB': 'N'}
14+
15+
- {'M': 4864, 'N': 8192, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'}
16+
- {'M': 4864, 'N': 8192, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N'}
17+
- {'M': 4864, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'}
18+
- {'M': 4864, 'N': 8192, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'N'}
19+
20+
- {'M': 1024, 'N': 1024, 'K': 1024, 'rowMajorA': 'T', 'rowMajorB': 'N'}
21+
- {'M': 1024, 'N': 1024, 'K': 1024, 'rowMajorA': 'T', 'rowMajorB': 'T'}
22+
- {'M': 8192, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'}
23+
- {'M': 8192, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'T'}
24+
25+
- {'M': 4864, 'N': 4096, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'N'}
26+
- {'M': 4864, 'N': 4096, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'T'}
27+
28+
- {'M': 1024, 'N': 8192, 'K': 28672, 'rowMajorA': 'T', 'rowMajorB': 'N'}
29+
- {'M': 1024, 'N': 8192, 'K': 28672, 'rowMajorA': 'T', 'rowMajorB': 'T'}
30+
- {'M': 1024, 'N': 28672, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'}
31+
- {'M': 1024, 'N': 28672, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'T'}
32+
- {'M': 1024, 'N': 14336, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'}
33+
- {'M': 1024, 'N': 14336, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'T'}
34+
- {'M': 1, 'N': 8192, 'K': 28672, 'rowMajorA': 'T', 'rowMajorB': 'N'}
35+
- {'M': 1, 'N': 8192, 'K': 28672, 'rowMajorA': 'T', 'rowMajorB': 'T'}
36+
- {'M': 1, 'N': 14336, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'}
37+
- {'M': 1, 'N': 14336, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'T'}
38+
- {'M': 1024, 'N': 16384, 'K': 53248, 'rowMajorA': 'T', 'rowMajorB': 'N'}
39+
- {'M': 1024, 'N': 16384, 'K': 53248, 'rowMajorA': 'T', 'rowMajorB': 'T'}
40+
- {'M': 1024, 'N': 53248, 'K': 16384, 'rowMajorA': 'T', 'rowMajorB': 'N'}
41+
- {'M': 1024, 'N': 53248, 'K': 16384, 'rowMajorA': 'T', 'rowMajorB': 'T'}
42+
- {'M': 32, 'N': 16384, 'K': 53248, 'rowMajorA': 'T', 'rowMajorB': 'N'}
43+
- {'M': 32, 'N': 16384, 'K': 53248, 'rowMajorA': 'T', 'rowMajorB': 'T'}
44+
- {'M': 32, 'N': 53248, 'K': 16384, 'rowMajorA': 'T', 'rowMajorB': 'N'}
45+
- {'M': 32, 'N': 53248, 'K': 16384, 'rowMajorA': 'T', 'rowMajorB': 'T'}
46+
47+
- {'M': 2, 'N': 3584, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'}
48+
- {'M': 2, 'N': 3584, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'T'}
49+
- {'M': 2, 'N': 4096, 'K': 1792, 'rowMajorA': 'T', 'rowMajorB': 'N'}
50+
- {'M': 2, 'N': 4096, 'K': 1792, 'rowMajorA': 'T', 'rowMajorB': 'T'}
51+
- {'M': 4096, 'N': 13312, 'K': 8896, 'rowMajorA': 'T', 'rowMajorB': 'N'}
52+
- {'M': 4096, 'N': 13312, 'K': 8896, 'rowMajorA': 'T', 'rowMajorB': 'T'}
53+
- {'M': 2048, 'N': 17792, 'K': 13312, 'rowMajorA': 'T', 'rowMajorB': 'N'}
54+
- {'M': 2048, 'N': 17792, 'K': 13312, 'rowMajorA': 'T', 'rowMajorB': 'T'}
55+
- {'M': 1024, 'N': 13312, 'K': 1664, 'rowMajorA': 'T', 'rowMajorB': 'N'}
56+
- {'M': 1024, 'N': 13312, 'K': 1664, 'rowMajorA': 'T', 'rowMajorB': 'T'}
57+
58+
- {'M': 8192, 'N': 1536, 'K': 5120, 'rowMajorA': 'T', 'rowMajorB': 'N'}
59+
- {'M': 8192, 'N': 1536, 'K': 5120, 'rowMajorA': 'T', 'rowMajorB': 'T'}
60+
- {'M': 8192, 'N': 5120, 'K': 1024, 'rowMajorA': 'T', 'rowMajorB': 'N'}
61+
- {'M': 8192, 'N': 5120, 'K': 1024, 'rowMajorA': 'T', 'rowMajorB': 'T'}
62+
- {'M': 8192, 'N': 1024, 'K': 5120, 'rowMajorA': 'T', 'rowMajorB': 'N'}
63+
- {'M': 8192, 'N': 1024, 'K': 5120, 'rowMajorA': 'T', 'rowMajorB': 'T'}
64+
- {'M': 32768, 'N': 5120, 'K': 512, 'rowMajorA': 'T', 'rowMajorB': 'N'}
65+
- {'M': 32768, 'N': 5120, 'K': 512, 'rowMajorA': 'T', 'rowMajorB': 'T'}
66+
- {'M': 512, 'N': 1536, 'K': 5120, 'rowMajorA': 'T', 'rowMajorB': 'N'}
67+
- {'M': 512, 'N': 1536, 'K': 5120, 'rowMajorA': 'T', 'rowMajorB': 'T'}
68+
- {'M': 512, 'N': 5120, 'K': 1024, 'rowMajorA': 'T', 'rowMajorB': 'N'}
69+
- {'M': 512, 'N': 5120, 'K': 1024, 'rowMajorA': 'T', 'rowMajorB': 'T'}
70+
- {'M': 512, 'N': 1024, 'K': 5120, 'rowMajorA': 'T', 'rowMajorB': 'N'}
71+
- {'M': 512, 'N': 1024, 'K': 5120, 'rowMajorA': 'T', 'rowMajorB': 'T'}
72+
- {'M': 2048, 'N': 5120, 'K': 512, 'rowMajorA': 'T', 'rowMajorB': 'N'}
73+
- {'M': 2048, 'N': 5120, 'K': 512, 'rowMajorA': 'T', 'rowMajorB': 'T'}
74+
75+
- {'M': 20196, 'N': 512, 'K': 1536, 'rowMajorA': 'T', 'rowMajorB': 'N'}
76+
- {'M': 20196, 'N': 512, 'K': 1536, 'rowMajorA': 'T', 'rowMajorB': 'T'}
77+
- {'M': 171792, 'N': 512, 'K': 1536, 'rowMajorA': 'T', 'rowMajorB': 'N'}
78+
- {'M': 171792, 'N': 512, 'K': 1536, 'rowMajorA': 'T', 'rowMajorB': 'T'}
79+
- {'M': 173318, 'N': 512, 'K': 1536, 'rowMajorA': 'T', 'rowMajorB': 'N'}
80+
- {'M': 173318, 'N': 512, 'K': 1536, 'rowMajorA': 'T', 'rowMajorB': 'T'}

0 commit comments

Comments
 (0)