Skip to content

Commit d417cd0

Browse files
committed
[Perf branch] Add support for 2d reduction and bug fix (#2629)
(cherry picked from commit d81b7e9)
1 parent 6c068ee commit d417cd0

File tree

1 file changed

+44
-49
lines changed

1 file changed

+44
-49
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2282,7 +2282,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
22822282

22832283

22842284
def triton_config_tiled_reduction(
2285-
size_hints, x, y, r, num_stages=1, register_intensive=False
2285+
size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None
22862286
):
22872287
"""
22882288
Construct a tile reduction triton config with some adjustment
@@ -2319,7 +2319,11 @@ def total_numel() -> int:
23192319
)
23202320
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
23212321
check_max_block(cfg)
2322-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
2322+
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
2323+
if torch.version.hip:
2324+
if waves_per_eu is not None:
2325+
config.kwargs["waves_per_eu"] = waves_per_eu
2326+
return config
23232327

23242328

23252329
def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]):
@@ -2469,6 +2473,9 @@ def _reduction_configs(
24692473
# Convert reductions to 1D, to simplify heuristics.
24702474
rnumel = get_total_reduction_numel(size_hints)
24712475

2476+
# Is max autotune enabled
2477+
max_autotune = inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
2478+
24722479
register_intensive = False
24732480
MAX_R0_BLOCK = 2048
24742481
if (
@@ -2491,7 +2498,7 @@ def _reduction_configs(
24912498
MAX_R0_BLOCK = 1024
24922499
register_intensive = True
24932500

2494-
def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
2501+
def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, waves_per_eu=None):
24952502
# For 3D case with tiling scores, create an adapted version
24962503
if "y" in size_hints:
24972504
assert "tiling_scores" in inductor_meta
@@ -2503,6 +2510,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
25032510
num_warps=num_warps,
25042511
num_stages=num_stages,
25052512
register_intensive=register_intensive,
2513+
waves_per_eu=waves_per_eu
25062514
)
25072515
else:
25082516
# For other cases, use the original function
@@ -2513,6 +2521,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
25132521
num_warps=num_warps,
25142522
num_stages=num_stages,
25152523
register_intensive=register_intensive,
2524+
waves_per_eu=waves_per_eu
25162525
)
25172526

25182527
contiguous_config = make_config(
@@ -2526,54 +2535,38 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
25262535
min(rnumel, MAX_R0_BLOCK),
25272536
register_intensive=register_intensive,
25282537
)
2529-
# For 3d tiling, default to more autotuning initially
2530-
if "y" in size_hints:
2531-
pass
2532-
elif inductor_meta.get("max_autotune") or inductor_meta.get(
2533-
"max_autotune_pointwise"
2534-
):
25352538

2536-
result_configs = []
2537-
2538-
# Extra ROCm tuning
2539-
if torch.version.hip:
2540-
result_configs.append(triton_config_reduction(
2541-
size_hints,
2542-
1024,
2543-
8,
2544-
num_warps=4,
2545-
num_stages=1,
2546-
waves_per_eu=2
2547-
))
2548-
result_configs.append(triton_config_reduction(
2549-
size_hints,
2550-
512,
2551-
8,
2552-
num_warps=4,
2553-
num_stages=1,
2554-
waves_per_eu=1
2555-
))
2556-
2557-
elif reduction_hint == ReductionHint.INNER:
2558-
result_configs = [contiguous_config]
2559-
elif reduction_hint == ReductionHint.OUTER:
2560-
result_configs = [outer_config]
2561-
elif reduction_hint == ReductionHint.OUTER_TINY:
2562-
result_configs = [tiny_config]
2563-
if disable_pointwise_autotuning(inductor_meta):
2564-
result_configs = [make_config(32, 128)]
2565-
result_configs = [
2566-
contiguous_config,
2567-
outer_config,
2568-
tiny_config,
2569-
make_config(64, 64),
2570-
make_config(8, 512),
2571-
# halve the XBLOCK/Rn_BLOCK compared to outer_config
2572-
# TODO: this may only be beneficial when each iteration of the reduction
2573-
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2574-
make_config(64, 4, num_warps=8),
2575-
]
2539+
result_configs = []
2540+
2541+
if not (max_autotune or "y" in size_hints):
2542+
if reduction_hint == ReductionHint.INNER:
2543+
result_configs = [contiguous_config]
2544+
elif reduction_hint == ReductionHint.OUTER:
2545+
result_configs = [outer_config]
2546+
elif reduction_hint == ReductionHint.OUTER_TINY:
2547+
result_configs = [tiny_config]
2548+
else:
2549+
result_configs = [make_config(32, 128)]
2550+
else:
2551+
result_configs = [
2552+
contiguous_config,
2553+
outer_config,
2554+
tiny_config,
2555+
make_config(64, 64),
2556+
make_config(8, 512),
2557+
# halve the XBLOCK/Rn_BLOCK compared to outer_config
2558+
# TODO: this may only be beneficial when each iteration of the reduction
2559+
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2560+
make_config(64, 4, num_warps=8),
2561+
]
25762562

2563+
# Add ROCm-specific configs when autotuning
2564+
if torch.version.hip:
2565+
result_configs.extend([
2566+
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
2567+
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1)
2568+
])
2569+
25772570
return result_configs
25782571

25792572

@@ -2632,6 +2625,7 @@ def adapt_config_for_tiling(
26322625
num_stages=1,
26332626
register_intensive=False,
26342627
persistent_reduction=False,
2628+
waves_per_eu=None
26352629
) -> Config:
26362630
"""
26372631
Create an adapted configuration based on tiling scores,
@@ -2650,6 +2644,7 @@ def adapt_config_for_tiling(
26502644
block_sizes["r0_"],
26512645
num_stages=num_stages,
26522646
register_intensive=register_intensive,
2647+
waves_per_eu=waves_per_eu
26532648
)
26542649

26552650

0 commit comments

Comments
 (0)