diff --git a/test/inductor/test_async_compile.py b/test/inductor/test_async_compile.py index 5a61ea851eae0..cc94c4c95e01a 100644 --- a/test/inductor/test_async_compile.py +++ b/test/inductor/test_async_compile.py @@ -74,7 +74,14 @@ def f(a, b): return (a @ b).to(torch.float32).sum(dim=1) # Fake name to make sure the lookup table is name agnostic - func_def = """ + # When codegen/triton.py is changed, func_def must be updated + loop_header = ( + "for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):" + if torch.version.hip + else "for r0_offset in range(0, r0_numel, R0_BLOCK):" + ) + + func_def = f""" def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): xnumel = 1024 r0_numel = 11776 @@ -87,7 +94,7 @@ def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.cons rbase = r0_base x0 = xindex _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) - for r0_offset in range(0, r0_numel, R0_BLOCK): + {loop_header} r0_index = r0_offset + r0_base r0_mask = r0_index < r0_numel roffset = r0_offset diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 3848fc3355e49..17a336cc3cf2e 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1101,11 +1101,17 @@ def relu(x): @staticmethod def minimum(a, b): - return f"triton_helpers.minimum({a}, {b})" + if torch.version.hip: + return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)" + else: + return f"triton_helpers.minimum({a}, {b})" @staticmethod def maximum(a, b): - return f"triton_helpers.maximum({a}, {b})" + if torch.version.hip: + return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)" + else: + return f"triton_helpers.maximum({a}, {b})" @staticmethod def where(a, b, c): @@ -1291,7 +1297,10 @@ def load_seed(name, offset): @staticmethod @maybe_upcast_float32() def rsqrt(x): - return f"libdevice.rsqrt({x})" + if torch.version.hip: + return f"tl.rsqrt({x})" + else: + return f"libdevice.rsqrt({x})" @staticmethod @maybe_upcast_float32() @@ -3788,8 +3797,9 @@ def codegen_body(self): loop_end = ( "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" ) + num_stages = ", num_stages = 2" if torch.version.hip else "" self.body.writeline( - f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):" ) with self.body.indent(offset=level + 1): self.iteration_ranges_codegen_header(tree, self.body) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f6921a057ba0f..857272df14c94 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1391,7 +1391,7 @@ class triton: # So far we see a fixed 8 spilled registers for kernels using sin/cos. # Raise the threshold to 16 to be safe. # We should revisit this once we understand more of the source of register spills. - spill_threshold: int = 16 + spill_threshold: int = 32 if torch.version.hip else 16 # Generate code containing the newer tl.make_block_ptr() API for loads/store use_block_ptr = False diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 15b86b1b3d1ae..a1a0a792c9b84 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -13,7 +13,7 @@ # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 4096, + "X": 8192, "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 46832167622b1..3aae58f2aa428 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -838,7 +838,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs): # for some (complicated) custom Triton kernels, a register-spilling # config may yield the best latency. if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( - "spill_threshold", 16 + "spill_threshold", 32 if torch.version.hip else 16 ): log.debug( "Skip config %s because of register spilling: %d", @@ -2163,6 +2163,9 @@ def triton_config( num_stages=1, num_elements_per_warp=256, min_elem_per_thread=0, + num_warps=None, + matrix_instr=None, + waves_per_eu=None, ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -2219,9 +2222,11 @@ def triton_config( ): z *= 2 - num_warps = _num_warps( - conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 - ) + # Calculate num_warps if they are not hard passed to config + if num_warps is None: + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still # want at least 4 warps if there's enough elements per thread @@ -2251,7 +2256,15 @@ def triton_config( cfg["ZBLOCK"] = z check_max_block(cfg) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if matrix_instr is not None: + config.kwargs["matrix_instr_nonkdim"] = matrix_instr + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: @@ -2299,6 +2312,7 @@ def triton_config_reduction( num_stages=1, num_warps=None, register_intensive=False, + waves_per_eu=None, dynamic_scale_rblock=True, ) -> Config: """ @@ -2343,13 +2357,19 @@ def total_numel() -> int: cfg = _get_config({"x": x, **rnumels}) check_max_block(cfg) check_config(cfg, xnumel=size_hints["x"]) - return InductorConfig( + config = InductorConfig( cfg, num_warps=num_warps, num_stages=num_stages, dynamic_scale_rblock=dynamic_scale_rblock, ) + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config + def _get_config(numels: dict[str, int]) -> dict[str, int]: """ @@ -2360,7 +2380,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]: def triton_config_tiled_reduction( - size_hints, x, y, r, num_stages=1, register_intensive=False + size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None ): """ Construct a tile reduction triton config with some adjustment @@ -2397,7 +2417,11 @@ def total_numel() -> int: ) check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) check_max_block(cfg) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + return config def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]): @@ -2486,11 +2510,37 @@ def pointwise( triton_config_with_settings( size_hints, bs // 2, num_elements_per_warp=64 ), + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), + triton_config_with_settings( + size_hints, + 4096, # wrt: better than the max_block for some kernel + ), *hinted_configs, ] + # Additional reduction configs appended for ROCm builds + if torch.version.hip: + configs.append( + triton_config_with_settings( + size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1 + ) + ) # 20% improvement + configs += [ + triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where? + triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel + triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1), + # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, + # triton_poi_fused_index_put_new_zeros_45 + # triton_poi_fused_index_put_new_zeros_49 + # triton_poi_fused_index_put_new_zeros_54 + ] if len(size_hints) == 2: + # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds + # ROCm has observed improvement by diverging here if ( - disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE + disable_pointwise_autotuning(inductor_meta) + or (torch.version.hip is None and tile_hint == TileHint.SQUARE) ) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") @@ -2499,13 +2549,34 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings( + size_hints, 64, 32 + ), # better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 - triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings( + size_hints, 128, 16 + ), # +10% for some kernels + triton_config_with_settings(size_hints, 128, 32), # additional 10% more + triton_config_with_settings( + size_hints, 32, 512 + ), # +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), *hinted_configs, ] + if torch.version.hip: + configs += [ # add here + ] + # bypass triton_config_with_settings -> triton_config logic + if "x" in size_hints and "y" in size_hints: + configs += [ + Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 + Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52 + Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 + ] + if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta): configs = [triton_config_with_settings(size_hints, 16, 16, 16)] @@ -2544,6 +2615,11 @@ def _reduction_configs( # Convert reductions to 1D, to simplify heuristics. rnumel = get_total_reduction_numel(size_hints) + # Is max autotune enabled + max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ) + register_intensive = False MAX_R0_BLOCK = 2048 loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get( @@ -2572,6 +2648,7 @@ def make_config( num_stages=1, register_intensive=False, dynamic_scale_rblock=True, + waves_per_eu=None, ): # For 3D case with tiling scores, create an adapted version if "y" in size_hints: @@ -2584,6 +2661,7 @@ def make_config( num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu, ) else: # For other cases, use the original function @@ -2594,6 +2672,7 @@ def make_config( num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu, dynamic_scale_rblock=dynamic_scale_rblock, ) @@ -2674,33 +2753,45 @@ def outer_config_opt(): ) configs.append(c) + result_configs = [] + # For 3d tiling, default to more autotuning initially - if "y" in size_hints: - pass - elif inductor_meta.get("max_autotune") or inductor_meta.get( - "max_autotune_pointwise" - ): - pass # skip all these cases - elif reduction_hint == ReductionHint.INNER: - return configs + [contiguous_config] - elif reduction_hint == ReductionHint.OUTER: - return configs + [outer_config] - elif reduction_hint == ReductionHint.OUTER_TINY: - return configs + [tiny_config] - if disable_pointwise_autotuning(inductor_meta): - return configs + [make_config(32, 128)] - - return configs + [ - contiguous_config, - outer_config, - tiny_config, - make_config(64, 64), - make_config(8, 512), - # halve the XBLOCK/Rn_BLOCK compared to outer_config - # TODO: this may only be beneficial when each iteration of the reduction - # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 - make_config(64, 4, num_warps=8), - ] + if not (max_autotune_enabled or "y" in size_hints): + if reduction_hint == ReductionHint.INNER: + result_configs = configs + [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + result_configs = configs + [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + result_configs = configs + [tiny_config] + else: + result_configs = configs + [make_config(32, 128)] + else: + result_configs = configs + [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] + + if torch.version.hip: + result_configs.extend( + [ + make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), + make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), + make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8 + make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4 + make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153 + make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16 + make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29 + ] + ) + + return result_configs def match_target_block_product( @@ -2758,6 +2849,7 @@ def adapt_config_for_tiling( num_stages=1, register_intensive=False, persistent_reduction=False, + waves_per_eu=None, ) -> Config: """ Create an adapted configuration based on tiling scores, @@ -2776,6 +2868,7 @@ def adapt_config_for_tiling( block_sizes["r0_"], num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu, ) @@ -2868,17 +2961,25 @@ def _persistent_reduction_configs( ): xnumel = size_hints["x"] rnumel = get_total_reduction_numel(size_hints) + loads_and_stores = inductor_meta.get("num_load", 0) + inductor_meta.get( + "num_store", 0 + ) MAX_PERSISTENT_BLOCK_NUMEL = 4096 + max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or ( - inductor_meta.get("max_autotune") - or inductor_meta.get("max_autotune_pointwise") + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") ) + if torch.version.hip: + xblock_vals = [1, 4, 8, 16, 32, 64, 128, 256] + else: + xblock_vals = [1, 8, 32, 128] + if "y" not in size_hints: configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) - for xblock in (1, 8, 32, 128) + for xblock in xblock_vals if xblock == 1 or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) ] @@ -2886,7 +2987,7 @@ def _persistent_reduction_configs( configs = [] assert "tiling_scores" in inductor_meta x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} - for target_block_size in (1, 8, 32, 64, 128): + for target_block_size in xblock_vals: if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: continue @@ -2899,16 +3000,6 @@ def _persistent_reduction_configs( ) ) - # defer to more autotuning, initially - if "y" in size_hints: - pass - # TODO(jansel): we should be able to improve these heuristics - if not max_autotune_enabled: # Don't filter if tuning enabled - if reduction_hint == ReductionHint.INNER and rnumel >= 256: - configs = configs[:1] - elif reduction_hint == ReductionHint.OUTER: - configs = configs[-1:] - tiny_configs = [ triton_config_reduction( size_hints, @@ -2917,12 +3008,40 @@ def _persistent_reduction_configs( ) ] - if max_autotune_enabled: + # defer to more autotuning, initially + if "y" in size_hints: + pass + # TODO(jansel): we should be able to improve these heuristics + elif not max_autotune_enabled: # Do not filter configs when tuning + if reduction_hint == ReductionHint.INNER: + if rnumel > 1024: + configs = configs[:1] + else: + x_block = 8 + if xnumel // x_block < 128 or (loads_and_stores >= 5 and rnumel >= 256): + # If loads/stores greater than 5, a lot of register pressure + # rnumel < 256 means no vectorized loads if we split up r dim + # so xblock still needs to be larger + x_block = 1 + + configs = [ + triton_config_reduction( + size_hints, + x_block, + rnumel, + register_intensive=True, + ) + ] + + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = tiny_configs + else: + # If autotune is enabled append tiny configs for conf in tiny_configs: if conf not in configs: configs.append(conf) - elif reduction_hint == ReductionHint.OUTER_TINY: - configs = tiny_configs for c in configs: # we don't need Rn_BLOCK for persistent reduction @@ -3120,9 +3239,10 @@ def foreach(triton_meta, filename=None, inductor_meta=None): Compile a triton foreach kernel """ configs = [] + + # Naive autotuning path for num_warps if disable_pointwise_autotuning(inductor_meta) and not ( - inductor_meta.get("max_autotune") or - inductor_meta.get("max_autotune_pointwise") + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") ): configs.append(triton.Config({}, num_stages=1, num_warps=8)) else: @@ -3138,6 +3258,7 @@ def foreach(triton_meta, filename=None, inductor_meta=None): filename=filename, ) + @dataclasses.dataclass class GridExpr: """Generate code for grid size expressions in launcher"""