Skip to content

Commit d344dd3

Browse files
Add Stream K and Split K to regular CI (#2313)
CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/11129147870 --------- Co-authored-by: Whitney Tsang <[email protected]>
1 parent 2b57f07 commit d344dd3

File tree

4 files changed

+28
-11
lines changed

4 files changed

+28
-11
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ jobs:
115115
cd benchmarks/triton_kernels_benchmark
116116
python gemm_benchmark.py --reports $REPORTS
117117
mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-base.csv
118-
source ../../scripts/capture-hw-details.sh
119118
119+
source ../../scripts/capture-hw-details.sh
120120
TAG=${{ inputs.tag || 'ci' }}
121121
python ../../scripts/build_report.py $REPORTS/matmul-performance-base.csv $REPORTS/gemm-triton-report.csv --benchmark gemm --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
122122
python ../../scripts/build_report.py $REPORTS/matmul-performance-base.csv $REPORTS/gemm-xetla-report.csv --benchmark gemm --compiler xetla --param_cols "B,M,K,N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
@@ -133,8 +133,8 @@ jobs:
133133
python gemm_benchmark.py --reports $REPORTS
134134
mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-default-path.csv
135135
136-
TAG=${{ inputs.tag || 'ci' }}-dflt
137136
source ../../scripts/capture-hw-details.sh
137+
TAG=${{ inputs.tag || 'ci' }}-dflt
138138
python ../../scripts/build_report.py $REPORTS/matmul-performance-default-path.csv $REPORTS/gemm-triton-default-report.csv --benchmark gemm --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
139139
140140
- name: Run Triton GEMM kernel benchmark - advanced path
@@ -149,10 +149,28 @@ jobs:
149149
python gemm_benchmark.py --reports $REPORTS
150150
mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-adv-path.csv
151151
152-
TAG=${{ inputs.tag || 'ci' }}-adv
153152
source ../../scripts/capture-hw-details.sh
153+
TAG=${{ inputs.tag || 'ci' }}-adv
154154
python ../../scripts/build_report.py $REPORTS/matmul-performance-adv-path.csv $REPORTS/gemm-triton-advanced-report.csv --benchmark gemm --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
155155
156+
- name: Run Triton GEMM (stream-k) kernel benchmark
157+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
158+
run: |
159+
cd benchmarks/triton_kernels_benchmark
160+
python gemm_streamk_benchmark.py --reports $REPORTS
161+
source ../../scripts/capture-hw-details.sh
162+
TAG=${{ inputs.tag || 'ci' }}
163+
python ../../scripts/build_report.py $REPORTS/matmul-streamk-performance.csv $REPORTS/gemm-streamk-triton-report.csv --benchmark gemm-streamk --compiler triton --param_cols "M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
164+
165+
- name: Run Triton GEMM (split-k) kernel benchmark
166+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
167+
run: |
168+
cd benchmarks/triton_kernels_benchmark
169+
python gemm_splitk_benchmark.py --reports $REPORTS
170+
source ../../scripts/capture-hw-details.sh
171+
TAG=${{ inputs.tag || 'ci' }}
172+
python ../../scripts/build_report.py $REPORTS/matmul-splitk-performance.csv $REPORTS/gemm-splitk-triton-report.csv --benchmark gemm-splitk --compiler triton --param_cols "M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
173+
156174
- name: Run Triton GEMM + PreOp (exp) kernel benchmark
157175
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
158176
run: |

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def matmul(a, b, c):
211211
[1, 512, 32768, 8192], #
212212
[1, 1024, 16384, 8192], #
213213
[1, 1024, 28672, 8192], #
214-
[1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark works
214+
[1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
215215
[1, 4096, 16384, 8192], #
216216
[1, 8192, 16384, 1024], #
217217
[1, 8192, 16384, 4096], #

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,8 @@ def forward(ctx, a, b, c, acc_dtype=None):
125125
x_names=['M', 'K', 'N'],
126126
x_vals=[
127127
[512, 32768, 8192],
128-
[3072, 4096, 3072],
129-
[4096, 4096, 4096],
130128
[1024, 28672, 8192],
129+
[3072, 4096, 3072],
131130
],
132131
line_arg='provider',
133132
# argument name whose value corresponds to a different line in the plot

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,22 +271,22 @@ def benchmark(M, N, K, provider):
271271
quantiles = [0.5, 0.0, 1.0]
272272

273273
if provider == 'onednn':
274-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
275-
quantiles=quantiles)
274+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
275+
quantiles=quantiles)
276276
elif provider == 'triton':
277277
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
278278
triton_fn = lambda: matmul(a, b, c)
279279
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
280280
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
281-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
282-
kernel_name=['first_wave', 'full_tiles'])
281+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
282+
kernel_name=['first_wave', 'full_tiles'])
283283
else:
284284
raise NotImplementedError(f'Unsupported provider {provider}')
285285

286286
tflops = lambda mean: 2 * M * N * K * (1e-12) / (mean * 1e-3)
287287
gbps = lambda mean: 2 * (M * K + K * N) + 4.0 * (M * N) * (1e-9) / (mean * 1e-3)
288288

289-
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
289+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
290290

291291

292292
if __name__ == '__main__':

0 commit comments

Comments
 (0)