Skip to content

Commit c7e5db0

Browse files
committed
Triton perf improvements, added poi tuning config
1 parent 9db1b12 commit c7e5db0

File tree

3 files changed

+53
-9
lines changed

3 files changed

+53
-9
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,11 +1002,11 @@ def relu(x):
10021002

10031003
@staticmethod
10041004
def minimum(a, b):
1005-
return f"triton_helpers.minimum({a}, {b})"
1005+
return f"tl.minimum({a}, {b})"
10061006

10071007
@staticmethod
10081008
def maximum(a, b):
1009-
return f"triton_helpers.maximum({a}, {b})"
1009+
return f"tl.maximum({a}, {b})"
10101010

10111011
@staticmethod
10121012
def where(a, b, c):
@@ -1202,7 +1202,7 @@ def load_seed(name, offset):
12021202
@staticmethod
12031203
@maybe_upcast_float32()
12041204
def rsqrt(x):
1205-
return f"libdevice.rsqrt({x})"
1205+
return f"tl.rsqrt({x})"
12061206

12071207
@staticmethod
12081208
@maybe_upcast_float32()

torch/_inductor/runtime/hints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
1414
# NOTE: if these fail asserts submit a PR to increase them
1515
TRITON_MAX_BLOCK = {
16-
"X": 4096,
16+
"X": 8192,
1717
"Y": 1024,
1818
"Z": 1024,
1919
"R0_": 4096 * 16, # * 16 is multi-kernel only

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,9 @@ def triton_config(
15871587
num_stages=1,
15881588
num_elements_per_warp=256,
15891589
min_elem_per_thread=0,
1590+
num_warps=None,
1591+
matrix_instr=None,
1592+
waves_per_eu=None
15901593
) -> Config:
15911594
"""
15921595
Construct a pointwise triton config with some adjustment heuristics
@@ -1643,9 +1646,11 @@ def triton_config(
16431646
):
16441647
z *= 2
16451648

1646-
num_warps = _num_warps(
1647-
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
1648-
)
1649+
# Calculate num_waprs if they are not hard passed to config
1650+
if num_warps == None:
1651+
num_warps = _num_warps(
1652+
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
1653+
)
16491654
# we are going to arrive at 2 warps only if bs was too small due to
16501655
# numel being too small. However to workaround some ptx bugs we still
16511656
# want at least 4 warps if there's enough elements per thread
@@ -1675,7 +1680,15 @@ def triton_config(
16751680
cfg["ZBLOCK"] = z
16761681
check_max_block(cfg)
16771682
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
1678-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1683+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
1684+
1685+
if torch.version.hip:
1686+
if matrix_instr is not None:
1687+
config.kwargs["matrix_instr_nonkdim"] = matrix_instr
1688+
if waves_per_eu is not None:
1689+
config.kwargs["waves_per_eu"] = waves_per_eu
1690+
1691+
return config
16791692

16801693

16811694
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
@@ -1723,6 +1736,7 @@ def triton_config_reduction(
17231736
num_stages=1,
17241737
num_warps=None,
17251738
register_intensive=False,
1739+
waves_per_eu=None
17261740
) -> Config:
17271741
"""
17281742
Construct a reduction triton config with some adjustment heuristics
@@ -1766,7 +1780,13 @@ def total_numel() -> int:
17661780
cfg = _get_config({"x": x, **rnumels})
17671781
check_max_block(cfg)
17681782
check_config(cfg, xnumel=size_hints["x"])
1769-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1783+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
1784+
1785+
if torch.version.hip:
1786+
if waves_per_eu is not None:
1787+
config.kwargs["waves_per_eu"] = waves_per_eu
1788+
1789+
return config
17701790

17711791

17721792
def _get_config(numels: dict[str, int]) -> dict[str, int]:
@@ -1854,6 +1874,12 @@ def pointwise(
18541874
triton_config_with_settings(
18551875
size_hints, bs // 2, num_elements_per_warp=64
18561876
),
1877+
# triton_config_with_settings(
1878+
# size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
1879+
# ),
1880+
triton_config_with_settings(
1881+
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
1882+
),
18571883
*hinted_configs,
18581884
]
18591885
if len(size_hints) == 2:
@@ -1985,6 +2011,24 @@ def reduction(
19852011
assert triton_meta is not None
19862012

19872013
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
2014+
2015+
# Additional tuning confirgs for ROCm builds
2016+
# Add checks for reduction autotuning bools
2017+
# if torch.version.hip and inductor_meta.get("max_autotune"):
2018+
# configs = [
2019+
# triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
2020+
# triton_config_with_settings(
2021+
# size_hints, bs // 2, num_elements_per_warp=64
2022+
# ),
2023+
# # triton_config_with_settings(
2024+
# # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2025+
# # ),
2026+
# triton_config_with_settings(
2027+
# size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
2028+
# ),
2029+
# *hinted_configs,
2030+
# ]
2031+
19882032
return cached_autotune(
19892033
size_hints,
19902034
configs=configs,

0 commit comments

Comments
 (0)