Skip to content

Commit cac829d

Browse files
authored
Add microbenchmark for A@B^t (#2408)
PR adds microbenchmark for gemm with A@B^t, which closes #2414
1 parent c4ed65a commit cac829d

File tree

2 files changed

+303
-0
lines changed

2 files changed

+303
-0
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,17 @@ jobs:
159159
TAG=${{ inputs.tag || 'ci' }}-adv
160160
python ../../scripts/build_report.py $REPORTS/matmul-performance-adv-path.csv $REPORTS/gemm-triton-advanced-report.csv --benchmark gemm --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
161161
162+
- name: Run Triton GEMM (A@B^t) kernel benchmark
163+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
164+
run: |
165+
cd benchmarks/triton_kernels_benchmark
166+
python gemm_bt_benchmark.py --reports $REPORTS
167+
mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-bt.csv
168+
source ../../scripts/capture-hw-details.sh
169+
170+
TAG=${{ inputs.tag || 'ci' }}
171+
python ../../scripts/build_report.py $REPORTS/matmul-performance-bt.csv $REPORTS/gemm-bt-triton-report.csv --benchmark gemm-bt --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
172+
162173
- name: Run Triton GEMM (stream-k) kernel benchmark
163174
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
164175
run: |
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
"""
2+
Gemm with A@B^t benchmark
3+
====================================
4+
5+
This benchmark is modified from gemm_benchmark.py with added transpose.
6+
"""
7+
8+
import torch
9+
import triton
10+
import triton.language as tl
11+
12+
import triton_kernels_benchmark as benchmark_suit
13+
import xetla_kernel
14+
15+
if benchmark_suit.USE_IPEX_OPTION:
16+
import intel_extension_for_pytorch # type: ignore # noqa: F401
17+
18+
19+
@triton.autotune(
20+
configs=[
21+
triton.Config(
22+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
23+
num_stages=s, num_warps=32) for s in [1, 2, 3]
24+
] + [
25+
triton.Config(
26+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
27+
num_stages=s, num_warps=32) for s in [2, 3]
28+
] + [
29+
triton.Config(
30+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
31+
num_stages=s, num_warps=32) for s in [2]
32+
] + [
33+
triton.Config(
34+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
35+
num_stages=s, num_warps=32) for s in [2, 3]
36+
],
37+
key=['M', 'N', 'K'],
38+
)
39+
@triton.jit
40+
def matmul_kernel_with_block_pointers(
41+
# Pointers to matrices
42+
a_ptr, b_ptr, c_ptr,
43+
# Matrix dimensions
44+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
45+
# Stride variables
46+
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
47+
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
48+
stride_cm: tl.constexpr, stride_cn: tl.constexpr,
49+
# Meta-parameters
50+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
51+
pid = tl.program_id(axis=0)
52+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
53+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
54+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
55+
group_id = pid // num_pid_in_group
56+
first_pid_m = group_id * GROUP_SIZE_M
57+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
58+
pid_m = first_pid_m + (pid % group_size_m)
59+
pid_n = (pid % num_pid_in_group) // group_size_m
60+
61+
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
62+
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
63+
order=(1, 0))
64+
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
65+
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
66+
order=(1, 0))
67+
68+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
69+
for _ in range(0, K, BLOCK_SIZE_K):
70+
a = tl.load(a_block_ptr, boundary_check=(0, 1))
71+
b = tl.load(b_block_ptr, boundary_check=(0, 1))
72+
accumulator += tl.dot(a, b)
73+
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
74+
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
75+
c = accumulator.to(tl.float32)
76+
77+
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
78+
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
79+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
80+
tl.store(c_block_ptr, c, boundary_check=(0, 1))
81+
82+
83+
# pylint: disable=unused-argument
84+
@triton.autotune(
85+
configs=[
86+
triton.Config(
87+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
88+
num_stages=s, num_warps=32) for s in [2, 3]
89+
] + [
90+
triton.Config(
91+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
92+
num_stages=s, num_warps=32) for s in [2]
93+
] + [
94+
triton.Config(
95+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
96+
num_stages=s, num_warps=32) for s in [2]
97+
] + [
98+
triton.Config(
99+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
100+
num_stages=s, num_warps=32) for s in [2]
101+
] + [
102+
triton.Config(
103+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
104+
num_stages=s, num_warps=4) for s in [2]
105+
],
106+
key=['M', 'N', 'K'],
107+
)
108+
@triton.jit
109+
def matmul_kernel_with_block_pointers_batched(
110+
# Pointers to matrices
111+
a_ptr, b_ptr, c_ptr,
112+
# Matrix dimensions
113+
B: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
114+
# Stride variables
115+
stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, #
116+
stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
117+
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr,
118+
# Meta-parameters
119+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
120+
bid = tl.program_id(axis=0)
121+
pid = tl.program_id(axis=1)
122+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
123+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
124+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
125+
group_id = pid // num_pid_in_group
126+
first_pid_m = group_id * GROUP_SIZE_M
127+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
128+
pid_m = first_pid_m + (pid % group_size_m)
129+
pid_n = (pid % num_pid_in_group) // group_size_m
130+
131+
offset_a = bid.to(tl.int64) * stride_az
132+
offset_b = bid.to(tl.int64) * stride_bz
133+
134+
a_block_ptr = tl.make_block_ptr(base=a_ptr + offset_a, shape=(M, K), strides=(stride_am, stride_ak),
135+
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
136+
order=(1, 0))
137+
b_block_ptr = tl.make_block_ptr(base=b_ptr + offset_b, shape=(K, N), strides=(stride_bk, stride_bn),
138+
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
139+
order=(1, 0))
140+
141+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
142+
for _ in range(0, K, BLOCK_SIZE_K):
143+
a = tl.load(a_block_ptr, boundary_check=(0, 1))
144+
b = tl.load(b_block_ptr, boundary_check=(0, 1))
145+
accumulator += tl.dot(a, b)
146+
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
147+
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
148+
c = accumulator.to(tl.float32)
149+
150+
offset_c = bid.to(tl.int64) * stride_cz
151+
c_block_ptr = tl.make_block_ptr(base=c_ptr + offset_c, shape=(M, N), strides=(stride_cm, stride_cn),
152+
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
153+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
154+
tl.store(c_block_ptr, c, boundary_check=(0, 1))
155+
156+
157+
# We can now create a convenience wrapper function that only takes two input tensors,
158+
# and (1) checks any shape constraint; (2) launches the above kernel.
159+
def matmul(a, b, c):
160+
# Check constraints.
161+
if len(a.shape) == 3 and len(b.shape) == 3:
162+
assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension'
163+
assert a.shape[2] == b.shape[2], 'Incompatible dimensions'
164+
assert a.is_contiguous(), 'Matrix A must be contiguous'
165+
assert b.is_contiguous(), 'Matrix B must be contiguous'
166+
B, M, K = a.shape
167+
B, N, K = b.shape
168+
# 1D launch kernel where each block gets its own program.
169+
grid = lambda META: (
170+
B,
171+
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
172+
)
173+
matmul_kernel_with_block_pointers_batched[grid](
174+
a, b, c, #
175+
B, M, N, K, #
176+
a.stride(0), a.stride(1), a.stride(2), #
177+
b.stride(0), b.stride(2), b.stride(1), #
178+
c.stride(0), c.stride(1), c.stride(2))
179+
elif len(a.shape) == 2 and len(b.shape) == 2:
180+
assert a.shape[1] == b.shape[1], 'Incompatible dimensions'
181+
assert a.is_contiguous(), 'Matrix A must be contiguous'
182+
assert b.is_contiguous(), 'Matrix B must be contiguous'
183+
M, K = a.shape
184+
N, K = b.shape
185+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
186+
matmul_kernel_with_block_pointers[grid](
187+
a, b, c, #
188+
M, N, K, #
189+
a.stride(0), a.stride(1), #
190+
b.stride(1), b.stride(0), #
191+
c.stride(0), c.stride(1))
192+
else:
193+
assert False, 'Input matrixs dimensions mismatch'
194+
return c
195+
196+
197+
# Benchmark Performance
198+
@benchmark_suit.perf_report(
199+
benchmark_suit.Benchmark(
200+
# argument names to use as an x-axis for the plot
201+
x_names=['B', 'M', 'K', 'N'],
202+
# different possible values for `x_name`
203+
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + #
204+
[ #
205+
[1, 1, 5120, 13824], #
206+
[1, 4, 4096, 12288], #
207+
[1, 512, 8192, 8192], #
208+
[1, 512, 8192, 32768], #
209+
[1, 512, 32768, 8192], #
210+
[1, 1024, 16384, 8192], #
211+
[1, 1024, 28672, 8192], #
212+
[1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark works
213+
[1, 4096, 16384, 8192], #
214+
[1, 8192, 16384, 1024], #
215+
[1, 8192, 16384, 4096], #
216+
[1, 16384, 1024, 8192], #
217+
[1, 16384, 4096, 8192], #
218+
[1, 16384, 8192, 1024], #
219+
[1, 16384, 8192, 4096], #
220+
[4, 32768, 128, 4096], #
221+
[4, 32768, 4096, 128], #
222+
[32, 4096, 4096, 128], #
223+
[4096, 8, 128, 16384], #
224+
[4096, 8, 16384, 128]
225+
],
226+
line_arg='provider',
227+
# argument name whose value corresponds to a different line in the plot
228+
# possible values for `line_arg``
229+
line_vals=['triton'],
230+
# label name for the lines
231+
line_names=['Triton'],
232+
# line styles
233+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
234+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
235+
plot_name='matmul-performance',
236+
# name for the plot. Used also as a file name for saving the plot.
237+
args={},
238+
))
239+
def benchmark(B, M, N, K, provider):
240+
if B == 1:
241+
a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16)
242+
b = torch.rand((N, K), device='xpu', dtype=torch.bfloat16)
243+
else:
244+
a = torch.rand((B, M, K), device='xpu', dtype=torch.bfloat16)
245+
b = torch.rand((B, N, K), device='xpu', dtype=torch.bfloat16)
246+
247+
quantiles = [0.5, 0.0, 1.0]
248+
249+
if provider == 'onednn':
250+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, torch.transpose(b, -1, -2)),
251+
warmup=10, rep=10, quantiles=quantiles,
252+
fast_flush=False)
253+
elif provider == 'triton':
254+
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
255+
if len(a.shape) == 3:
256+
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
257+
else:
258+
assert len(a.shape) == 2, 'Expecting shape of length 2'
259+
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
260+
triton_fn = lambda: matmul(a, b, c)
261+
torch_fn = lambda: torch.matmul(a, torch.transpose(b, -1, -2)).to(torch.float32)
262+
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
263+
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
264+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
265+
fast_flush=False)
266+
elif provider == 'xetla':
267+
if B == 1:
268+
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
269+
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
270+
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
271+
else:
272+
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
273+
acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
274+
cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32)
275+
name = f'gemm_shape_{B}_{M}_{K}_{N}'
276+
func = getattr(xetla_kernel, name)
277+
xetla_fn = lambda: func(a, torch.transpose(b, -1, -2), c, acc, cnt)
278+
torch_fn = lambda: torch.matmul(a, torch.tranpose(b, -1, -2)).to(torch.float32)
279+
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
280+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
281+
fast_flush=False)
282+
else:
283+
raise NotImplementedError(f'Unsupported provider {provider}')
284+
285+
tflops = lambda ms: 2 * B * M * N * K * (1e-12) / (ms * 1e-3)
286+
gbps = lambda ms: B * (2 * (M * K + K * N) + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3)
287+
288+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
289+
290+
291+
if __name__ == '__main__':
292+
benchmark.run(show_plots=False, print_data=True)

0 commit comments

Comments
 (0)