diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py index 6a0da5c97d9f..769833a88daa 100644 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py +++ b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py @@ -9,7 +9,25 @@ torch.manual_seed(123) random.seed(123) -total_sm = 304 + +def is_hip_cdna3(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == 'hip' and target.arch == 'gfx942' + + +def is_hip_cdna4(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == 'hip' and target.arch == 'gfx950' + + +if is_hip_cdna3(): + total_sm = 304 +elif is_hip_cdna4(): + total_sm = 256 +else: + print("Unknown target") + exit(0) + print(f"total SMs: {total_sm}") @@ -138,7 +156,8 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch. ## test for tiles that is not multipe of 304 tiles #m, n, k = 4096, 4096, 8192 # some problem size to test -m, n, k = 8192, 8192, 8192 # some problem size to test +m, n, k = 8192, 8192, 512 # some problem size to test +#m, n, k = 8704, 8704, 8192 # some problem size to test #m, n, k = 512, 512, 512 # some problem size to test ## memory bound sizes @@ -171,7 +190,14 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch. num_warps = 8 waves_per_eu = 0 mfmaInstrSize = 16 -kpack = 2 + +if is_hip_cdna3(): + kpack = 2 +elif is_hip_cdna4(): + kpack = 1 +else: + print("Unknown target") + exit(0) ##for total_sm in range(1, 305): ## print(f"{total_sm=}") @@ -195,7 +221,7 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch. P = torch.zeros((total_sm, BLK_M * BLK_N), device="cuda", dtype=torch.float32) C = matmul.apply(A, B, C, bias, P, locks, total_sm, BLK_M, BLK_N, BLK_K, gsize_m, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack) -#exit(0) +exit(0) matmul.set_debug(False) expected = A @ B diff --git a/python/perf-kernels/streamk/streamk_kernel.py b/python/perf-kernels/streamk/streamk_kernel.py index 196ba6936308..a66df84c9711 100644 --- a/python/perf-kernels/streamk/streamk_kernel.py +++ b/python/perf-kernels/streamk/streamk_kernel.py @@ -48,7 +48,7 @@ def streamk_gemm( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 - for tile_id in range(pid, total_full_tiles, NUM_SMS): + for tile_id in tl.range(pid, total_full_tiles, NUM_SMS, flatten=True): num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M @@ -74,8 +74,10 @@ def streamk_gemm( if not EVEN_K: loop_k -= 1 + tl.assume(loop_k > 1) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - for k in range(0, loop_k): + for k in tl.range(0, loop_k): a = tl.load(tl.multiple_of(A_BASE, (1, 16))) b = tl.load(tl.multiple_of(B_BASE, (16, 1))) acc += tl.dot(a, b) @@ -105,6 +107,9 @@ def streamk_gemm( mask = (rm < M)[:, None] & (rn < N)[None, :] tl.store(C_, c, mask=mask) + if STREAMK_TILES == 0: + return + tl.assume(pid >= 0) total_streamk_iters = STREAMK_TILES * iters_per_tile streamk_iters_pcu = total_streamk_iters // NUM_SMS