Skip to content

Commit f916e47

Browse files
[GEMM] Stop running XeTLA (#4248)
As `XeTLA` continuous development has been stopped, the team decided to use other implementations as reference, e.g., `oneDNN` and `CUTLASS`. Signed-off-by: Whitney Tsang <[email protected]>
1 parent dca7748 commit f916e47

File tree

4 files changed

+5
-29
lines changed

4 files changed

+5
-29
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,17 +133,6 @@ jobs:
133133
python build_report.py $REPORTS/softmax-performance.csv $REPORTS/softmax-xetla-report.csv --benchmark softmax --compiler xetla --param_cols "N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
134134
135135
- name: Run Triton GEMM kernel benchmark
136-
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py') }}
137-
run: |
138-
cd benchmarks/triton_kernels_benchmark
139-
NEW_SHAPES=0 python gemm_benchmark.py --reports $REPORTS --n_runs $N_RUNS
140-
source ../../scripts/capture-hw-details.sh
141-
python build_report.py $REPORTS/matmul-performance.csv $REPORTS/gemm-triton-report.csv --benchmark gemm-legacy --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
142-
python build_report.py $REPORTS/matmul-performance.csv $REPORTS/gemm-xetla-report.csv --benchmark gemm-legacy --compiler xetla --param_cols "B,M,K,N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
143-
python build_report.py $REPORTS/matmul-performance.csv $REPORTS/gemm-onednn-report.csv --benchmark gemm-legacy --compiler onednn --param_cols "B,M,K,N" --tflops_col OneDNN-TFlops --hbm_col "OneDNN-GB/s" --tag $TAG
144-
python build_report.py $REPORTS/matmul-performance.csv $REPORTS/gemm-cutlass-report.csv --benchmark gemm-legacy --compiler cutlass --param_cols "B,M,K,N" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG
145-
146-
- name: Run Triton GEMM kernel benchmark - new shapes
147136
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_benchmark.py_newshapes')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py_newshapes') }}
148137
run: |
149138
cd benchmarks/triton_kernels_benchmark

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,15 +232,13 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
232232
return a_shape, b_shape
233233

234234

235-
NEW_X_VALS = [ #
235+
X_VALS = [ #
236+
[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]
237+
] + [ #
236238
[1, m, n, 4096] for m in [1, 8] for n in [1024, 4096, 6144, 14336, 28672, 128256]
237239
] + [ #
238240
[1, m, 4096, 14336] for m in [1, 8]
239241
] + [ #
240-
[1, 8192, 4096, 4096] #
241-
]
242-
243-
X_VALS = [[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [
244242
[1, 1, 13824, 5120],
245243
[1, 4, 12288, 4096],
246244
[1, 512, 8192, 8192],
@@ -261,6 +259,7 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
261259
[32, 4096, 128, 4096],
262260
[4096, 8, 128, 16384],
263261
[4096, 8, 16384, 128],
262+
[1, 8192, 4096, 4096],
264263
]
265264

266265
DEVICE_NAME = torch.xpu.get_device_name()
@@ -281,16 +280,13 @@ def is_enough_memory(x_val):
281280
return enough_memory
282281

283282

284-
if os.getenv('NEW_SHAPES', '1') == '1':
285-
X_VALS += NEW_X_VALS
286283
X_VALS = [x_val for x_val in X_VALS if is_enough_memory(x_val)]
287284

288285

289286
def get_benchmark(
290287
providers_filter: Optional[list[str]] = None,
291288
transpose_a=False,
292289
transpose_b=False,
293-
new_shapes=False,
294290
matmul_kernel=matmul_kernel_with_block_pointers,
295291
matmul_kernel_batched=matmul_kernel_with_block_pointers_batched,
296292
plot_name='matmul-performance',
@@ -303,10 +299,8 @@ def get_benchmark(
303299
'triton': 'Triton',
304300
'onednn': 'OneDNN',
305301
}
306-
# use_xetla and use_cutlass
302+
# use_cutlass
307303
if not (transpose_a or transpose_b):
308-
if not new_shapes:
309-
supported_providers['xetla'] = 'XeTLA'
310304
supported_providers['cutlass'] = 'CUTLASS'
311305
providers = benchmark_suite.filter_providers(supported_providers, providers_filter)
312306

@@ -457,6 +451,5 @@ def cutlass_invoker():
457451
_benchmark = get_benchmark(
458452
transpose_a=(os.getenv('TRANSPOSE_A', '0') == '1'),
459453
transpose_b=(os.getenv('TRANSPOSE_B', '0') == '1'),
460-
new_shapes=(os.getenv('NEW_SHAPES', '1') == '1'),
461454
)
462455
_benchmark.run(show_plots=False, print_data=True)

benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def get_benchmark(
117117
providers_filter: Optional[List[str]] = None,
118118
transpose_a=False,
119119
transpose_b=False,
120-
new_shapes=True,
121120
):
122121
return gemm_benchmark.get_benchmark(
123122
providers_filter=providers_filter,
@@ -126,14 +125,12 @@ def get_benchmark(
126125
plot_name='matmul-tensor-desc-performance',
127126
transpose_a=transpose_a,
128127
transpose_b=transpose_b,
129-
new_shapes=new_shapes,
130128
)
131129

132130

133131
if __name__ == '__main__':
134132
_benchmark = get_benchmark(
135133
transpose_a=(os.getenv('TRANSPOSE_A', '0') == '1'),
136134
transpose_b=(os.getenv('TRANSPOSE_B', '0') == '1'),
137-
new_shapes=(os.getenv('NEW_SHAPES', '1') == '1'),
138135
)
139136
_benchmark.run(show_plots=False, print_data=True)

benchmarks/triton_kernels_benchmark/gemm_tensor_of_ptr_benchmark.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def get_benchmark(
124124
providers_filter: Optional[List[str]] = None,
125125
transpose_a=False,
126126
transpose_b=False,
127-
new_shapes=True,
128127
):
129128
return gemm_benchmark.get_benchmark(
130129
providers_filter=providers_filter,
@@ -133,14 +132,12 @@ def get_benchmark(
133132
plot_name='matmul-tensor-of-ptr-performance',
134133
transpose_a=transpose_a,
135134
transpose_b=transpose_b,
136-
new_shapes=new_shapes,
137135
)
138136

139137

140138
if __name__ == '__main__':
141139
_benchmark = get_benchmark(
142140
transpose_a=(os.getenv('TRANSPOSE_A', '0') == '1'),
143141
transpose_b=(os.getenv('TRANSPOSE_B', '0') == '1'),
144-
new_shapes=(os.getenv('NEW_SHAPES', '1') == '1'),
145142
)
146143
_benchmark.run(show_plots=False, print_data=True)

0 commit comments

Comments
 (0)