Skip to content

Commit 189481e

Browse files
naromero77amdjataylo
authored andcommitted
Reduction heursitics improvements for ROCm
(cherry picked from commit 9534cbd)
1 parent 0de435f commit 189481e

File tree

3 files changed

+73
-35
lines changed

3 files changed

+73
-35
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,11 +1101,17 @@ def relu(x):
11011101

11021102
@staticmethod
11031103
def minimum(a, b):
1104-
return f"triton_helpers.minimum({a}, {b})"
1104+
if torch.version.hip:
1105+
return f"tl.minimum({a}, {b})"
1106+
else:
1107+
return f"triton_helpers.minimum({a}, {b})"
11051108

11061109
@staticmethod
11071110
def maximum(a, b):
1108-
return f"triton_helpers.maximum({a}, {b})"
1111+
if torch.version.hip:
1112+
return f"tl.maximum({a}, {b})"
1113+
else:
1114+
return f"triton_helpers.maximum({a}, {b})"
11091115

11101116
@staticmethod
11111117
def where(a, b, c):
@@ -1291,7 +1297,10 @@ def load_seed(name, offset):
12911297
@staticmethod
12921298
@maybe_upcast_float32()
12931299
def rsqrt(x):
1294-
return f"libdevice.rsqrt({x})"
1300+
if torch.version.hip:
1301+
return f"tl.rsqrt({x})"
1302+
else:
1303+
return f"libdevice.rsqrt({x})"
12951304

12961305
@staticmethod
12971306
@maybe_upcast_float32()
@@ -3788,8 +3797,9 @@ def codegen_body(self):
37883797
loop_end = (
37893798
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
37903799
)
3800+
num_stages = ", num_stages = 2" if torch.version.hip else ""
37913801
self.body.writeline(
3792-
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
3802+
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):"
37933803
)
37943804
with self.body.indent(offset=level + 1):
37953805
self.iteration_ranges_codegen_header(tree, self.body)

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1391,7 +1391,7 @@ class triton:
13911391
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
13921392
# Raise the threshold to 16 to be safe.
13931393
# We should revisit this once we understand more of the source of register spills.
1394-
spill_threshold: int = 16
1394+
spill_threshold: int = 32 if torch.version.hip else 16
13951395

13961396
# Generate code containing the newer tl.make_block_ptr() API for loads/store
13971397
use_block_ptr = False

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
838838
# for some (complicated) custom Triton kernels, a register-spilling
839839
# config may yield the best latency.
840840
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
841-
"spill_threshold", 16
841+
"spill_threshold", 32 if torch.version.hip else 16
842842
):
843843
log.debug(
844844
"Skip config %s because of register spilling: %d",
@@ -2312,6 +2312,7 @@ def triton_config_reduction(
23122312
num_stages=1,
23132313
num_warps=None,
23142314
register_intensive=False,
2315+
waves_per_eu=None,
23152316
dynamic_scale_rblock=True,
23162317
) -> Config:
23172318
"""
@@ -2356,13 +2357,19 @@ def total_numel() -> int:
23562357
cfg = _get_config({"x": x, **rnumels})
23572358
check_max_block(cfg)
23582359
check_config(cfg, xnumel=size_hints["x"])
2359-
return InductorConfig(
2360+
config = InductorConfig(
23602361
cfg,
23612362
num_warps=num_warps,
23622363
num_stages=num_stages,
23632364
dynamic_scale_rblock=dynamic_scale_rblock,
23642365
)
23652366

2367+
if torch.version.hip:
2368+
if waves_per_eu is not None:
2369+
config.kwargs["waves_per_eu"] = waves_per_eu
2370+
2371+
return config
2372+
23662373

23672374
def _get_config(numels: dict[str, int]) -> dict[str, int]:
23682375
"""
@@ -2373,7 +2380,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
23732380

23742381

23752382
def triton_config_tiled_reduction(
2376-
size_hints, x, y, r, num_stages=1, register_intensive=False
2383+
size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None
23772384
):
23782385
"""
23792386
Construct a tile reduction triton config with some adjustment
@@ -2410,7 +2417,11 @@ def total_numel() -> int:
24102417
)
24112418
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
24122419
check_max_block(cfg)
2413-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
2420+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
2421+
if torch.version.hip:
2422+
if waves_per_eu is not None:
2423+
config.kwargs["waves_per_eu"] = waves_per_eu
2424+
return config
24142425

24152426

24162427
def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]):
@@ -2584,6 +2595,11 @@ def _reduction_configs(
25842595
# Convert reductions to 1D, to simplify heuristics.
25852596
rnumel = get_total_reduction_numel(size_hints)
25862597

2598+
# Is max autotune enabled
2599+
max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get(
2600+
"max_autotune_pointwise"
2601+
)
2602+
25872603
register_intensive = False
25882604
MAX_R0_BLOCK = 2048
25892605
loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get(
@@ -2612,6 +2628,7 @@ def make_config(
26122628
num_stages=1,
26132629
register_intensive=False,
26142630
dynamic_scale_rblock=True,
2631+
waves_per_eu=None,
26152632
):
26162633
# For 3D case with tiling scores, create an adapted version
26172634
if "y" in size_hints:
@@ -2624,6 +2641,7 @@ def make_config(
26242641
num_warps=num_warps,
26252642
num_stages=num_stages,
26262643
register_intensive=register_intensive,
2644+
waves_per_eu=waves_per_eu,
26272645
)
26282646
else:
26292647
# For other cases, use the original function
@@ -2634,6 +2652,7 @@ def make_config(
26342652
num_warps=num_warps,
26352653
num_stages=num_stages,
26362654
register_intensive=register_intensive,
2655+
waves_per_eu=waves_per_eu,
26372656
dynamic_scale_rblock=dynamic_scale_rblock,
26382657
)
26392658

@@ -2714,33 +2733,40 @@ def outer_config_opt():
27142733
)
27152734
configs.append(c)
27162735

2736+
result_configs = []
2737+
27172738
# For 3d tiling, default to more autotuning initially
2718-
if "y" in size_hints:
2719-
pass
2720-
elif inductor_meta.get("max_autotune") or inductor_meta.get(
2721-
"max_autotune_pointwise"
2722-
):
2723-
pass # skip all these cases
2724-
elif reduction_hint == ReductionHint.INNER:
2725-
return configs + [contiguous_config]
2726-
elif reduction_hint == ReductionHint.OUTER:
2727-
return configs + [outer_config]
2728-
elif reduction_hint == ReductionHint.OUTER_TINY:
2729-
return configs + [tiny_config]
2730-
if disable_pointwise_autotuning(inductor_meta):
2731-
return configs + [make_config(32, 128)]
2732-
2733-
return configs + [
2734-
contiguous_config,
2735-
outer_config,
2736-
tiny_config,
2737-
make_config(64, 64),
2738-
make_config(8, 512),
2739-
# halve the XBLOCK/Rn_BLOCK compared to outer_config
2740-
# TODO: this may only be beneficial when each iteration of the reduction
2741-
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2742-
make_config(64, 4, num_warps=8),
2743-
]
2739+
if not (max_autotune_enabled or "y" in size_hints):
2740+
if reduction_hint == ReductionHint.INNER:
2741+
result_configs = configs + [contiguous_config]
2742+
elif reduction_hint == ReductionHint.OUTER:
2743+
result_configs = configs + [outer_config]
2744+
elif reduction_hint == ReductionHint.OUTER_TINY:
2745+
result_configs = configs + [tiny_config]
2746+
else:
2747+
result_configs = configs + [make_config(32, 128)]
2748+
else:
2749+
result_configs = configs + [
2750+
contiguous_config,
2751+
outer_config,
2752+
tiny_config,
2753+
make_config(64, 64),
2754+
make_config(8, 512),
2755+
# halve the XBLOCK/Rn_BLOCK compared to outer_config
2756+
# TODO: this may only be beneficial when each iteration of the reduction
2757+
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2758+
make_config(64, 4, num_warps=8),
2759+
]
2760+
2761+
if torch.version.hip:
2762+
result_configs.extend(
2763+
[
2764+
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
2765+
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1),
2766+
]
2767+
)
2768+
2769+
return result_configs
27442770

27452771

27462772
def match_target_block_product(
@@ -2798,6 +2824,7 @@ def adapt_config_for_tiling(
27982824
num_stages=1,
27992825
register_intensive=False,
28002826
persistent_reduction=False,
2827+
waves_per_eu=None,
28012828
) -> Config:
28022829
"""
28032830
Create an adapted configuration based on tiling scores,
@@ -2816,6 +2843,7 @@ def adapt_config_for_tiling(
28162843
block_sizes["r0_"],
28172844
num_stages=num_stages,
28182845
register_intensive=register_intensive,
2846+
waves_per_eu=waves_per_eu,
28192847
)
28202848

28212849

0 commit comments

Comments
 (0)