Skip to content

Commit 50df392

Browse files
[GEMM] Cleanup advanced path code (#3947)
Starting #3724, GEMM is no longer run on advanced path. This PR cleanup code added for advanced path. Signed-off-by: Whitney Tsang <[email protected]>
1 parent e9608d8 commit 50df392

File tree

2 files changed

+6
-21
lines changed

2 files changed

+6
-21
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import triton_kernels_benchmark as benchmark_suite
1717
from triton_kernels_benchmark import xetla_kernel
1818

19-
SMALL_GRF = os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0'
20-
2119

2220
@triton.autotune(
2321
configs=[
@@ -26,18 +24,14 @@
2624
num_stages=s, num_warps=32) for s in [1, 2, 3]
2725
] + [
2826
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
29-
num_stages=s, num_warps=w)
30-
for s in [2, 3, 4]
31-
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
27+
num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in ([('large', 32), ('small', 64)])
3228
] + [
3329
triton.Config(
3430
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
3531
num_stages=s, num_warps=32) for s in [2]
3632
] + [
3733
triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m},
38-
num_stages=s, num_warps=w)
39-
for s in [2, 3]
40-
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
34+
num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in ([('large', 32), ('small', 64)])
4135
],
4236
key=['M', 'N', 'K'],
4337
)
@@ -93,9 +87,7 @@ def matmul_kernel_with_block_pointers(
9387
num_stages=s, num_warps=32) for s in [2, 3]
9488
] + [
9589
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
96-
num_stages=s, num_warps=w)
97-
for s in [2]
98-
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
90+
num_stages=s, num_warps=w) for s in [2] for (m, w) in ([('large', 32), ('small', 64)])
9991
] + [
10092
triton.Config(
10193
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},

benchmarks/triton_kernels_benchmark/gemm_tensor_of_ptr_benchmark.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1'
1919
TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1'
2020
use_xetla = not (TRANSPOSE_A or TRANSPOSE_B)
21-
SMALL_GRF = os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0'
2221

2322

2423
@triton.autotune(
@@ -28,18 +27,14 @@
2827
num_stages=s, num_warps=32) for s in [1, 2, 3]
2928
] + [
3029
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
31-
num_stages=s, num_warps=w)
32-
for s in [2, 3, 4]
33-
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
30+
num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in ([('large', 32), ('small', 64)])
3431
] + [
3532
triton.Config(
3633
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
3734
num_stages=s, num_warps=32) for s in [2]
3835
] + [
3936
triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m},
40-
num_stages=s, num_warps=w)
41-
for s in [2, 3]
42-
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
37+
num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in ([('large', 32), ('small', 64)])
4338
],
4439
key=['M', 'N', 'K'],
4540
)
@@ -97,9 +92,7 @@ def matmul_kernel(
9792
num_stages=s, num_warps=32) for s in [2, 3]
9893
] + [
9994
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
100-
num_stages=s, num_warps=w)
101-
for s in [2]
102-
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
95+
num_stages=s, num_warps=w) for s in [2] for (m, w) in ([('large', 32), ('small', 64)])
10396
] + [
10497
triton.Config(
10598
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},

0 commit comments

Comments
 (0)