Skip to content

Commit 83871ea

Browse files
authored
Clean up GEMM kernel (#730)
1) Triton now assumes TF32 as default. Explicitly set F32 2) Change text rocblas to hipblasLT 3) Make hipblaslt GEMMs TN to be consistent with Triton 4) Set float8 dtype correctly fot gfx950 5) Remove couple of unused configs - these would never make sense.
1 parent 2a90b5b commit 83871ea

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

python/perf-kernels/gemm.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,9 @@
2727
triton.Config(
2828
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'waves_per_eu': 0},
2929
num_warps=8, num_stages=2),
30-
triton.Config(
31-
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
32-
num_warps=4, num_stages=2),
3330
triton.Config(
3431
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
3532
num_warps=8, num_stages=2),
36-
triton.Config(
37-
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 32, 'waves_per_eu': 2},
38-
num_warps=4, num_stages=2),
3933
],
4034
key=['M', 'N', 'K'],
4135
use_cuda_graph=True,
@@ -122,7 +116,7 @@ def matmul_kernel(
122116
else:
123117
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
124118
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
125-
accumulator += tl.dot(a, b)
119+
accumulator += tl.dot(a, b, input_precision="ieee")
126120

127121
# Advance the ptrs to the next K block.
128122
a_ptrs += BLOCK_SIZE_K * stride_ak
@@ -179,29 +173,36 @@ def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""):
179173
)
180174

181175

176+
def is_cdna4():
177+
return triton.runtime.driver.active.get_current_target().arch == 'gfx950'
178+
179+
180+
e5m2_type = torch.float8_e5m2 if is_cdna4() else torch.float8_e5m2fnuz
181+
e4m3_type = torch.float8_e4m3fn if is_cdna4() else torch.float8_e4m3fnuz
182+
182183
name_to_torch_types = {
183184
'int8': torch.int8,
184185
'int32': torch.int32,
185186
'fp16': torch.float16,
186187
'fp32': torch.float32,
187188
'bf16': torch.bfloat16,
188-
'fp8e5': torch.float8_e5m2fnuz,
189-
'fp8e4': torch.float8_e4m3fnuz,
189+
'fp8e5': e5m2_type,
190+
'fp8e4': e4m3_type,
190191
}
191192

192193
dtype_max = {
193194
dtype: (torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)).max
194195
for dtype in [
195-
torch.float8_e5m2fnuz,
196-
torch.float8_e4m3fnuz,
196+
e5m2_type,
197+
e4m3_type,
197198
torch.int8,
198199
]
199200
}
200201

201202

202203
def dtype_is_8_bit(dtype):
203-
return (dtype is torch.float8_e5m2fnuz) or \
204-
(dtype is torch.float8_e4m3fnuz) or \
204+
return (dtype is e5m2_type) or \
205+
(dtype is e4m3_type) or \
205206
(dtype is torch.int8)
206207

207208

@@ -278,7 +279,7 @@ def get_type(provider):
278279
x_vals=get_x_vals(),
279280
line_arg='provider',
280281
line_vals=[
281-
'rocblas(fp16)', 'rocblas(bf16)', 'triton(fp16)', 'triton(bf16)', 'triton(int8)', 'triton(fp8e4)',
282+
'hipblaslt(fp16)', 'hipblaslt(bf16)', 'triton(fp16)', 'triton(bf16)', 'triton(int8)', 'triton(fp8e4)',
282283
'triton(fp8e5)'
283284
],
284285
line_names=[
@@ -293,9 +294,10 @@ def benchmark(M, N, K, provider, model=None):
293294
out_dtype = in_dtype
294295

295296
quantiles = [0.5, 0.2, 0.8]
296-
if 'rocblas' in provider:
297+
if 'hipblaslt' in provider:
297298
a = torch.randn((M, K), dtype=in_dtype, device='cuda')
298-
b = torch.randn((K, N), dtype=in_dtype, device='cuda')
299+
b = torch.randn((N, K), dtype=in_dtype, device='cuda')
300+
b = b.T
299301

300302
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
301303
else: # triton, different data types

0 commit comments

Comments
 (0)