Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,11 +1002,11 @@ def relu(x):

@staticmethod
def minimum(a, b):
return f"triton_helpers.minimum({a}, {b})"
return f"tl.minimum({a}, {b})"

@staticmethod
def maximum(a, b):
return f"triton_helpers.maximum({a}, {b})"
return f"tl.maximum({a}, {b})"

@staticmethod
def where(a, b, c):
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def load_seed(name, offset):
@staticmethod
@maybe_upcast_float32()
def rsqrt(x):
return f"libdevice.rsqrt({x})"
return f"tl.rsqrt({x})"

@staticmethod
@maybe_upcast_float32()
Expand Down Expand Up @@ -3222,7 +3222,7 @@ def codegen_body(self):
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
)
self.body.writeline(
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK, num_stages = 2):"
)
with self.body.indent(offset=level + 1):
self.iteration_ranges_codegen_header(tree, self.body)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,7 @@ class triton:
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
# Raise the threshold to 16 to be safe.
# We should revisit this once we understand more of the source of register spills.
spill_threshold: int = 16
spill_threshold: int = 32

# Generate code containing the newer tl.make_block_ptr() API for loads/store
use_block_ptr = False
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/kernel/mm_scaled.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ def scaled_mm_options( # type: ignore[no-untyped-def]
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
)
return dict(
GROUP_M=8,
# this change is incompatible with vllm, can't make it into our release
# should be fixed by them
# GROUP_M=8,
EVEN_K=even_k_symbolic,
ACC_TYPE="tl.float32",
USE_FAST_ACCUM=use_fast_accum,
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/runtime/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
"X": 4096,
"X": 8192,
"Y": 1024,
"Z": 1024,
"R0_": 4096 * 16, # * 16 is multi-kernel only
Expand Down
59 changes: 48 additions & 11 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
# for some (complicated) custom Triton kernels, a register-spilling
# config may yield the best latency.
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
"spill_threshold", 16
"spill_threshold", 32
):
log.debug(
"Skip config %s because of register spilling: %d",
Expand Down Expand Up @@ -1587,6 +1587,9 @@ def triton_config(
num_stages=1,
num_elements_per_warp=256,
min_elem_per_thread=0,
num_warps=None,
matrix_instr=None,
waves_per_eu=None
) -> Config:
"""
Construct a pointwise triton config with some adjustment heuristics
Expand Down Expand Up @@ -1643,9 +1646,11 @@ def triton_config(
):
z *= 2

num_warps = _num_warps(
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
)
# Calculate num_waprs if they are not hard passed to config
if num_warps == None:
num_warps = _num_warps(
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
)
# we are going to arrive at 2 warps only if bs was too small due to
# numel being too small. However to workaround some ptx bugs we still
# want at least 4 warps if there's enough elements per thread
Expand Down Expand Up @@ -1675,7 +1680,15 @@ def triton_config(
cfg["ZBLOCK"] = z
check_max_block(cfg)
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)

if torch.version.hip:
if matrix_instr is not None:
config.kwargs["matrix_instr_nonkdim"] = matrix_instr
if waves_per_eu is not None:
config.kwargs["waves_per_eu"] = waves_per_eu

return config


def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
Expand Down Expand Up @@ -1723,6 +1736,7 @@ def triton_config_reduction(
num_stages=1,
num_warps=None,
register_intensive=False,
waves_per_eu=None
) -> Config:
"""
Construct a reduction triton config with some adjustment heuristics
Expand Down Expand Up @@ -1766,7 +1780,13 @@ def total_numel() -> int:
cfg = _get_config({"x": x, **rnumels})
check_max_block(cfg)
check_config(cfg, xnumel=size_hints["x"])
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)

if torch.version.hip:
if waves_per_eu is not None:
config.kwargs["waves_per_eu"] = waves_per_eu

return config


def _get_config(numels: dict[str, int]) -> dict[str, int]:
Expand Down Expand Up @@ -1854,6 +1874,9 @@ def pointwise(
triton_config_with_settings(
size_hints, bs // 2, num_elements_per_warp=64
),
triton_config_with_settings(
size_hints, TRITON_MAX_BLOCK["X"]
),
*hinted_configs,
]
if len(size_hints) == 2:
Expand Down Expand Up @@ -1949,14 +1972,14 @@ def _reduction_configs(
if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"):
pass # skip all these cases
elif reduction_hint == ReductionHint.INNER:
return [contiguous_config]
result_configs = [contiguous_config]
elif reduction_hint == ReductionHint.OUTER:
return [outer_config]
result_configs = [outer_config]
elif reduction_hint == ReductionHint.OUTER_TINY:
return [tiny_config]
result_configs = [tiny_config]
if disable_pointwise_autotuning(inductor_meta):
return [triton_config_reduction(size_hints, 32, 128)]
return [
result_configs = [triton_config_reduction(size_hints, 32, 128)]
result_configs = [
contiguous_config,
outer_config,
tiny_config,
Expand All @@ -1968,6 +1991,19 @@ def _reduction_configs(
triton_config_reduction(size_hints, 64, 4, num_warps=8),
]

# Additional reduction configs appended for ROCm builds
if torch.version.hip:
# New config
result_configs.append(triton_config_reduction(
size_hints,
1024,
8,
num_warps=4,
num_stages=1
))

return result_configs


def reduction(
size_hints,
Expand All @@ -1985,6 +2021,7 @@ def reduction(
assert triton_meta is not None

configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)

return cached_autotune(
size_hints,
configs=configs,
Expand Down