Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,25 @@
torch.manual_seed(123)
random.seed(123)

total_sm = 304

def is_hip_cdna3():
target = triton.runtime.driver.active.get_current_target()
return target.backend == 'hip' and target.arch == 'gfx942'


def is_hip_cdna4():
target = triton.runtime.driver.active.get_current_target()
return target.backend == 'hip' and target.arch == 'gfx950'


if is_hip_cdna3():
total_sm = 304
elif is_hip_cdna4():
total_sm = 256
else:
print("Unknown target")
exit(0)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we deal with MI308 80/64 CUs ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk. There must be a way (hopefully a torch API) to query the number of cus from the GPU. Let me find it.

print(f"total SMs: {total_sm}")


Expand Down Expand Up @@ -138,7 +156,8 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch.

## test for tiles that is not multipe of 304 tiles
#m, n, k = 4096, 4096, 8192 # some problem size to test
m, n, k = 8192, 8192, 8192 # some problem size to test
m, n, k = 8192, 8192, 512 # some problem size to test
#m, n, k = 8704, 8704, 8192 # some problem size to test
#m, n, k = 512, 512, 512 # some problem size to test

## memory bound sizes
Expand Down Expand Up @@ -171,7 +190,14 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch.
num_warps = 8
waves_per_eu = 0
mfmaInstrSize = 16
kpack = 2

if is_hip_cdna3():
kpack = 2
elif is_hip_cdna4():
kpack = 1
else:
print("Unknown target")
exit(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same questions for MI308


##for total_sm in range(1, 305):
## print(f"{total_sm=}")
Expand All @@ -195,7 +221,7 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch.
P = torch.zeros((total_sm, BLK_M * BLK_N), device="cuda", dtype=torch.float32)
C = matmul.apply(A, B, C, bias, P, locks, total_sm, BLK_M, BLK_N, BLK_K, gsize_m, two_tiles, num_stages, num_warps,
waves_per_eu, mfmaInstrSize, kpack)
#exit(0)
exit(0)
matmul.set_debug(False)
expected = A @ B

Expand Down
9 changes: 7 additions & 2 deletions python/perf-kernels/streamk/streamk_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def streamk_gemm(

acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32

for tile_id in range(pid, total_full_tiles, NUM_SMS):
for tile_id in tl.range(pid, total_full_tiles, NUM_SMS, flatten=True):
Copy link
Member

@xiaohuguo2023 xiaohuguo2023 Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you observed any perf improvement by enable this loop fusion ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, this is just an experiment and it doesn't compile. I filed a ticket for it https://github.com/ROCm/triton-internal/issues/784

num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
Expand All @@ -74,8 +74,10 @@ def streamk_gemm(
if not EVEN_K:
loop_k -= 1

tl.assume(loop_k > 1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, loop_k):
for k in tl.range(0, loop_k):
a = tl.load(tl.multiple_of(A_BASE, (1, 16)))
b = tl.load(tl.multiple_of(B_BASE, (16, 1)))
acc += tl.dot(a, b)
Expand Down Expand Up @@ -105,6 +107,9 @@ def streamk_gemm(
mask = (rm < M)[:, None] & (rn < N)[None, :]
tl.store(C_, c, mask=mask)

if STREAMK_TILES == 0:
return

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fact, I was trying to use STREAK_TILES to avoid this "if" with line 113 and 115, otherwise, we could use full tiles and total tiles for this if ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, when STREAM_TILES==0, start_iter == last_iter and we should not need the while loop. However, they are not constant, so compiler cannot optimize them out. I found this if a very nice way to early return :)

tl.assume(pid >= 0)
total_streamk_iters = STREAMK_TILES * iters_per_tile
streamk_iters_pcu = total_streamk_iters // NUM_SMS
Expand Down