Skip to content

Commit cf98f3a

Browse files
authored
[CI][microbenchmarks] Add onednn backend to GEMM with transposed matrices (#2445)
Without `BENCHMARKING_METHOD="ELAPSED_TIME"` I get ~6000TFLOPS GeoMean for onednn measurements. Closes #2456
1 parent 6138f11 commit cf98f3a

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ jobs:
167167
source ../../scripts/capture-hw-details.sh
168168
169169
python ../../scripts/build_report.py $REPORTS/matmul-performance-bt.csv $REPORTS/gemm-bt-triton-report.csv --benchmark gemm-bt --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
170+
python ../../scripts/build_report.py $REPORTS/matmul-performance-bt.csv $REPORTS/gemm-bt-triton-report.csv --benchmark gemm-bt --compiler onednn --param_cols "B,M,K,N" --tflops_col onednn-TFlops --hbm_col "onednn-GB/s" --tag $TAG
170171
171172
- name: Run Triton GEMM (A^t@B) kernel benchmark
172173
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
@@ -177,6 +178,7 @@ jobs:
177178
source ../../scripts/capture-hw-details.sh
178179
179180
python ../../scripts/build_report.py $REPORTS/matmul-performance-at.csv $REPORTS/gemm-at-triton-report.csv --benchmark gemm-at --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
181+
python ../../scripts/build_report.py $REPORTS/matmul-performance-at.csv $REPORTS/gemm-at-triton-report.csv --benchmark gemm-at --compiler onednn --param_cols "B,M,K,N" --tflops_col onednn-TFlops --hbm_col "onednn-GB/s" --tag $TAG
180182
181183
- name: Run Triton GEMM (stream-k) kernel benchmark
182184
if: ${{ steps.install.outcome == 'success' && !cancelled() }}

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import triton.language as tl
1414

1515
import triton_kernels_benchmark as benchmark_suit
16+
from triton_kernels_benchmark.benchmark_testing import do_bench_elapsed_time, BENCHMARKING_METHOD
17+
1618
import xetla_kernel
1719

1820
if benchmark_suit.USE_IPEX_OPTION:
@@ -250,9 +252,9 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
250252
line_arg='provider',
251253
# argument name whose value corresponds to a different line in the plot
252254
# possible values for `line_arg``
253-
line_vals=['triton'] + (['xetla'] if use_xetla else []),
255+
line_vals=['triton'] + (['xetla'] if use_xetla else ['onednn']),
254256
# label name for the lines
255-
line_names=['Triton'] + (['XeTLA'] if use_xetla else []),
257+
line_names=['Triton'] + (['XeTLA'] if use_xetla else ['onednn']),
256258
# line styles
257259
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
258260
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
@@ -277,8 +279,12 @@ def benchmark(B, M, N, K, provider):
277279
torch_b = torch.transpose(torch_b, -2, -1)
278280

279281
if provider == 'onednn':
280-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10,
281-
rep=10, quantiles=quantiles)
282+
do_bench = benchmark_suit.do_bench
283+
if BENCHMARKING_METHOD == 'PYTORCH_LEGACY_PROFILER_USING_IPEX':
284+
# Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method
285+
do_bench = do_bench_elapsed_time
286+
_, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10, rep=10,
287+
quantiles=quantiles, kernel_name='gemm_kernel')
282288
elif provider == 'triton':
283289
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
284290
if len(a.shape) == 3:

0 commit comments

Comments
 (0)