Skip to content

Commit 46ca63c

Browse files
authored
Added support for lower precision b tensor in GEMMs (#738)
* Added support for lower precision b tensor in GEMMs. Modified pytest to check mixed precision configs * Formatting and minor testcase fix * More formatting * Added mixed precision cases to perf benchmarking * Formatting * Add configs. Fix benchmarking case * Address comments * Better formatting of test cases * Removed duplicate config
1 parent ade5449 commit 46ca63c

File tree

1 file changed

+47
-31
lines changed

1 file changed

+47
-31
lines changed

python/perf-kernels/gemm.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,18 @@
1111

1212
@triton.autotune(
1313
configs=[
14+
triton.Config(
15+
{
16+
'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2,
17+
'kpack': 2, 'matrix_instr_nonkdim': 16
18+
}, num_warps=8, num_stages=2),
1419
triton.Config(
1520
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 4, 'waves_per_eu': 0},
1621
num_warps=8, num_stages=2),
1722
triton.Config(
1823
{
19-
'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'waves_per_eu': 2,
20-
'kpack': 2, 'matrix_instr_nonkdim': 16
24+
'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2,
25+
'kpack': 1, 'matrix_instr_nonkdim': 16
2126
}, num_warps=8, num_stages=2),
2227
triton.Config(
2328
{
@@ -128,7 +133,7 @@ def matmul_kernel(
128133
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
129134
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
130135
if APPLY_SCALE:
131-
a_scale = tl.load(a_scale_ptr)
136+
a_scale = tl.load(a_scale_ptr) if (a_scale_ptr) else 1.0
132137
b_scale = tl.load(b_scale_ptr)
133138

134139
acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32
@@ -143,6 +148,8 @@ def matmul_kernel(
143148
else:
144149
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
145150
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
151+
# Type conversion to support mixed precision GEMMs where b is lower precision than a
152+
b = b.to(a_ptr.type.element_ty)
146153
accumulator += tl.dot(a, b, input_precision="ieee")
147154

148155
# Advance the ptrs to the next K block.
@@ -176,7 +183,10 @@ def leaky_relu(x):
176183
def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""):
177184
# Check constraints.
178185
assert a.shape[1] == b.shape[0], "Incompatible dimensions!!!"
179-
assert a.dtype == b.dtype, "Mixed dtype GEMMs are not supported!!!"
186+
assert (a.element_size()
187+
>= b.element_size()), "Mixed dtype GEMMs are only supported when data type of a is bigger than b!!!"
188+
assert (a.is_floating_point() == b.is_floating_point()
189+
), "GEMMs between float and integer type tensors are not supported!!!"
180190
M, K = a.shape
181191
K, N = b.shape
182192
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
@@ -262,32 +272,39 @@ def get_x_vals():
262272

263273
# Unit tests
264274
#TODO(vgokhale): Test activation.
275+
# yapf: disable
265276
@pytest.mark.parametrize(
266-
"M, N, K, in_dtype, out_dtype, col_a, col_b",
267-
[(*shape, in_dtype, out_dtype, col_a, col_b)
277+
"M, N, K, in_dtype_a, in_dtype_b, out_dtype, col_a, col_b",
278+
[(*shape, in_dtype_a, in_dtype_b, out_dtype, col_a, col_b)
268279
for shape in get_x_vals()
269-
for in_dtype, out_dtype in [('fp16', 'fp16'), ('bf16', 'bf16'), ('fp32', 'fp32'), (
270-
'fp8e4', 'fp16'), ('fp8e5', 'fp16'), ('int8', 'int8'), ('int8', 'int32')]
280+
for in_dtype_a, in_dtype_b, out_dtype in [
281+
('fp16', 'fp16', 'fp16'), ('bf16', 'bf16', 'bf16'), ('fp32', 'fp32', 'fp32'),
282+
('fp8e4', 'fp8e4', 'fp16'), ('fp8e5', 'fp8e5', 'fp16'), ('fp16', 'fp8e4', 'fp16'),
283+
('fp16', 'fp8e5', 'fp16'), ('bf16', 'fp8e4', 'bf16'), ('bf16', 'fp8e5', 'bf16'),
284+
('int8', 'int8', 'int8'), ('int8', 'int8', 'int32')]
271285
# Defines if a matrix is row or column major.
272286
for col_a in [True, False]
273287
for col_b in [True, False]])
274-
def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype):
275-
torch_in_dtype = name_to_torch_types[in_dtype]
276-
a, a_fp32, a_scale = gen_input(M, K, torch_in_dtype, col_a, 1, device='cuda')
277-
b, b_fp32, b_scale = gen_input(K, N, torch_in_dtype, col_b, 2, device='cuda')
288+
# yapf: enable
289+
def test_correctness(M, N, K, col_a, col_b, in_dtype_a, in_dtype_b, out_dtype):
290+
torch_in_dtype_a = name_to_torch_types[in_dtype_a]
291+
torch_in_dtype_b = name_to_torch_types[in_dtype_b]
292+
a, a_fp32, a_scale = gen_input(M, K, torch_in_dtype_a, col_a, 1, device='cuda')
293+
b, b_fp32, b_scale = gen_input(K, N, torch_in_dtype_b, col_b, 2, device='cuda')
278294
torch_out_dtype = name_to_torch_types[out_dtype]
279295
c = torch.empty((M, N), device=a.device, dtype=torch_out_dtype)
280296
# For 8-bit, we have scaled to the dynamic range of the data type.
281297
# This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10).
282298
# If we use fp16 it is possible to return infs from the torch.matmul call.
283-
if dtype_is_8_bit(torch_in_dtype):
299+
if dtype_is_8_bit(torch_in_dtype_a) or dtype_is_8_bit(torch_in_dtype_b):
284300
matmul(a, b, c, a_scale, b_scale, scale_a8_b8=True, activation="")
285301
torch_output = torch.matmul(a_fp32, b_fp32)
286-
torch_output = torch_output * a_scale * b_scale
302+
# Set a_scale to 1.0 if it is not set
303+
torch_output = torch_output * (a_scale or 1.0) * b_scale
287304
# For other dtypes, use the same torch matmul as the dtype.
288305
else:
289306
matmul(a, b, c, a_scale=None, b_scale=None, scale_a8_b8=False, activation="")
290-
torch_output = torch.matmul(a.to(torch_in_dtype), b.to(torch_in_dtype))
307+
torch_output = torch.matmul(a.to(torch_in_dtype_a), b.to(torch_in_dtype_b))
291308
if out_dtype == 'int8':
292309
torch.testing.assert_close(c.to(torch.float32),
293310
torch_output.to(torch.int8).to(torch.float32), atol=1e-3, rtol=1e-2)
@@ -297,7 +314,7 @@ def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype):
297314

298315
def get_type(provider):
299316
res = re.findall(r'\(.*?\)', provider)
300-
return res[0][1:-1]
317+
return res[0][1:-1].split('/', 1)
301318

302319

303320
@triton.testing.perf_report(
@@ -306,39 +323,38 @@ def get_type(provider):
306323
x_vals=get_x_vals(),
307324
line_arg='provider',
308325
line_vals=[
309-
'hipblaslt(fp16)', 'hipblaslt(bf16)', 'triton(fp16)', 'triton(bf16)', 'triton(int8)', 'triton(fp8e4)',
310-
'triton(fp8e5)'
326+
'hipblaslt(fp16/fp16)', 'hipblaslt(bf16/bf16)', 'triton(fp16/fp16)', 'triton(bf16/bf16)',
327+
'triton(int8/int8)', 'triton(fp8e4/fp8e4)', 'triton(fp8e5/fp8e5)', 'triton(fp16/fp8e4)',
328+
'triton(fp16/fp8e5)'
311329
],
312330
line_names=[
313-
"rocBLAS.Fp16", "rocBLAS.Bf16", "Triton.Fp16", "Triton.Bf16", "Triton.Int8", "Triton.Fp8E4", "Triton.Fp8E5"
331+
"rocBLAS.Fp16", "rocBLAS.Bf16", "Triton.Fp16", "Triton.Bf16", "Triton.Int8", "Triton.Fp8E4", "Triton.Fp8E5",
332+
"Triton.Fp16.Fp8E4", "Triton.Fp16.Fp8E5"
314333
],
315334
ylabel="TFLOPS",
316335
plot_name="matmul-performance",
317336
args={},
318337
))
319338
def benchmark(M, N, K, provider, model=None):
320-
in_dtype = name_to_torch_types[get_type(provider)]
321-
out_dtype = in_dtype
339+
in_dtype_a, in_dtype_b = [name_to_torch_types[x] for x in get_type(provider)]
340+
out_dtype = in_dtype_a
322341

323342
quantiles = [0.5, 0.2, 0.8]
324343
if 'hipblaslt' in provider:
325-
a = torch.randn((M, K), dtype=in_dtype, device='cuda')
326-
b = torch.randn((N, K), dtype=in_dtype, device='cuda')
344+
a = torch.randn((M, K), dtype=in_dtype_a, device='cuda')
345+
b = torch.randn((N, K), dtype=in_dtype_b, device='cuda')
327346
b = b.T
328347

329348
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
330349
else: # triton, different data types
331350
assert "triton" in provider
332-
a, _, a_scale = gen_input(M, K, in_dtype, False, 1, device='cuda')
333-
b, _, b_scale = gen_input(K, N, in_dtype, True, 2, device='cuda')
351+
a, _, a_scale = gen_input(M, K, in_dtype_a, False, 1, device='cuda')
352+
b, _, b_scale = gen_input(K, N, in_dtype_b, True, 2, device='cuda')
334353
# Allocates output.
335354
c = torch.empty((M, N), device=a.device, dtype=out_dtype)
336-
337-
if dtype_is_8_bit(in_dtype):
338-
a_scale = a_scale.item()
339-
b_scale = b_scale.item()
340-
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c, a_scale, b_scale, activation=""),
341-
quantiles=quantiles)
355+
scale_a8_b8 = dtype_is_8_bit(in_dtype_a) or dtype_is_8_bit(in_dtype_b)
356+
ms, min_ms, max_ms = triton.testing.do_bench(
357+
lambda: matmul(a, b, c, a_scale, b_scale, scale_a8_b8=scale_a8_b8, activation=""), quantiles=quantiles)
342358
global verbose
343359
if verbose:
344360
print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.best_config()})')

0 commit comments

Comments
 (0)