|
18 | 18 | TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1' |
19 | 19 | TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1' |
20 | 20 | use_xetla = not (TRANSPOSE_A or TRANSPOSE_B) |
| 21 | +SMALL_GRF = os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' and not TRANSPOSE_A |
21 | 22 |
|
22 | 23 |
|
23 | 24 | @triton.autotune( |
|
27 | 28 | num_stages=s, num_warps=32) for s in [1, 2, 3] |
28 | 29 | ] + [ |
29 | 30 | triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m}, |
30 | | - num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in |
31 | | - ([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)]) |
| 31 | + num_stages=s, num_warps=w) |
| 32 | + for s in [2, 3, 4] |
| 33 | + for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)]) |
32 | 34 | ] + [ |
33 | 35 | triton.Config( |
34 | 36 | {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
35 | 37 | num_stages=s, num_warps=32) for s in [2] |
36 | 38 | ] + [ |
37 | 39 | triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m}, |
38 | | - num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in |
39 | | - ([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)]) |
| 40 | + num_stages=s, num_warps=w) |
| 41 | + for s in [2, 3] |
| 42 | + for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)]) |
40 | 43 | ], |
41 | 44 | key=['M', 'N', 'K'], |
42 | 45 | ) |
@@ -92,8 +95,9 @@ def matmul_kernel_with_block_pointers( |
92 | 95 | num_stages=s, num_warps=32) for s in [2, 3] |
93 | 96 | ] + [ |
94 | 97 | triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m}, |
95 | | - num_stages=s, num_warps=w) for s in [2] for (m, w) in |
96 | | - ([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)]) |
| 98 | + num_stages=s, num_warps=w) |
| 99 | + for s in [2] |
| 100 | + for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)]) |
97 | 101 | ] + [ |
98 | 102 | triton.Config( |
99 | 103 | {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
|
0 commit comments