diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 7d65354e7a2f4..883a73b3a4df4 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1027,11 +1027,11 @@ def relu(x): @staticmethod def minimum(a, b): - return f"triton_helpers.minimum({a}, {b})" + return f"tl.minimum({a}, {b})" @staticmethod def maximum(a, b): - return f"triton_helpers.maximum({a}, {b})" + return f"tl.maximum({a}, {b})" @staticmethod def where(a, b, c): @@ -1217,7 +1217,7 @@ def load_seed(name, offset): @staticmethod @maybe_upcast_float32() def rsqrt(x): - return f"libdevice.rsqrt({x})" + return f"tl.rsqrt({x})" @staticmethod @maybe_upcast_float32() @@ -3296,7 +3296,7 @@ def codegen_body(self): "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" ) self.body.writeline( - f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK, num_stages = 2):" ) with self.body.indent(offset=level + 1): self.iteration_ranges_codegen_header(tree, self.body) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 9ebf118a5f643..41d04ce01593a 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1261,7 +1261,7 @@ class triton: # So far we see a fixed 8 spilled registers for kernels using sin/cos. # Raise the threshold to 16 to be safe. # We should revisit this once we understand more of the source of register spills. - spill_threshold: int = 16 + spill_threshold: int = 32 # Generate code containing the newer tl.make_block_ptr() API for loads/store use_block_ptr = False diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index e559eaa1a31d4..1b4fae120142f 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -13,7 +13,7 @@ # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 4096, + "X": 8192, "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a79d6cd41b7cc..95f58213397dd 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -765,7 +765,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs): # for some (complicated) custom Triton kernels, a register-spilling # config may yield the best latency. if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( - "spill_threshold", 16 + "spill_threshold", 32 ): log.debug( "Skip config %s because of register spilling: %d", @@ -1987,6 +1987,9 @@ def triton_config( num_stages=1, num_elements_per_warp=256, min_elem_per_thread=0, + num_warps=None, + matrix_instr=None, + waves_per_eu=None ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -2043,9 +2046,11 @@ def triton_config( ): z *= 2 - num_warps = _num_warps( - conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 - ) + # Calculate num_waprs if they are not hard passed to config + if num_warps == None: + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still # want at least 4 warps if there's enough elements per thread @@ -2075,7 +2080,15 @@ def triton_config( cfg["ZBLOCK"] = z check_max_block(cfg) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if matrix_instr is not None: + config.kwargs["matrix_instr_nonkdim"] = matrix_instr + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: @@ -2123,6 +2136,7 @@ def triton_config_reduction( num_stages=1, num_warps=None, register_intensive=False, + waves_per_eu=None ) -> Config: """ Construct a reduction triton config with some adjustment heuristics @@ -2166,7 +2180,13 @@ def total_numel() -> int: cfg = _get_config({"x": x, **rnumels}) check_max_block(cfg) check_config(cfg, xnumel=size_hints["x"]) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_config(numels: dict[str, int]) -> dict[str, int]: @@ -2178,7 +2198,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]: def triton_config_tiled_reduction( - size_hints, x, y, r, num_stages=1, register_intensive=False + size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None ): """ Construct a tile reduction triton config with some adjustment @@ -2215,7 +2235,13 @@ def total_numel() -> int: ) check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) check_max_block(cfg) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def pointwise( @@ -2259,11 +2285,26 @@ def pointwise( triton_config_with_settings( size_hints, bs // 2, num_elements_per_warp=64 ), + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), + triton_config_with_settings( + size_hints, 4096 # wrt: better than the max_block for some kernel + ), *hinted_configs, ] + # Additional reduction configs appended for ROCm builds + if torch.version.hip: + configs.append(triton_config_with_settings( + size_hints, + 2048, + num_warps=8, + num_stages=2, + waves_per_eu=1 + )) # 20% improvement if len(size_hints) == 2: if ( - disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE + disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE ) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") @@ -2272,9 +2313,13 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings(size_hints, 64, 32), # wrt: better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings(size_hints, 128, 16), # wrt: +10% for some kernels + triton_config_with_settings(size_hints, 128, 32), # wrt: ..additional 10% more + triton_config_with_settings(size_hints, 32, 512), # wrt: +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), *hinted_configs, @@ -2314,6 +2359,12 @@ def _reduction_configs( # Convert reductions to 1D, to simplify heuristics. rnumel = get_total_reduction_numel(size_hints) + # Is max autotune enabled + max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ) + register_intensive = False MAX_R0_BLOCK = 2048 if ( @@ -2336,7 +2387,7 @@ def _reduction_configs( MAX_R0_BLOCK = 1024 register_intensive = True - def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): + def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, waves_per_eu=None): # For 3D case with tiling scores, create an adapted version if "y" in size_hints: assert "tiling_scores" in inductor_meta @@ -2348,6 +2399,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu ) else: # For other cases, use the original function @@ -2358,6 +2410,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu ) contiguous_config = make_config( @@ -2371,32 +2424,39 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): min(rnumel, MAX_R0_BLOCK), register_intensive=register_intensive, ) - # For 3d tiling, default to more autotuning initially - if "y" in size_hints: - pass - elif inductor_meta.get("max_autotune") or inductor_meta.get( - "max_autotune_pointwise" - ): - pass # skip all these cases - elif reduction_hint == ReductionHint.INNER: - return [contiguous_config] - elif reduction_hint == ReductionHint.OUTER: - return [outer_config] - elif reduction_hint == ReductionHint.OUTER_TINY: - return [tiny_config] - if disable_pointwise_autotuning(inductor_meta): - return [make_config(32, 128)] - return [ - contiguous_config, - outer_config, - tiny_config, - make_config(64, 64), - make_config(8, 512), - # halve the XBLOCK/Rn_BLOCK compared to outer_config - # TODO: this may only be beneficial when each iteration of the reduction - # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 - make_config(64, 4, num_warps=8), - ] + + result_configs = [] + + if not (max_autotune_enabled or "y" in size_hints): + if reduction_hint == ReductionHint.INNER: + result_configs = [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + result_configs = [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + result_configs = [tiny_config] + else: + result_configs = [make_config(32, 128)] + else: + result_configs = [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] + + # Add ROCm-specific configs when autotuning + if torch.version.hip: + result_configs.extend([ + make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), + make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1) + ]) + + return result_configs def match_target_block_product( @@ -2454,6 +2514,7 @@ def adapt_config_for_tiling( num_stages=1, register_intensive=False, persistent_reduction=False, + waves_per_eu=None ) -> Config: """ Create an adapted configuration based on tiling scores, @@ -2472,6 +2533,7 @@ def adapt_config_for_tiling( block_sizes["r0_"], num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu ) @@ -2491,6 +2553,24 @@ def reduction( assert triton_meta is not None configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + + # Additional tuning confirgs for ROCm builds + # Add checks for reduction autotuning bools + # if torch.version.hip and inductor_meta.get("max_autotune"): + # configs = [ + # triton_config_with_settings(size_hints, bs, num_elements_per_warp=256), + # triton_config_with_settings( + # size_hints, bs // 2, num_elements_per_warp=64 + # ), + # # triton_config_with_settings( + # # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2 + # # ), + # triton_config_with_settings( + # size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + # ), + # *hinted_configs, + # ] + return cached_autotune( size_hints, configs=configs, @@ -2564,15 +2644,15 @@ def _persistent_reduction_configs( if "y" not in size_hints: configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) - for xblock in (1, 8, 32, 128) + for xblock in (1, 4, 8, 16, 32, 64, 128, 256, 512) if xblock == 1 - or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) + or (xblock <= xnumel and rnumel * xblock <= 4096) ] else: configs = [] assert "tiling_scores" in inductor_meta x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} - for target_block_size in (1, 8, 32, 64, 128): + for target_block_size in (1, 4, 8, 16, 32, 64, 128, 256, 512): if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: continue @@ -2607,6 +2687,22 @@ def _persistent_reduction_configs( for conf in tiny_configs: if conf not in configs: configs.append(conf) + + # Expand configs to try additional warps + expanded_configs = [] + for conf in configs: + num_warps = conf.num_warps + max_warps = 8 if torch.version.hip else 16 + small_conf = copy.deepcopy(conf) + large_conf = copy.deepcopy(conf) + small_conf.num_warps = max(small_conf.num_warps // 2, 1) + large_conf.num_warps = min(large_conf.num_warps * 2, max_warps) + expanded_configs.append(conf) + expanded_configs.append(small_conf) + expanded_configs.append(large_conf) + + configs = expanded_configs + elif reduction_hint == ReductionHint.OUTER_TINY: configs = tiny_configs