|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | # type: ignore |
| 16 | +import sys |
| 17 | + |
16 | 18 | import torch |
17 | 19 | import triton |
18 | 20 | import triton.language as tl |
| 21 | +from absl import logging |
19 | 22 |
|
20 | 23 |
|
21 | 24 | try: |
@@ -76,23 +79,30 @@ def matmul_tma_set_block_size_hook(nargs: dict) -> None: |
76 | 79 | nargs["d_t_desc"].block_shape = [TILE_N, TILE_M] |
77 | 80 |
|
78 | 81 |
|
| 82 | +_CONFIGS = [ |
| 83 | + triton.Config( |
| 84 | + {"TILE_M": tm, "TILE_N": tn, "TILE_K": tk, "GROUP_SIZE_M": gm}, |
| 85 | + num_warps=nw, |
| 86 | + num_stages=ns, |
| 87 | + num_ctas=nc, |
| 88 | + pre_hook=matmul_tma_set_block_size_hook, |
| 89 | + ) |
| 90 | + for tm in (64, 128, 256) |
| 91 | + for tn in (64, 128, 256) |
| 92 | + for tk in (64, 128, 256) |
| 93 | + for gm in (2, 4, 8) |
| 94 | + for nw in (4, 8) |
| 95 | + for ns in (2, 3, 4) |
| 96 | + for nc in (1,) |
| 97 | +] |
| 98 | + |
| 99 | +if "absl.testing" in sys.modules.keys(): |
| 100 | + logging.warning("Running in absl.testing mode, disable autotune for triton.") |
| 101 | + _CONFIGS = _CONFIGS[:1] |
| 102 | + |
| 103 | + |
79 | 104 | @triton.autotune( |
80 | | - configs=[ |
81 | | - triton.Config( |
82 | | - {"TILE_M": tm, "TILE_N": tn, "TILE_K": tk, "GROUP_SIZE_M": gm}, |
83 | | - num_warps=nw, |
84 | | - num_stages=ns, |
85 | | - num_ctas=nc, |
86 | | - pre_hook=matmul_tma_set_block_size_hook, |
87 | | - ) |
88 | | - for tm in (64, 128, 256) |
89 | | - for tn in (64, 128, 256) |
90 | | - for tk in (64, 128, 256) |
91 | | - for gm in (2, 4, 8) |
92 | | - for nw in (4, 8) |
93 | | - for ns in (2, 3, 4) |
94 | | - for nc in (1,) |
95 | | - ], |
| 105 | + configs=_CONFIGS, |
96 | 106 | key=["N", "K", "TRANS", "WARP_SPECIALIZE"], |
97 | 107 | prune_configs_by={"early_config_prune": prune_invalid_configs}, |
98 | 108 | ) |
|
0 commit comments