Skip to content

Commit cd33eea

Browse files
committed
skip auto tuning for test
Signed-off-by: Hao Wu <[email protected]>
1 parent 4528425 commit cd33eea

File tree

1 file changed

+26
-16
lines changed
  • emerging_optimizers/triton_kernels

1 file changed

+26
-16
lines changed

emerging_optimizers/triton_kernels/syrk.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# type: ignore
16+
import sys
17+
1618
import torch
1719
import triton
1820
import triton.language as tl
21+
from absl import logging
1922

2023

2124
try:
@@ -76,23 +79,30 @@ def matmul_tma_set_block_size_hook(nargs: dict) -> None:
7679
nargs["d_t_desc"].block_shape = [TILE_N, TILE_M]
7780

7881

82+
_CONFIGS = [
83+
triton.Config(
84+
{"TILE_M": tm, "TILE_N": tn, "TILE_K": tk, "GROUP_SIZE_M": gm},
85+
num_warps=nw,
86+
num_stages=ns,
87+
num_ctas=nc,
88+
pre_hook=matmul_tma_set_block_size_hook,
89+
)
90+
for tm in (64, 128, 256)
91+
for tn in (64, 128, 256)
92+
for tk in (64, 128, 256)
93+
for gm in (2, 4, 8)
94+
for nw in (4, 8)
95+
for ns in (2, 3, 4)
96+
for nc in (1,)
97+
]
98+
99+
if "absl.testing" in sys.modules.keys():
100+
logging.warning("Running in absl.testing mode, disable autotune for triton.")
101+
_CONFIGS = _CONFIGS[:1]
102+
103+
79104
@triton.autotune(
80-
configs=[
81-
triton.Config(
82-
{"TILE_M": tm, "TILE_N": tn, "TILE_K": tk, "GROUP_SIZE_M": gm},
83-
num_warps=nw,
84-
num_stages=ns,
85-
num_ctas=nc,
86-
pre_hook=matmul_tma_set_block_size_hook,
87-
)
88-
for tm in (64, 128, 256)
89-
for tn in (64, 128, 256)
90-
for tk in (64, 128, 256)
91-
for gm in (2, 4, 8)
92-
for nw in (4, 8)
93-
for ns in (2, 3, 4)
94-
for nc in (1,)
95-
],
105+
configs=_CONFIGS,
96106
key=["N", "K", "TRANS", "WARP_SPECIALIZE"],
97107
prune_configs_by={"early_config_prune": prune_invalid_configs},
98108
)

0 commit comments

Comments
 (0)