Skip to content

Commit e27d722

Browse files
authored
Add int8 to gemm w/ addmatrix (#3040)
Update the gemm addmatrix benchmark to support int8 inputs as well as bfloat16. Exclude all int8 shapes from correctness testing becasue PyTorch matmul does not support int8 on GPU yet.
1 parent e8e47af commit e27d722

File tree

2 files changed

+70
-25
lines changed

2 files changed

+70
-25
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,21 @@ jobs:
206206
source ../../scripts/capture-hw-details.sh
207207
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-gelu.csv $REPORTS/gemm-postop-gelu-triton-report.csv --benchmark gemm-postop-gelu --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
208208
209-
- name: Run Triton GEMM + PostOp (add matrix) kernel benchmark
209+
- name: Run Triton GEMM + PostOp (add matrix) kernel benchmark bfloat16
210210
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark.py') }}
211211
run: |
212212
cd benchmarks/triton_kernels_benchmark
213213
python gemm_postop_addmatrix_benchmark.py --reports $REPORTS
214214
source ../../scripts/capture-hw-details.sh
215-
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix.csv $REPORTS/gemm-postop-addmatrix-triton-report.csv --benchmark gemm-postop-addmatrix --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
215+
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix-bfloat16.csv $REPORTS/gemm-postop-addmatrix-bfloat16-triton-report.csv --benchmark gemm-postop-addmatrix --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
216+
217+
- name: Run Triton GEMM + PostOp (add matrix) kernel benchmark int8
218+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark.py') }}
219+
run: |
220+
cd benchmarks/triton_kernels_benchmark
221+
INT8_ONLY=1 python gemm_postop_addmatrix_benchmark.py --reports $REPORTS
222+
source ../../scripts/capture-hw-details.sh
223+
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix-int8.csv $REPORTS/gemm-postop-addmatrix-int8-triton-report.csv --benchmark gemm-postop-addmatrix-int8 --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
216224
217225
- name: Run Triton FA kernel benchmark
218226
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_fwd_benchmark.py') }}

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,33 @@
55
This benchmark is modified from gemm_benchmark.py to add a matrix to the output of the gemm operation.
66
77
"""
8+
import os
89

910
import torch
1011
import triton
1112
import triton.language as tl
1213

1314
import triton_kernels_benchmark as benchmark_suit
1415

16+
INT8_ONLY_OPTION = os.getenv('INT8_ONLY', '0') == '1'
17+
ALL_DTYPES_OPTION = os.getenv('ALL_DTYPES', '0') == '1'
18+
19+
20+
def dtypes():
21+
if ALL_DTYPES_OPTION:
22+
return [torch.bfloat16, torch.int8]
23+
if INT8_ONLY_OPTION:
24+
return [torch.int8]
25+
return [torch.bfloat16]
26+
27+
28+
def suffix():
29+
if ALL_DTYPES_OPTION:
30+
return 'all'
31+
if INT8_ONLY_OPTION:
32+
return 'int8'
33+
return 'bfloat16'
34+
1535

1636
@triton.autotune(
1737
configs=[
@@ -43,7 +63,8 @@ def matmul_kernel_with_block_pointers(
4363
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
4464
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
4565
stride_cm: tl.constexpr, stride_cn: tl.constexpr, #
46-
stride_dm: tl.constexpr, stride_dn: tl.constexpr,
66+
stride_dm: tl.constexpr, stride_dn: tl.constexpr, #
67+
ACCUMULATOR_DTYPE: tl.constexpr,
4768
# Meta-parameters
4869
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
4970
pid = tl.program_id(axis=0)
@@ -63,7 +84,7 @@ def matmul_kernel_with_block_pointers(
6384
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
6485
order=(1, 0))
6586

66-
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
87+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
6788
for _ in range(0, K, BLOCK_SIZE_K):
6889
a = tl.load(a_block_ptr, boundary_check=(0, 1))
6990
b = tl.load(b_block_ptr, boundary_check=(0, 1))
@@ -117,7 +138,8 @@ def matmul_kernel_with_block_pointers_batched(
117138
stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, #
118139
stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
119140
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, #
120-
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr,
141+
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, #
142+
ACCUMULATOR_DTYPE: tl.constexpr,
121143
# Meta-parameters
122144
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
123145
bid = tl.program_id(axis=0)
@@ -141,7 +163,7 @@ def matmul_kernel_with_block_pointers_batched(
141163
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
142164
order=(1, 0))
143165

144-
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
166+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
145167
for _ in range(0, K, BLOCK_SIZE_K):
146168
a = tl.load(a_block_ptr, boundary_check=(0, 1))
147169
b = tl.load(b_block_ptr, boundary_check=(0, 1))
@@ -185,7 +207,8 @@ def matmul(a, b, d, c):
185207
a.stride(0), a.stride(1), a.stride(2), #
186208
b.stride(0), b.stride(1), b.stride(2), #
187209
c.stride(0), c.stride(1), c.stride(2), #
188-
d.stride(0), d.stride(1), d.stride(2))
210+
d.stride(0), d.stride(1), d.stride(2), #
211+
tl.float32 if a.dtype.is_floating_point else tl.int32)
189212
elif len(a.shape) == 2 and len(b.shape) == 2:
190213
assert a.shape[1] == b.shape[0], 'Incompatible dimensions'
191214
assert a.is_contiguous(), 'Matrix A must be contiguous'
@@ -199,7 +222,8 @@ def matmul(a, b, d, c):
199222
a.stride(0), a.stride(1), #
200223
b.stride(0), b.stride(1), #
201224
c.stride(0), c.stride(1), #
202-
d.stride(0), d.stride(1))
225+
d.stride(0), d.stride(1), #
226+
tl.float32 if a.dtype.is_floating_point else tl.int32)
203227
else:
204228
assert False, 'Input matrixs dimensions mismatch'
205229
return c
@@ -209,10 +233,10 @@ def matmul(a, b, d, c):
209233
@benchmark_suit.perf_report(
210234
benchmark_suit.Benchmark(
211235
# argument names to use as an x-axis for the plot
212-
x_names=['B', 'M', 'K', 'N'],
236+
x_names=['B', 'M', 'K', 'N', 'dtype'],
213237
# different possible values for `x_name`
214-
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + #
215-
[ #
238+
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in dtypes()] + #
239+
[[*shape, dtype] for shape in [ #
216240
[1, 1, 5120, 13824], #
217241
[1, 4, 4096, 12288], #
218242
[1, 512, 8192, 8192], #
@@ -232,8 +256,8 @@ def matmul(a, b, d, c):
232256
[4, 32768, 4096, 128], #
233257
[32, 4096, 4096, 128], #
234258
[4096, 8, 128, 16384], #
235-
[4096, 8, 16384, 128]
236-
],
259+
[4096, 8, 16384, 128] #
260+
] for dtype in dtypes()],
237261
line_arg='provider',
238262
# argument name whose value corresponds to a different line in the plot
239263
# possible values for `line_arg``
@@ -243,33 +267,46 @@ def matmul(a, b, d, c):
243267
# line styles
244268
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
245269
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
246-
plot_name='matmul-performance-postop-addmatrix',
270+
plot_name='matmul-performance-postop-addmatrix' + '-' + suffix(),
247271
# name for the plot. Used also as a file name for saving the plot.
248272
args={},
249273
))
250-
def benchmark(B, M, N, K, provider):
274+
def benchmark(B, M, N, K, dtype, provider):
275+
res_dtype = torch.float32 if dtype.is_floating_point else torch.int32
276+
if dtype.is_floating_point:
277+
rand = lambda shape, dtype: torch.rand(shape, device='xpu', dtype=dtype)
278+
else:
279+
rand = lambda shape, dtype: torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype)
251280
if B == 1:
252-
a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16)
253-
b = torch.rand((K, N), device='xpu', dtype=torch.bfloat16)
254-
d = torch.rand((M, N), device='xpu', dtype=torch.float32)
281+
a = rand((M, K), dtype)
282+
b = rand((K, N), dtype)
283+
d = rand((M, N), res_dtype)
255284
else:
256-
a = torch.rand((B, M, K), device='xpu', dtype=torch.bfloat16)
257-
b = torch.rand((B, K, N), device='xpu', dtype=torch.bfloat16)
258-
d = torch.rand((B, M, N), device='xpu', dtype=torch.float32)
285+
a = rand((B, M, K), dtype)
286+
b = rand((B, K, N), dtype)
287+
d = rand((B, M, N), res_dtype)
259288

260289
quantiles = [0.5, 0.0, 1.0]
261290

262291
if provider == 'triton':
263292
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
264293
if len(a.shape) == 3:
265-
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
294+
c = torch.empty((B, M, N), device='xpu', dtype=res_dtype)
266295
else:
267296
assert len(a.shape) == 2, 'Expecting shape of length 2'
268-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
297+
c = torch.empty((M, N), device='xpu', dtype=res_dtype)
269298
triton_fn = lambda: matmul(a, b, d, c)
270-
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
299+
# Torch does not support integer calculation in matmul
300+
torch_device = 'xpu' if dtype.is_floating_point else 'cpu'
301+
torch_dtype = dtype if dtype.is_floating_point else res_dtype
302+
torch_fn = lambda: torch.matmul(a.to(device=torch_device, dtype=torch_dtype),
303+
b.to(device=torch_device, dtype=torch_dtype)).to(device='xpu', dtype=res_dtype
304+
) + d
271305
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
272-
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
306+
if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048],
307+
[1, 512, 8192, 32768], [4, 32768, 4096, 128]]:
308+
# torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime
309+
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
273310
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
274311
quantiles=quantiles)
275312
else:

0 commit comments

Comments
 (0)