|
90 | 90 | # Final Result |
91 | 91 | # ------------ |
92 | 92 |
|
| 93 | +import os |
| 94 | + |
93 | 95 | import torch |
94 | 96 |
|
95 | 97 | import triton |
96 | 98 | import triton.language as tl |
97 | 99 |
|
| 100 | +SMALL_GRF = os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' |
| 101 | + |
98 | 102 |
|
99 | 103 | @triton.autotune( |
100 | 104 | configs=[ |
101 | | - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2, |
102 | | - num_warps=32), |
103 | | - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=3, |
104 | | - num_warps=32), |
105 | | - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2, |
106 | | - num_warps=32), |
107 | | - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2, |
108 | | - num_warps=32), |
| 105 | + triton.Config( |
| 106 | + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 107 | + num_stages=s, num_warps=32) for s in [1, 2, 3] |
| 108 | + ] + [ |
| 109 | + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m}, |
| 110 | + num_stages=s, num_warps=w) |
| 111 | + for s in [2, 3, 4] |
| 112 | + for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)]) |
| 113 | + ] + [ |
| 114 | + triton.Config( |
| 115 | + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 116 | + num_stages=s, num_warps=32) for s in [2] |
| 117 | + ] + [ |
| 118 | + triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m}, |
| 119 | + num_stages=s, num_warps=w) |
| 120 | + for s in [2, 3] |
| 121 | + for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)]) |
109 | 122 | ], |
110 | 123 | key=['M', 'N', 'K'], |
111 | 124 | ) |
|
0 commit comments