|
26 | 26 | {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
27 | 27 | num_stages=s, num_warps=32) for s in [1, 2, 3] |
28 | 28 | ] + [ |
29 | | - triton.Config( |
30 | | - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
31 | | - num_stages=s, num_warps=32) for s in [2, 3, 4] |
| 29 | + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m}, |
| 30 | + num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in |
| 31 | + ([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)]) |
32 | 32 | ] + [ |
33 | 33 | triton.Config( |
34 | 34 | {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
35 | 35 | num_stages=s, num_warps=32) for s in [2] |
36 | 36 | ] + [ |
37 | | - triton.Config( |
38 | | - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, |
39 | | - num_stages=s, num_warps=32) for s in [2, 3] |
| 37 | + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m}, |
| 38 | + num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in |
| 39 | + ([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)]) |
40 | 40 | ], |
41 | 41 | key=['M', 'N', 'K'], |
42 | 42 | ) |
@@ -91,9 +91,9 @@ def matmul_kernel_with_block_pointers( |
91 | 91 | {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
92 | 92 | num_stages=s, num_warps=32) for s in [2, 3] |
93 | 93 | ] + [ |
94 | | - triton.Config( |
95 | | - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
96 | | - num_stages=s, num_warps=32) for s in [2] |
| 94 | + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m}, |
| 95 | + num_stages=s, num_warps=w) for s in [2] for (m, w) in |
| 96 | + ([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)]) |
97 | 97 | ] + [ |
98 | 98 | triton.Config( |
99 | 99 | {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
|
0 commit comments