Skip to content

Commit b12d0dd

Browse files
authored
Add A^t@B benchmark (#2430)
Based on this feedback #2408 (review) Changed GEMM benchmark to include transposed matrices case. Closes #2424 Relates to #1795 A@B^t case is important because weight matrix is often stored in [M, K] format. For example, in https://pytorch.org/docs/stable/generated/torch.nn.Linear.html Right now we are about 1.5 times slower on XPU against raw torch for that case. A^t@B case is important because it's part of matmul backprop. Right now we are about 4 times slower on XPU against raw torch for that case.
1 parent 2202ca7 commit b12d0dd

File tree

3 files changed

+67
-320
lines changed

3 files changed

+67
-320
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,24 @@ jobs:
163163
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
164164
run: |
165165
cd benchmarks/triton_kernels_benchmark
166-
python gemm_bt_benchmark.py --reports $REPORTS
166+
TRANSPOSE_B=1 python gemm_benchmark.py --reports $REPORTS
167167
mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-bt.csv
168168
source ../../scripts/capture-hw-details.sh
169169
170170
TAG=${{ inputs.tag || 'ci' }}
171171
python ../../scripts/build_report.py $REPORTS/matmul-performance-bt.csv $REPORTS/gemm-bt-triton-report.csv --benchmark gemm-bt --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
172172
173+
- name: Run Triton GEMM (A^t@B) kernel benchmark
174+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
175+
run: |
176+
cd benchmarks/triton_kernels_benchmark
177+
TRANSPOSE_A=1 python gemm_benchmark.py --reports $REPORTS
178+
mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-at.csv
179+
source ../../scripts/capture-hw-details.sh
180+
181+
TAG=${{ inputs.tag || 'ci' }}
182+
python ../../scripts/build_report.py $REPORTS/matmul-performance-at.csv $REPORTS/gemm-at-triton-report.csv --benchmark gemm-at --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
183+
173184
- name: Run Triton GEMM (stream-k) kernel benchmark
174185
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
175186
run: |

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
To compare the performance to XeTLA kernel.
77
88
"""
9+
import os
910

1011
import torch
1112
import triton
@@ -17,6 +18,10 @@
1718
if benchmark_suit.USE_IPEX_OPTION:
1819
import intel_extension_for_pytorch # type: ignore # noqa: F401
1920

21+
TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1'
22+
TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1'
23+
use_xetla = not (TRANSPOSE_A or TRANSPOSE_B)
24+
2025

2126
@triton.autotune(
2227
configs=[
@@ -158,15 +163,22 @@ def matmul_kernel_with_block_pointers_batched(
158163

159164
# We can now create a convenience wrapper function that only takes two input tensors,
160165
# and (1) checks any shape constraint; (2) launches the above kernel.
161-
def matmul(a, b, c):
166+
def matmul(a, b, c, transpose_a=False, transpose_b=False):
167+
a_major, a_minor = -2, -1
168+
if transpose_a:
169+
a_major, a_minor = a_minor, a_major
170+
b_minor, b_major = -2, -1
171+
if transpose_b:
172+
b_major, b_minor = b_minor, b_major
173+
174+
assert a.shape[a_minor] == b.shape[b_minor], 'Incompatible dimensions'
175+
assert a.is_contiguous(), 'Matrix A must be contiguous'
176+
assert b.is_contiguous(), 'Matrix B must be contiguous'
177+
M, N, K = a.shape[a_major], b.shape[b_major], a.shape[a_minor]
162178
# Check constraints.
163179
if len(a.shape) == 3 and len(b.shape) == 3:
164180
assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension'
165-
assert a.shape[2] == b.shape[1], 'Incompatible dimensions'
166-
assert a.is_contiguous(), 'Matrix A must be contiguous'
167-
assert b.is_contiguous(), 'Matrix B must be contiguous'
168-
B, M, K = a.shape
169-
B, K, N = b.shape
181+
B = a.shape[0]
170182
# 1D launch kernel where each block gets its own program.
171183
grid = lambda META: (
172184
B,
@@ -175,27 +187,37 @@ def matmul(a, b, c):
175187
matmul_kernel_with_block_pointers_batched[grid](
176188
a, b, c, #
177189
B, M, N, K, #
178-
a.stride(0), a.stride(1), a.stride(2), #
179-
b.stride(0), b.stride(1), b.stride(2), #
190+
a.stride(0), a.stride(a_major), a.stride(a_minor), #
191+
b.stride(0), b.stride(b_minor), b.stride(b_major), #
180192
c.stride(0), c.stride(1), c.stride(2))
181193
elif len(a.shape) == 2 and len(b.shape) == 2:
182-
assert a.shape[1] == b.shape[0], 'Incompatible dimensions'
183-
assert a.is_contiguous(), 'Matrix A must be contiguous'
184-
assert b.is_contiguous(), 'Matrix B must be contiguous'
185-
M, K = a.shape
186-
K, N = b.shape
187194
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
188195
matmul_kernel_with_block_pointers[grid](
189196
a, b, c, #
190197
M, N, K, #
191-
a.stride(0), a.stride(1), #
192-
b.stride(0), b.stride(1), #
198+
a.stride(a_major), a.stride(a_minor), #
199+
b.stride(b_minor), b.stride(b_major), #
193200
c.stride(0), c.stride(1))
194201
else:
195202
assert False, 'Input matrixs dimensions mismatch'
196203
return c
197204

198205

206+
def get_shapes(B, M, N, K, transpose_a, transpose_b):
207+
a_shape = (M, K)
208+
if transpose_a:
209+
a_shape = (K, M)
210+
211+
b_shape = (K, N)
212+
if transpose_b:
213+
b_shape = (N, K)
214+
215+
if B != 1:
216+
a_shape = (B, *a_shape)
217+
b_shape = (B, *b_shape)
218+
return a_shape, b_shape
219+
220+
199221
# Benchmark Performance
200222
@benchmark_suit.perf_report(
201223
benchmark_suit.Benchmark(
@@ -228,9 +250,9 @@ def matmul(a, b, c):
228250
line_arg='provider',
229251
# argument name whose value corresponds to a different line in the plot
230252
# possible values for `line_arg``
231-
line_vals=['triton', 'xetla'],
253+
line_vals=['triton'] + (['xetla'] if use_xetla else []),
232254
# label name for the lines
233-
line_names=['Triton', 'XeTLA'],
255+
line_names=['Triton'] + (['XeTLA'] if use_xetla else []),
234256
# line styles
235257
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
236258
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
@@ -239,27 +261,33 @@ def matmul(a, b, c):
239261
args={},
240262
))
241263
def benchmark(B, M, N, K, provider):
242-
if B == 1:
243-
a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16)
244-
b = torch.rand((K, N), device='xpu', dtype=torch.bfloat16)
245-
else:
246-
a = torch.rand((B, M, K), device='xpu', dtype=torch.bfloat16)
247-
b = torch.rand((B, K, N), device='xpu', dtype=torch.bfloat16)
264+
a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
265+
266+
a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16)
267+
b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16)
248268

249269
quantiles = [0.5, 0.0, 1.0]
250270

271+
torch_a = a
272+
if TRANSPOSE_A:
273+
torch_a = torch.transpose(torch_a, -2, -1)
274+
275+
torch_b = b
276+
if TRANSPOSE_B:
277+
torch_b = torch.transpose(torch_b, -2, -1)
278+
251279
if provider == 'onednn':
252-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
253-
quantiles=quantiles)
280+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10,
281+
rep=10, quantiles=quantiles)
254282
elif provider == 'triton':
255283
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
256284
if len(a.shape) == 3:
257285
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
258286
else:
259287
assert len(a.shape) == 2, 'Expecting shape of length 2'
260288
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
261-
triton_fn = lambda: matmul(a, b, c)
262-
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
289+
triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
290+
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
263291
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
264292
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
265293
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,

0 commit comments

Comments
 (0)