Skip to content

Commit 0a9dcdd

Browse files
committed
Reduction heursitics improvements for ROCm
(cherry picked from commit 9534cbd) (cherry picked from commit 189481e)
1 parent c1f8e99 commit 0a9dcdd

File tree

3 files changed

+259
-29
lines changed

3 files changed

+259
-29
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,11 +1002,17 @@ def relu(x):
10021002

10031003
@staticmethod
10041004
def minimum(a, b):
1005-
return f"triton_helpers.minimum({a}, {b})"
1005+
if torch.version.hip:
1006+
return f"tl.minimum({a}, {b})"
1007+
else:
1008+
return f"triton_helpers.minimum({a}, {b})"
10061009

10071010
@staticmethod
10081011
def maximum(a, b):
1009-
return f"triton_helpers.maximum({a}, {b})"
1012+
if torch.version.hip:
1013+
return f"tl.maximum({a}, {b})"
1014+
else:
1015+
return f"triton_helpers.maximum({a}, {b})"
10101016

10111017
@staticmethod
10121018
def where(a, b, c):
@@ -1202,7 +1208,10 @@ def load_seed(name, offset):
12021208
@staticmethod
12031209
@maybe_upcast_float32()
12041210
def rsqrt(x):
1205-
return f"libdevice.rsqrt({x})"
1211+
if torch.version.hip:
1212+
return f"tl.rsqrt({x})"
1213+
else:
1214+
return f"libdevice.rsqrt({x})"
12061215

12071216
@staticmethod
12081217
@maybe_upcast_float32()
@@ -3227,8 +3236,9 @@ def codegen_body(self):
32273236
loop_end = (
32283237
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
32293238
)
3239+
num_stages = ", num_stages = 2" if torch.version.hip else ""
32303240
self.body.writeline(
3231-
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
3241+
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):"
32323242
)
32333243
with self.body.indent(offset=level + 1):
32343244
self.iteration_ranges_codegen_header(tree, self.body)

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,7 @@ class triton:
11351135
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
11361136
# Raise the threshold to 16 to be safe.
11371137
# We should revisit this once we understand more of the source of register spills.
1138-
spill_threshold: int = 16
1138+
spill_threshold: int = 32 if torch.version.hip else 16
11391139

11401140
# Generate code containing the newer tl.make_block_ptr() API for loads/store
11411141
use_block_ptr = False

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 244 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
615615
# for some (complicated) custom Triton kernels, a register-spilling
616616
# config may yield the best latency.
617617
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
618-
"spill_threshold", 16
618+
"spill_threshold", 32 if torch.version.hip else 16
619619
):
620620
log.debug(
621621
"Skip config %s because of register spilling: %d",
@@ -1779,6 +1779,8 @@ def triton_config_reduction(
17791779
num_stages=1,
17801780
num_warps=None,
17811781
register_intensive=False,
1782+
waves_per_eu=None,
1783+
dynamic_scale_rblock=True,
17821784
) -> Config:
17831785
"""
17841786
Construct a reduction triton config with some adjustment heuristics
@@ -1822,7 +1824,18 @@ def total_numel() -> int:
18221824
cfg = _get_config({"x": x, **rnumels})
18231825
check_max_block(cfg)
18241826
check_config(cfg, xnumel=size_hints["x"])
1825-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1827+
config = InductorConfig(
1828+
cfg,
1829+
num_warps=num_warps,
1830+
num_stages=num_stages,
1831+
dynamic_scale_rblock=dynamic_scale_rblock,
1832+
)
1833+
1834+
if torch.version.hip:
1835+
if waves_per_eu is not None:
1836+
config.kwargs["waves_per_eu"] = waves_per_eu
1837+
1838+
return config
18261839

18271840

18281841
def _get_config(numels: dict[str, int]) -> dict[str, int]:
@@ -1833,7 +1846,9 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
18331846
return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()}
18341847

18351848

1836-
def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1):
1849+
def triton_config_tiled_reduction(
1850+
size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None
1851+
):
18371852
"""
18381853
Construct a tile reduction triton config with some adjustment
18391854
heuristics based on size_hints. Size_hints is a tuple of numels in
@@ -1866,7 +1881,11 @@ def total_numel() -> int:
18661881
num_warps = _num_warps(total_numel() // 256, min_num_warps=1)
18671882
check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1])
18681883
check_max_block(cfg)
1869-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1884+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
1885+
if torch.version.hip:
1886+
if waves_per_eu is not None:
1887+
config.kwargs["waves_per_eu"] = waves_per_eu
1888+
return config
18701889

18711890

18721891
def pointwise(
@@ -1992,6 +2011,11 @@ def _reduction_configs(
19922011
# Convert reductions to 1D, to simplify heuristics.
19932012
rnumel = get_total_reduction_numel(size_hints)
19942013

2014+
# Is max autotune enabled
2015+
max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get(
2016+
"max_autotune_pointwise"
2017+
)
2018+
19952019
register_intensive = False
19962020
MAX_R0_BLOCK = 2048
19972021
if (
@@ -2014,8 +2038,90 @@ def _reduction_configs(
20142038
MAX_R0_BLOCK = 1024
20152039
register_intensive = True
20162040

2017-
contiguous_config = triton_config_reduction(
2018-
size_hints,
2041+
def make_config(
2042+
x,
2043+
r,
2044+
num_warps=None,
2045+
num_stages=1,
2046+
register_intensive=False,
2047+
dynamic_scale_rblock=True,
2048+
waves_per_eu=None,
2049+
):
2050+
# For 3D case with tiling scores, create an adapted version
2051+
if "y" in size_hints:
2052+
assert "tiling_scores" in inductor_meta
2053+
return adapt_config_for_tiling(
2054+
size_hints,
2055+
inductor_meta["tiling_scores"],
2056+
x,
2057+
r,
2058+
num_warps=num_warps,
2059+
num_stages=num_stages,
2060+
register_intensive=register_intensive,
2061+
waves_per_eu=waves_per_eu,
2062+
)
2063+
else:
2064+
# For other cases, use the original function
2065+
return triton_config_reduction(
2066+
size_hints,
2067+
x,
2068+
r,
2069+
num_warps=num_warps,
2070+
num_stages=num_stages,
2071+
register_intensive=register_intensive,
2072+
waves_per_eu=waves_per_eu,
2073+
dynamic_scale_rblock=dynamic_scale_rblock,
2074+
)
2075+
2076+
def outer_config_opt():
2077+
# Default to 64 for vectorized loads
2078+
max_x_block, x_block = 256, 64
2079+
load_factor = inductor_meta.get("num_load", 0)
2080+
x = size_hints["x"]
2081+
num_warps = None
2082+
2083+
# Try to use all SMs with small x
2084+
if x <= 1024:
2085+
x_block = max(min(x // 128, 8), 2)
2086+
outer_r_block = min(rnumel, 64)
2087+
# Lower bound x = 1024, 1024 // 16 = 128 around # of SMs
2088+
elif x // 4096 <= 8:
2089+
x_block = 16
2090+
outer_r_block = 512 // x_block
2091+
elif num_dynamic > 1:
2092+
# Lots of compute with multiple dynamic shape per loop iteration
2093+
# Larger RBLOCK minimizes loop iteration
2094+
outer_r_block = max(min((rnumel // 64), 64), 8)
2095+
elif num_dynamic == 1:
2096+
# Dynamic shapes introduce a lot register pressure for indexing
2097+
outer_r_block = (
2098+
1
2099+
if load_factor >= 3
2100+
else min(next_power_of_2(max(rnumel, 128) // 128), 8)
2101+
)
2102+
else:
2103+
x_block = max(min(max_x_block, next_power_of_2(x // 4096)), x_block)
2104+
if load_factor < 4 or rnumel <= 128:
2105+
outer_r_block = 512 // x_block
2106+
else:
2107+
# Heavier reductions contain a lot more overhead per loop iteration
2108+
# We minimize the overhead by enlarging r block
2109+
if rnumel >= 2048:
2110+
outer_r_block = 64
2111+
else:
2112+
outer_r_block = 32
2113+
x_block = min(x_block, 32)
2114+
num_warps = 4
2115+
2116+
# Set register intensive to true by default as we try to maximize tiles with heuristic
2117+
return make_config(
2118+
x_block,
2119+
outer_r_block,
2120+
num_warps=num_warps,
2121+
register_intensive=register_intensive,
2122+
)
2123+
2124+
contiguous_config = make_config(
20192125
1,
20202126
rnumel if 256 <= rnumel < MAX_R0_BLOCK else MAX_R0_BLOCK,
20212127
register_intensive=register_intensive,
@@ -2029,27 +2135,141 @@ def _reduction_configs(
20292135
min(rnumel, MAX_R0_BLOCK),
20302136
register_intensive=register_intensive,
20312137
)
2032-
if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"):
2033-
pass # skip all these cases
2138+
2139+
outer_config = make_config(64, 8, register_intensive=register_intensive)
2140+
# TODO (paulzhan): Test heuristic on AMD and internal testing
2141+
# for correctness
2142+
if not torch.version.hip and not is_fbcode():
2143+
outer_config = outer_config_opt()
2144+
2145+
configs = []
2146+
2147+
if inductor_meta.get("add_persistent_rblock") and loads_and_red <= 8:
2148+
xnumel = max(4096 // rnumel, 1)
2149+
c = make_config(
2150+
xnumel,
2151+
rnumel,
2152+
register_intensive=register_intensive,
2153+
dynamic_scale_rblock=False,
2154+
)
2155+
configs.append(c)
2156+
2157+
result_configs = []
2158+
2159+
# For 3d tiling, default to more autotuning initially
2160+
if "y" in size_hints:
2161+
pass
2162+
elif max_autotune_enabled:
2163+
pass # skip all these cases
20342164
elif reduction_hint == ReductionHint.INNER:
2035-
return [contiguous_config]
2165+
return configs + [contiguous_config]
20362166
elif reduction_hint == ReductionHint.OUTER:
2037-
return [outer_config]
2167+
return configs + [outer_config]
20382168
elif reduction_hint == ReductionHint.OUTER_TINY:
2039-
return [tiny_config]
2040-
if disable_pointwise_autotuning(inductor_meta):
2041-
return [triton_config_reduction(size_hints, 32, 128)]
2042-
return [
2043-
contiguous_config,
2044-
outer_config,
2045-
tiny_config,
2046-
triton_config_reduction(size_hints, 64, 64),
2047-
triton_config_reduction(size_hints, 8, 512),
2048-
# halve the XBLOCK/Rn_BLOCK compared to outer_config
2049-
# TODO: this may only be beneficial when each iteration of the reduction
2050-
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2051-
triton_config_reduction(size_hints, 64, 4, num_warps=8),
2052-
]
2169+
return configs + [tiny_config]
2170+
2171+
# We continue here under the following conditions:
2172+
# - max_autotune_enabled is True
2173+
# - max_autotune_enabled is False and reduction_hint is NOT one of the above cases
2174+
result_configs = configs + [
2175+
contiguous_config,
2176+
outer_config,
2177+
tiny_config,
2178+
make_config(64, 64),
2179+
make_config(8, 512),
2180+
# halve the XBLOCK/Rn_BLOCK compared to outer_config
2181+
# TODO: this may only be beneficial when each iteration of the reduction
2182+
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2183+
make_config(64, 4, num_warps=8),
2184+
]
2185+
2186+
if torch.version.hip:
2187+
result_configs.extend(
2188+
[
2189+
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
2190+
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1),
2191+
]
2192+
)
2193+
2194+
return result_configs
2195+
2196+
2197+
def match_target_block_product(
2198+
size_hints, tiling_scores, target_block_product, min_block_size=1
2199+
):
2200+
"""
2201+
Distribute block sizes across dimensions according to tiling scores,
2202+
aiming to match a target product of block sizes.
2203+
"""
2204+
total_score = sum(tiling_scores.values())
2205+
if total_score == 0:
2206+
# just assume even score with no minimum block size
2207+
min_block_size = 1
2208+
tiling_scores = dict.fromkeys(tiling_scores.keys(), target_block_product)
2209+
2210+
# First, give each coalescing dimension at least min_block_size
2211+
block_sizes = {}
2212+
relative_scores = {}
2213+
curr_block_product = 1
2214+
2215+
for dim, score in tiling_scores.items():
2216+
if score == 0:
2217+
block_sizes[dim] = 1
2218+
continue
2219+
2220+
block_sizes[dim] = min_block_size
2221+
curr_block_product *= min_block_size
2222+
relative_scores[dim] = score / total_score
2223+
2224+
# Scale up dimensions by their relative scores until we reach the target
2225+
while curr_block_product < target_block_product and len(relative_scores):
2226+
dim, score = max(relative_scores.items(), key=lambda item: item[1])
2227+
2228+
# Check if we've hit the max for this dimension
2229+
if (
2230+
block_sizes[dim] >= TRITON_MAX_BLOCK[dim.capitalize()]
2231+
or block_sizes[dim] >= size_hints[dim]
2232+
):
2233+
del relative_scores[dim]
2234+
continue
2235+
2236+
block_sizes[dim] *= 2
2237+
relative_scores[dim] /= 2
2238+
curr_block_product *= 2
2239+
2240+
return block_sizes
2241+
2242+
2243+
def adapt_config_for_tiling(
2244+
size_hints,
2245+
tiling_scores,
2246+
original_x,
2247+
original_r,
2248+
num_warps=None,
2249+
num_stages=1,
2250+
register_intensive=False,
2251+
persistent_reduction=False,
2252+
waves_per_eu=None,
2253+
) -> Config:
2254+
"""
2255+
Create an adapted configuration based on tiling scores,
2256+
redistributing the same total block size (x * r) according to tiling scores.
2257+
"""
2258+
assert all(s in tiling_scores for s in size_hints)
2259+
target_block_product = original_x * original_r
2260+
block_sizes = match_target_block_product(
2261+
size_hints, tiling_scores, target_block_product
2262+
)
2263+
2264+
return triton_config_tiled_reduction(
2265+
size_hints,
2266+
block_sizes["x"],
2267+
block_sizes["y"],
2268+
block_sizes["r0_"],
2269+
num_stages=num_stages,
2270+
register_intensive=register_intensive,
2271+
waves_per_eu=waves_per_eu,
2272+
)
20532273

20542274

20552275
def reduction(

0 commit comments

Comments
 (0)