Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 44 additions & 49 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2282,7 +2282,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:


def triton_config_tiled_reduction(
size_hints, x, y, r, num_stages=1, register_intensive=False
size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None
):
"""
Construct a tile reduction triton config with some adjustment
Expand Down Expand Up @@ -2319,7 +2319,11 @@ def total_numel() -> int:
)
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
check_max_block(cfg)
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
if torch.version.hip:
if waves_per_eu is not None:
config.kwargs["waves_per_eu"] = waves_per_eu
return config


def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]):
Expand Down Expand Up @@ -2469,6 +2473,9 @@ def _reduction_configs(
# Convert reductions to 1D, to simplify heuristics.
rnumel = get_total_reduction_numel(size_hints)

# Is max autotune enabled
max_autotune = inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")

register_intensive = False
MAX_R0_BLOCK = 2048
if (
Expand All @@ -2491,7 +2498,7 @@ def _reduction_configs(
MAX_R0_BLOCK = 1024
register_intensive = True

def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, waves_per_eu=None):
# For 3D case with tiling scores, create an adapted version
if "y" in size_hints:
assert "tiling_scores" in inductor_meta
Expand All @@ -2503,6 +2510,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
num_warps=num_warps,
num_stages=num_stages,
register_intensive=register_intensive,
waves_per_eu=waves_per_eu
)
else:
# For other cases, use the original function
Expand All @@ -2513,6 +2521,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
num_warps=num_warps,
num_stages=num_stages,
register_intensive=register_intensive,
waves_per_eu=waves_per_eu
)

contiguous_config = make_config(
Expand All @@ -2526,54 +2535,38 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
min(rnumel, MAX_R0_BLOCK),
register_intensive=register_intensive,
)
# For 3d tiling, default to more autotuning initially
if "y" in size_hints:
pass
elif inductor_meta.get("max_autotune") or inductor_meta.get(
"max_autotune_pointwise"
):

result_configs = []

# Extra ROCm tuning
if torch.version.hip:
result_configs.append(triton_config_reduction(
size_hints,
1024,
8,
num_warps=4,
num_stages=1,
waves_per_eu=2
))
result_configs.append(triton_config_reduction(
size_hints,
512,
8,
num_warps=4,
num_stages=1,
waves_per_eu=1
))

elif reduction_hint == ReductionHint.INNER:
result_configs = [contiguous_config]
elif reduction_hint == ReductionHint.OUTER:
result_configs = [outer_config]
elif reduction_hint == ReductionHint.OUTER_TINY:
result_configs = [tiny_config]
if disable_pointwise_autotuning(inductor_meta):
result_configs = [make_config(32, 128)]
result_configs = [
contiguous_config,
outer_config,
tiny_config,
make_config(64, 64),
make_config(8, 512),
# halve the XBLOCK/Rn_BLOCK compared to outer_config
# TODO: this may only be beneficial when each iteration of the reduction
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
make_config(64, 4, num_warps=8),
]
result_configs = []

if not (max_autotune or "y" in size_hints):
if reduction_hint == ReductionHint.INNER:
result_configs = [contiguous_config]
elif reduction_hint == ReductionHint.OUTER:
result_configs = [outer_config]
elif reduction_hint == ReductionHint.OUTER_TINY:
result_configs = [tiny_config]
else:
result_configs = [make_config(32, 128)]
else:
result_configs = [
contiguous_config,
outer_config,
tiny_config,
make_config(64, 64),
make_config(8, 512),
# halve the XBLOCK/Rn_BLOCK compared to outer_config
# TODO: this may only be beneficial when each iteration of the reduction
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
make_config(64, 4, num_warps=8),
]

# Add ROCm-specific configs when autotuning
if torch.version.hip:
result_configs.extend([
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1)
])

return result_configs


Expand Down Expand Up @@ -2632,6 +2625,7 @@ def adapt_config_for_tiling(
num_stages=1,
register_intensive=False,
persistent_reduction=False,
waves_per_eu=None
) -> Config:
"""
Create an adapted configuration based on tiling scores,
Expand All @@ -2650,6 +2644,7 @@ def adapt_config_for_tiling(
block_sizes["r0_"],
num_stages=num_stages,
register_intensive=register_intensive,
waves_per_eu=waves_per_eu
)


Expand Down