Skip to content

Commit 2ed2c59

Browse files
authored
Adding Stream-K persistent matmul kernel (#309)
Differential Revision: D77983954 Pull Request resolved: #324
1 parent 66b557c commit 2ed2c59

File tree

2 files changed

+281
-11
lines changed

2 files changed

+281
-11
lines changed

tritonbench/operators/gemm/operator.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from tritonbench.operators.gemm.kernels import matmul as kernels
1212
from tritonbench.operators.gemm.partition_k import matmul_partition_k
13-
from tritonbench.operators.gemm.stream_k import streamk_matmul
13+
from tritonbench.operators.gemm.stream_k import streamk_amd_matmul, streamk_cuda_matmul
1414
from tritonbench.operators.gemm.warp_spec_persistent_matmul import (
1515
blackwell_matmul_descriptor_persistent,
1616
blackwell_matmul_tma,
@@ -311,12 +311,19 @@ def pt2_matmul_maxautotune(self, a, b, bias) -> Callable:
311311

312312
return lambda: compiled(a, b)
313313

314-
@register_benchmark()
314+
@register_benchmark(enabled=not is_cuda())
315315
def streamk_matmul(self, a, b, bias) -> Callable:
316-
if bias is not None:
317-
return lambda: streamk_matmul(a, b, bias)
318-
else:
319-
return lambda: streamk_matmul(a, b)
316+
return lambda: streamk_amd_matmul(a, b, bias) if bias else streamk_amd_matmul(a, b)
317+
318+
@register_benchmark(enabled=is_cuda())
319+
def streamk_matmul(self, a, b, bias) -> Callable:
320+
print(f"Testing shape: {a.shape} x {b.shape}...")
321+
streamk = torch.matmul(a, b)
322+
b = b.T.contiguous()
323+
baseline = streamk_cuda_matmul(a, b)
324+
if not torch.allclose(streamk, baseline):
325+
print(f"StreamK matmul on {a.shape} x {b.shape} result does not match baseline matmul result. Max abs(streamk/baseline - 1): {torch.max(torch.abs(streamk / baseline - 1))}")
326+
return lambda: streamk_cuda_matmul(a, b) + bias if bias else streamk_cuda_matmul(a, b)
320327

321328
@register_benchmark(enabled=is_cuda())
322329
def pt2_cutlass_matmul(self, a, b, bias) -> Callable:
@@ -335,7 +342,7 @@ def pt2_cutlass_matmul(self, a, b, bias) -> Callable:
335342
compiled(a, b)
336343
return lambda: compiled(a, b)
337344

338-
@register_benchmark()
345+
@register_benchmark(enabled=False)
339346
def matmul_decompose_k(self, a, b, bias) -> Callable:
340347
def decompose_func(a_in, b_in):
341348
M, K = a_in.shape

tritonbench/operators/gemm/stream_k.py

Lines changed: 267 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,24 @@
1212

1313
import triton
1414
import triton.language as tl
15+
from triton.tools.tensor_descriptor import TensorDescriptor
1516

16-
from tritonbench.utils.env_utils import is_hip_mi300
17+
from tritonbench.utils.env_utils import is_cuda, is_fbcode, is_hip_mi300
1718

1819
from .triton_matmul_configs import get_full_amd_config_space
1920

21+
if not is_fbcode():
22+
if is_cuda():
23+
from triton._C.libtriton import nvidia
24+
25+
cublas_workspace = torch.empty(
26+
32 * 1024 * 1024, device="cuda", dtype=torch.uint8
27+
)
28+
cublas = nvidia.cublas.CublasLt(cublas_workspace)
29+
else:
30+
cublas = None
31+
32+
2033
if os.environ.get("FULL_AUTOTUNING_AMD", "0") == "1" and torch.version.hip is not None:
2134
tuning_configs = get_full_amd_config_space(False)
2235
else:
@@ -56,7 +69,7 @@
5669
}
5770
)
5871
@triton.jit
59-
def streamk_gemm(
72+
def streamk_amd_gemm(
6073
A,
6174
B,
6275
C,
@@ -274,7 +287,7 @@ def streamk_gemm(
274287
start_iter = end_iter
275288

276289

277-
def streamk_matmul(a, b, bias=None):
290+
def streamk_amd_matmul(a, b, bias=None):
278291
M, K = a.shape
279292
_, N = b.shape
280293
dtype = a.dtype
@@ -350,7 +363,7 @@ def streamk_matmul(a, b, bias=None):
350363
and c.stride(0) >= 0
351364
and c.stride(1) >= 0
352365
)
353-
streamk_gemm[(grids,)](
366+
streamk_amd_gemm[(grids,)](
354367
a,
355368
b,
356369
c,
@@ -376,3 +389,253 @@ def streamk_matmul(a, b, bias=None):
376389
# print(c)
377390
# print(a @ b)
378391
return c
392+
393+
def _matmul_launch_metadata(grid, kernel, args):
394+
ret = {}
395+
M, N, K = args["M"], args["N"], args["K"]
396+
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
397+
ret["flops8"] = 2.0 * M * N * K
398+
if "c_ptr" in args:
399+
bytes_per_elem = args["c_ptr"].element_size()
400+
else:
401+
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
402+
ret["bytes"] = bytes_per_elem * (M * K + N * K)
403+
return ret
404+
405+
406+
def matmul_get_configs(pre_hook=None):
407+
return [
408+
triton.Config(
409+
{"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK, "SK_BLOCK_K": skBK, "GROUP_M": 8},
410+
num_stages=s,
411+
num_warps=w,
412+
pre_hook=pre_hook,
413+
) #
414+
for BM in [128, 256] #
415+
for BN in [128, 256] #
416+
for BK in [32, 64, 128] #
417+
for skBK in [16, 32, 64, 128] #
418+
for s in ([2, 3, 4]) #
419+
for w in [4, 8] #
420+
]
421+
422+
def matmul_tma_set_block_size_hook(nargs):
423+
BLOCK_M = nargs["BLOCK_M"]
424+
BLOCK_N = nargs["BLOCK_N"]
425+
BLOCK_K = nargs["BLOCK_K"]
426+
nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
427+
nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K]
428+
nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
429+
430+
SK_BLOCK_K = nargs["SK_BLOCK_K"]
431+
nargs["a_desc_sk"].block_shape = [BLOCK_M, SK_BLOCK_K]
432+
nargs["b_desc_sk"].block_shape = [BLOCK_N, SK_BLOCK_K]
433+
434+
@triton.autotune(
435+
configs=matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook),
436+
key=["M", "N", "K"],
437+
)
438+
@triton.jit(launch_metadata=_matmul_launch_metadata)
439+
def streamk_cuda_gemm(
440+
# Pointer to a [BLOCK_M, BLOCK_K] TensorDescriptor
441+
a_desc,
442+
# Pointer to b [BLOCK_N, BLOCK_K] TensorDescriptor
443+
b_desc,
444+
# Pointer to a [BLOCK_M, SK_BLOCK_K] TensorDescriptor
445+
a_desc_sk,
446+
# Pointer to b [BLOCK_N, SK_BLOCK_K] TensorDescriptor
447+
b_desc_sk,
448+
# Pointer to c [BLOCK_M, BLOCK_N] TensorDescriptor
449+
c_desc,
450+
#
451+
M,
452+
N,
453+
K,
454+
# Tile dimensions both phases
455+
BLOCK_M: tl.constexpr,
456+
BLOCK_N: tl.constexpr,
457+
# K block dimension for DDP phase
458+
BLOCK_K: tl.constexpr,
459+
# K block dimension for Stream-K phase
460+
SK_BLOCK_K: tl.constexpr,
461+
# Group size for both phases
462+
GROUP_M: tl.constexpr,
463+
# TRUE if lowering for FP8 output
464+
FP8_OUTPUT: tl.constexpr,
465+
#
466+
ENABLE_BUFFER_OPS_ASSUMES: tl.constexpr,
467+
# Number of SMs on the device
468+
NUM_SMS: tl.constexpr,
469+
):
470+
if ENABLE_BUFFER_OPS_ASSUMES:
471+
tl.assume(M >= 0)
472+
tl.assume(N >= 0)
473+
tl.assume(K >= 0)
474+
475+
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
476+
477+
pid = tl.program_id(0)
478+
num_pid = tl.num_programs(0)
479+
num_tile_m = tl.cdiv(M, BLOCK_M)
480+
num_tile_n = tl.cdiv(N, BLOCK_N)
481+
num_tile_in_group = GROUP_M * num_tile_n
482+
483+
total_tiles = num_tile_m * num_tile_n
484+
485+
# number of full waves
486+
W = total_tiles // NUM_SMS
487+
# number of tiles in partial wave
488+
R = total_tiles % NUM_SMS
489+
if W == 0 or R == 0:
490+
total_ddp_tiles = num_pid
491+
streamk_sms = 0
492+
else:
493+
# hybrid Stream-K + DDP: DDP on first W-1 waves, Stream-K on last wave with full SM occupancy
494+
total_ddp_tiles = num_pid - NUM_SMS
495+
streamk_sms = NUM_SMS
496+
497+
498+
# ----------------------------------------------------------------------------
499+
# DDP phase
500+
# ----------------------------------------------------------------------------
501+
if pid < total_ddp_tiles:
502+
# Each DDP-assigned program computes 1 full tile
503+
group_id = pid // num_tile_in_group
504+
first_tile_m = group_id * GROUP_M
505+
group_size_m = min(num_tile_m - first_tile_m, GROUP_M)
506+
tile_m = first_tile_m + (pid % group_size_m)
507+
tile_n = (pid % num_tile_in_group) // group_size_m
508+
509+
offs_am = tile_m * BLOCK_M
510+
offs_bn = tile_n * BLOCK_N
511+
512+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
513+
514+
work_units_per_tile = tl.cdiv(K, BLOCK_K)
515+
516+
for k in tl.range(0, work_units_per_tile, warp_specialize=True):
517+
offs_k = k * BLOCK_K
518+
a = a_desc.load([offs_am, offs_k])
519+
b = b_desc.load([offs_bn, offs_k])
520+
accumulator = tl.dot(a, b.T, accumulator)
521+
522+
c = accumulator.to(dtype)
523+
c_desc.store([offs_am, offs_bn], c)
524+
525+
# ----------------------------------------------------------------------------
526+
# Stream-K phase
527+
# ----------------------------------------------------------------------------
528+
else:
529+
# index each Stream-K program as if it were a single SM (num_pid - total_ddp_tiles = streamk_sms)
530+
worker_id = pid - total_ddp_tiles
531+
532+
work_units_per_tile = tl.cdiv(K, SK_BLOCK_K)
533+
total_work_units = (total_tiles - total_ddp_tiles) * work_units_per_tile
534+
535+
# `evenly` distribute work units across SMs, with rem tiles assigned contiguously to the first rem programs
536+
base = total_work_units // streamk_sms
537+
rem = total_work_units % streamk_sms
538+
work = tl.where(worker_id < rem, base + 1, base)
539+
start = tl.where(
540+
worker_id < rem,
541+
worker_id * (base + 1),
542+
rem * (base + 1) + (worker_id - rem) * base
543+
)
544+
end = start + work - 1
545+
546+
# if start >= total_units, nothing to do
547+
if start >= total_work_units:
548+
return
549+
550+
# this program is responsible for computing tiles [(st_tile_streamk, en_k_streamk), (en_tile_streamk, en_k_streamk)]
551+
# *_k_streamk indexes along the K dimension and is one of {0, 1, ..., work_units_per_tile - 1}
552+
st_tile_streamk = start // work_units_per_tile + total_ddp_tiles
553+
st_k_streamk = start % work_units_per_tile
554+
en_tile_streamk = end // work_units_per_tile + total_ddp_tiles
555+
en_k_streamk = end % work_units_per_tile
556+
557+
for curr_tile in tl.range(st_tile_streamk, en_tile_streamk + 1, flatten=True):
558+
# Compute the tile associate with this work unit --- consistent with the DDP phase
559+
group_id = curr_tile // num_tile_in_group
560+
first_tile_m = group_id * GROUP_M
561+
group_size_m = min(num_tile_m - first_tile_m, GROUP_M)
562+
tile_m = first_tile_m + (curr_tile % group_size_m)
563+
tile_n = (curr_tile % num_tile_in_group) // group_size_m
564+
565+
offs_am = tile_m * BLOCK_M
566+
offs_bn = tile_n * BLOCK_N
567+
568+
# compute the start and end K index on this tile for this work unit
569+
curr_st_k = tl.where(curr_tile == st_tile_streamk, st_k_streamk, 0)
570+
curr_en_k = tl.where(curr_tile == en_tile_streamk, en_k_streamk, work_units_per_tile - 1)
571+
572+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
573+
574+
for k in tl.range(curr_st_k, curr_en_k + 1, warp_specialize=True):
575+
offs_k = k * SK_BLOCK_K
576+
# if same Tensor Descriptor shape is used for both phases, just use DDP's (better performance)
577+
if BLOCK_K == SK_BLOCK_K:
578+
a = a_desc.load([offs_am, offs_k])
579+
b = b_desc.load([offs_bn, offs_k])
580+
else:
581+
a = a_desc_sk.load([offs_am, offs_k])
582+
b = b_desc_sk.load([offs_bn, offs_k])
583+
accumulator = tl.dot(a, b.T, accumulator)
584+
585+
c = accumulator.to(dtype)
586+
587+
if curr_st_k == 0 and curr_en_k == work_units_per_tile - 1:
588+
c_desc.store([offs_am, offs_bn], c)
589+
else:
590+
# NOTE: known correctness issue with atomic_add
591+
c_desc.atomic_add([offs_am, offs_bn], c)
592+
593+
def streamk_cuda_matmul(a, b):
594+
assert a.dtype == b.dtype, "Incompatible dtypes"
595+
596+
M, K = a.shape
597+
N, K = b.shape
598+
dtype = a.dtype
599+
600+
c = torch.zeros((M, N), device=a.device, dtype=dtype)
601+
602+
dummy_block = [1, 1]
603+
a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
604+
b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
605+
c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
606+
607+
a_desc_sk = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
608+
b_desc_sk = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
609+
610+
num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
611+
612+
def grid(META):
613+
nonlocal a_desc, b_desc, c_desc
614+
BLOCK_M = META["BLOCK_M"]
615+
BLOCK_N = META["BLOCK_N"]
616+
num_tiles = triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)
617+
W = num_tiles // num_sms
618+
R = num_tiles % num_sms
619+
if W == 0 or R == 0:
620+
total_ddp_tiles = num_tiles
621+
streamk_sms = 0
622+
else:
623+
total_ddp_tiles = (W - 1) * num_sms
624+
streamk_sms = num_sms
625+
return (total_ddp_tiles + streamk_sms,)
626+
627+
628+
streamk_cuda_gemm[grid](
629+
a_desc,
630+
b_desc,
631+
a_desc_sk,
632+
b_desc_sk,
633+
c_desc, #
634+
M,
635+
N,
636+
K, #
637+
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
638+
ENABLE_BUFFER_OPS_ASSUMES=True, #
639+
NUM_SMS=num_sms #
640+
)
641+
return c

0 commit comments

Comments
 (0)