Skip to content

Commit 7edd183

Browse files
committed
Update triton_config method.
(cherry picked from commit dfc1579) (cherry picked from commit 8f60456)
1 parent 3d716eb commit 7edd183

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,9 @@ def triton_config(
16301630
num_stages=1,
16311631
num_elements_per_warp=256,
16321632
min_elem_per_thread=0,
1633+
num_warps=None,
1634+
matrix_instr=None,
1635+
waves_per_eu=None,
16331636
) -> Config:
16341637
"""
16351638
Construct a pointwise triton config with some adjustment heuristics
@@ -1686,9 +1689,11 @@ def triton_config(
16861689
):
16871690
z *= 2
16881691

1689-
num_warps = _num_warps(
1690-
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
1691-
)
1692+
# Calculate num_warps if they are not hard passed to config
1693+
if num_warps is None:
1694+
num_warps = _num_warps(
1695+
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
1696+
)
16921697
# we are going to arrive at 2 warps only if bs was too small due to
16931698
# numel being too small. However to workaround some ptx bugs we still
16941699
# want at least 4 warps if there's enough elements per thread
@@ -1718,7 +1723,15 @@ def triton_config(
17181723
cfg["ZBLOCK"] = z
17191724
check_max_block(cfg)
17201725
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
1721-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1726+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
1727+
1728+
if torch.version.hip:
1729+
if matrix_instr is not None:
1730+
config.kwargs["matrix_instr_nonkdim"] = matrix_instr
1731+
if waves_per_eu is not None:
1732+
config.kwargs["waves_per_eu"] = waves_per_eu
1733+
1734+
return config
17221735

17231736

17241737
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:

0 commit comments

Comments
 (0)