Skip to content

Commit 5f9e611

Browse files
committed
Triton perf improvements, added poi tuning config
1 parent bd74018 commit 5f9e611

File tree

3 files changed

+53
-9
lines changed

3 files changed

+53
-9
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,11 +1027,11 @@ def relu(x):
10271027

10281028
@staticmethod
10291029
def minimum(a, b):
1030-
return f"triton_helpers.minimum({a}, {b})"
1030+
return f"tl.minimum({a}, {b})"
10311031

10321032
@staticmethod
10331033
def maximum(a, b):
1034-
return f"triton_helpers.maximum({a}, {b})"
1034+
return f"tl.maximum({a}, {b})"
10351035

10361036
@staticmethod
10371037
def where(a, b, c):
@@ -1217,7 +1217,7 @@ def load_seed(name, offset):
12171217
@staticmethod
12181218
@maybe_upcast_float32()
12191219
def rsqrt(x):
1220-
return f"libdevice.rsqrt({x})"
1220+
return f"tl.rsqrt({x})"
12211221

12221222
@staticmethod
12231223
@maybe_upcast_float32()

torch/_inductor/runtime/hints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
1414
# NOTE: if these fail asserts submit a PR to increase them
1515
TRITON_MAX_BLOCK = {
16-
"X": 4096,
16+
"X": 8192,
1717
"Y": 1024,
1818
"Z": 1024,
1919
"R0_": 4096 * 16, # * 16 is multi-kernel only

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,6 +1987,9 @@ def triton_config(
19871987
num_stages=1,
19881988
num_elements_per_warp=256,
19891989
min_elem_per_thread=0,
1990+
num_warps=None,
1991+
matrix_instr=None,
1992+
waves_per_eu=None
19901993
) -> Config:
19911994
"""
19921995
Construct a pointwise triton config with some adjustment heuristics
@@ -2043,9 +2046,11 @@ def triton_config(
20432046
):
20442047
z *= 2
20452048

2046-
num_warps = _num_warps(
2047-
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
2048-
)
2049+
# Calculate num_waprs if they are not hard passed to config
2050+
if num_warps == None:
2051+
num_warps = _num_warps(
2052+
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
2053+
)
20492054
# we are going to arrive at 2 warps only if bs was too small due to
20502055
# numel being too small. However to workaround some ptx bugs we still
20512056
# want at least 4 warps if there's enough elements per thread
@@ -2075,7 +2080,15 @@ def triton_config(
20752080
cfg["ZBLOCK"] = z
20762081
check_max_block(cfg)
20772082
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
2078-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
2083+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
2084+
2085+
if torch.version.hip:
2086+
if matrix_instr is not None:
2087+
config.kwargs["matrix_instr_nonkdim"] = matrix_instr
2088+
if waves_per_eu is not None:
2089+
config.kwargs["waves_per_eu"] = waves_per_eu
2090+
2091+
return config
20792092

20802093

20812094
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
@@ -2123,6 +2136,7 @@ def triton_config_reduction(
21232136
num_stages=1,
21242137
num_warps=None,
21252138
register_intensive=False,
2139+
waves_per_eu=None
21262140
) -> Config:
21272141
"""
21282142
Construct a reduction triton config with some adjustment heuristics
@@ -2166,7 +2180,13 @@ def total_numel() -> int:
21662180
cfg = _get_config({"x": x, **rnumels})
21672181
check_max_block(cfg)
21682182
check_config(cfg, xnumel=size_hints["x"])
2169-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
2183+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
2184+
2185+
if torch.version.hip:
2186+
if waves_per_eu is not None:
2187+
config.kwargs["waves_per_eu"] = waves_per_eu
2188+
2189+
return config
21702190

21712191

21722192
def _get_config(numels: dict[str, int]) -> dict[str, int]:
@@ -2259,6 +2279,12 @@ def pointwise(
22592279
triton_config_with_settings(
22602280
size_hints, bs // 2, num_elements_per_warp=64
22612281
),
2282+
# triton_config_with_settings(
2283+
# size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2284+
# ),
2285+
triton_config_with_settings(
2286+
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
2287+
),
22622288
*hinted_configs,
22632289
]
22642290
if len(size_hints) == 2:
@@ -2491,6 +2517,24 @@ def reduction(
24912517
assert triton_meta is not None
24922518

24932519
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
2520+
2521+
# Additional tuning confirgs for ROCm builds
2522+
# Add checks for reduction autotuning bools
2523+
# if torch.version.hip and inductor_meta.get("max_autotune"):
2524+
# configs = [
2525+
# triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
2526+
# triton_config_with_settings(
2527+
# size_hints, bs // 2, num_elements_per_warp=64
2528+
# ),
2529+
# # triton_config_with_settings(
2530+
# # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2531+
# # ),
2532+
# triton_config_with_settings(
2533+
# size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
2534+
# ),
2535+
# *hinted_configs,
2536+
# ]
2537+
24942538
return cached_autotune(
24952539
size_hints,
24962540
configs=configs,

0 commit comments

Comments
 (0)