Skip to content

Commit c243ada

Browse files
committed
Set SM and kpack according to arch and early return if no streamK tiles
1 parent 1468103 commit c243ada

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,25 @@
99
torch.manual_seed(123)
1010
random.seed(123)
1111

12-
total_sm = 304
12+
13+
def is_hip_cdna3():
14+
target = triton.runtime.driver.active.get_current_target()
15+
return target.backend == 'hip' and target.arch == 'gfx942'
16+
17+
18+
def is_hip_cdna4():
19+
target = triton.runtime.driver.active.get_current_target()
20+
return target.backend == 'hip' and target.arch == 'gfx950'
21+
22+
23+
if is_hip_cdna3():
24+
total_sm = 304
25+
elif is_hip_cdna4():
26+
total_sm = 256
27+
else:
28+
print("Unknown target")
29+
exit(0)
30+
1331
print(f"total SMs: {total_sm}")
1432

1533

@@ -171,7 +189,14 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch.
171189
num_warps = 8
172190
waves_per_eu = 0
173191
mfmaInstrSize = 16
174-
kpack = 2
192+
193+
if is_hip_cdna3():
194+
kpack = 2
195+
elif is_hip_cdna4():
196+
kpack = 1
197+
else:
198+
print("Unknown target")
199+
exit(0)
175200

176201
##for total_sm in range(1, 305):
177202
## print(f"{total_sm=}")

python/perf-kernels/streamk/streamk_kernel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def streamk_gemm(
105105
mask = (rm < M)[:, None] & (rn < N)[None, :]
106106
tl.store(C_, c, mask=mask)
107107

108+
if STREAMK_TILES == 0:
109+
return
110+
108111
tl.assume(pid >= 0)
109112
total_streamk_iters = STREAMK_TILES * iters_per_tile
110113
streamk_iters_pcu = total_streamk_iters // NUM_SMS

0 commit comments

Comments
 (0)