Skip to content

Commit 0097de3

Browse files
committed
Set SM and kpack according to arch and early return if no streamK tiles
1 parent 1468103 commit 0097de3

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,23 @@
99
torch.manual_seed(123)
1010
random.seed(123)
1111

12-
total_sm = 304
12+
13+
def is_hip_cdna3():
14+
target = triton.runtime.driver.active.get_current_target()
15+
return target.backend == 'hip' and target.arch == 'gfx942'
16+
17+
def is_hip_cdna4():
18+
target = triton.runtime.driver.active.get_current_target()
19+
return target.backend == 'hip' and target.arch == 'gfx950'
20+
21+
if is_hip_cdna3():
22+
total_sm = 304
23+
elif is_hip_cdna4():
24+
total_sm = 256
25+
else:
26+
print("Unknown target")
27+
exit(0)
28+
1329
print(f"total SMs: {total_sm}")
1430

1531

@@ -171,7 +187,14 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch.
171187
num_warps = 8
172188
waves_per_eu = 0
173189
mfmaInstrSize = 16
174-
kpack = 2
190+
191+
if is_hip_cdna3():
192+
kpack = 2
193+
elif is_hip_cdna4():
194+
kpack = 1
195+
else:
196+
print("Unknown target")
197+
exit(0)
175198

176199
##for total_sm in range(1, 305):
177200
## print(f"{total_sm=}")

python/perf-kernels/streamk/streamk_kernel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def streamk_gemm(
105105
mask = (rm < M)[:, None] & (rn < N)[None, :]
106106
tl.store(C_, c, mask=mask)
107107

108+
if STREAMK_TILES == 0:
109+
return
110+
108111
tl.assume(pid >= 0)
109112
total_streamk_iters = STREAMK_TILES * iters_per_tile
110113
streamk_iters_pcu = total_streamk_iters // NUM_SMS

0 commit comments

Comments
 (0)