diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 9039fa2b51a1..9af0c91c0ddd 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1002,11 +1002,11 @@ def relu(x): @staticmethod def minimum(a, b): - return f"triton_helpers.minimum({a}, {b})" + return f"tl.minimum({a}, {b})" @staticmethod def maximum(a, b): - return f"triton_helpers.maximum({a}, {b})" + return f"tl.maximum({a}, {b})" @staticmethod def where(a, b, c): @@ -1202,7 +1202,7 @@ def load_seed(name, offset): @staticmethod @maybe_upcast_float32() def rsqrt(x): - return f"libdevice.rsqrt({x})" + return f"tl.rsqrt({x})" @staticmethod @maybe_upcast_float32() @@ -3222,7 +3222,7 @@ def codegen_body(self): "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" ) self.body.writeline( - f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK, num_stages = 2):" ) with self.body.indent(offset=level + 1): self.iteration_ranges_codegen_header(tree, self.body) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index ed930ad296a9..6e11426b76c2 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1138,7 +1138,7 @@ class triton: # So far we see a fixed 8 spilled registers for kernels using sin/cos. # Raise the threshold to 16 to be safe. # We should revisit this once we understand more of the source of register spills. - spill_threshold: int = 16 + spill_threshold: int = 32 # Generate code containing the newer tl.make_block_ptr() API for loads/store use_block_ptr = False diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py index aa917e120168..ad0029e21e59 100644 --- a/torch/_inductor/kernel/mm_scaled.py +++ b/torch/_inductor/kernel/mm_scaled.py @@ -469,7 +469,9 @@ def scaled_mm_options( # type: ignore[no-untyped-def] f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." ) return dict( - GROUP_M=8, + # this change is incompatible with vllm, can't make it into our release + # should be fixed by them + # GROUP_M=8, EVEN_K=even_k_symbolic, ACC_TYPE="tl.float32", USE_FAST_ACCUM=use_fast_accum, diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 3bc8df35a838..2a045f1167e9 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -13,7 +13,7 @@ # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 4096, + "X": 8192, "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 50565690fd96..f3a12a2a3b3e 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -577,7 +577,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs): # for some (complicated) custom Triton kernels, a register-spilling # config may yield the best latency. if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( - "spill_threshold", 16 + "spill_threshold", 32 ): log.debug( "Skip config %s because of register spilling: %d", @@ -1587,6 +1587,9 @@ def triton_config( num_stages=1, num_elements_per_warp=256, min_elem_per_thread=0, + num_warps=None, + matrix_instr=None, + waves_per_eu=None ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -1643,9 +1646,11 @@ def triton_config( ): z *= 2 - num_warps = _num_warps( - conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 - ) + # Calculate num_waprs if they are not hard passed to config + if num_warps == None: + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still # want at least 4 warps if there's enough elements per thread @@ -1675,7 +1680,15 @@ def triton_config( cfg["ZBLOCK"] = z check_max_block(cfg) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if matrix_instr is not None: + config.kwargs["matrix_instr_nonkdim"] = matrix_instr + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: @@ -1723,6 +1736,7 @@ def triton_config_reduction( num_stages=1, num_warps=None, register_intensive=False, + waves_per_eu=None ) -> Config: """ Construct a reduction triton config with some adjustment heuristics @@ -1766,7 +1780,13 @@ def total_numel() -> int: cfg = _get_config({"x": x, **rnumels}) check_max_block(cfg) check_config(cfg, xnumel=size_hints["x"]) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_config(numels: dict[str, int]) -> dict[str, int]: @@ -1854,6 +1874,9 @@ def pointwise( triton_config_with_settings( size_hints, bs // 2, num_elements_per_warp=64 ), + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"] + ), *hinted_configs, ] if len(size_hints) == 2: @@ -1949,14 +1972,14 @@ def _reduction_configs( if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"): pass # skip all these cases elif reduction_hint == ReductionHint.INNER: - return [contiguous_config] + result_configs = [contiguous_config] elif reduction_hint == ReductionHint.OUTER: - return [outer_config] + result_configs = [outer_config] elif reduction_hint == ReductionHint.OUTER_TINY: - return [tiny_config] + result_configs = [tiny_config] if disable_pointwise_autotuning(inductor_meta): - return [triton_config_reduction(size_hints, 32, 128)] - return [ + result_configs = [triton_config_reduction(size_hints, 32, 128)] + result_configs = [ contiguous_config, outer_config, tiny_config, @@ -1968,6 +1991,19 @@ def _reduction_configs( triton_config_reduction(size_hints, 64, 4, num_warps=8), ] + # Additional reduction configs appended for ROCm builds + if torch.version.hip: + # New config + result_configs.append(triton_config_reduction( + size_hints, + 1024, + 8, + num_warps=4, + num_stages=1 + )) + + return result_configs + def reduction( size_hints, @@ -1985,6 +2021,7 @@ def reduction( assert triton_meta is not None configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + return cached_autotune( size_hints, configs=configs,