diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 152d2ef36197f..98908d7cb696a 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1873,6 +1873,7 @@ def __init__( self.compute = IndentedBuffer() self.stores = IndentedBuffer() + self.atomic_add_found = False self.num_load = 0 self.num_reduction = 0 diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e9c5b910ba02f..d8c5b35fe3972 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1002,11 +1002,17 @@ def relu(x): @staticmethod def minimum(a, b): - return f"triton_helpers.minimum({a}, {b})" + if torch.version.hip: + return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)" + else: + return f"triton_helpers.minimum({a}, {b})" @staticmethod def maximum(a, b): - return f"triton_helpers.maximum({a}, {b})" + if torch.version.hip: + return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)" + else: + return f"triton_helpers.maximum({a}, {b})" @staticmethod def where(a, b, c): @@ -1202,7 +1208,10 @@ def load_seed(name, offset): @staticmethod @maybe_upcast_float32() def rsqrt(x): - return f"libdevice.rsqrt({x})" + if torch.version.hip: + return f"tl.rsqrt({x})" + else: + return f"libdevice.rsqrt({x})" @staticmethod @maybe_upcast_float32() @@ -2285,6 +2294,7 @@ def store( elif mode is None: line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})" elif mode == "atomic_add": + self.atomic_add_found = True line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str}, sem='relaxed')" else: raise NotImplementedError(f"store mode={mode}") @@ -3227,8 +3237,9 @@ def codegen_body(self): loop_end = ( "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" ) + num_stages = ", num_stages = 2" if torch.version.hip else "" self.body.writeline( - f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):" ) with self.body.indent(offset=level + 1): self.iteration_ranges_codegen_header(tree, self.body) @@ -3601,6 +3612,7 @@ def add_constexpr_arg(arg_name): "mutated_arg_names": mutated_args, "optimize_mem": optimize_mem, "no_x_dim": self.no_x_dim, + "atomic_add_found": self.atomic_add_found, "num_load": self.num_load, "num_reduction": self.num_reduction, **self.inductor_meta_common(), diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 0b58772fbf0b6..13a1e731892cc 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1135,7 +1135,7 @@ class triton: # So far we see a fixed 8 spilled registers for kernels using sin/cos. # Raise the threshold to 16 to be safe. # We should revisit this once we understand more of the source of register spills. - spill_threshold: int = 16 + spill_threshold: int = 32 if torch.version.hip else 16 # Generate code containing the newer tl.make_block_ptr() API for loads/store use_block_ptr = False diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 3bc8df35a8389..ec981f4a786df 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -7,13 +7,14 @@ from enum import auto, Enum from typing import Optional, Union +import torch from torch.utils._triton import has_triton_package # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 4096, + "X": 8192 if torch.version.hip else 4096, "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 4d87a8236c460..0f3d3e0762eb5 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -69,6 +69,14 @@ ) +class InductorConfig(Config): + """Inductor-specific Triton config with additional control flags""" + + def __init__(self, *args, dynamic_scale_rblock=True, **kwargs): + super().__init__(*args, **kwargs) + self.dynamic_scale_rblock = dynamic_scale_rblock + + class NoTritonConfigsError(RuntimeError): pass @@ -597,7 +605,8 @@ def _get_args_with_constexprs(self, args, launcher): # so we can sort them by index. constexpr_args: list[tuple[int, Any]] = [] for arg_name, arg_val in launcher.config.kwargs.items(): - constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) + if arg_name in self.fn.arg_names: + constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) constexpr_args.sort() new_args = [*args] @@ -615,7 +624,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs): # for some (complicated) custom Triton kernels, a register-spilling # config may yield the best latency. if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( - "spill_threshold", 16 + "spill_threshold", 32 if torch.version.hip else 16 ): log.debug( "Skip config %s because of register spilling: %d", @@ -1630,6 +1639,9 @@ def triton_config( num_stages=1, num_elements_per_warp=256, min_elem_per_thread=0, + num_warps=None, + matrix_instr=None, + waves_per_eu=None, ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -1686,9 +1698,11 @@ def triton_config( ): z *= 2 - num_warps = _num_warps( - conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 - ) + # Calculate num_warps if they are not hard passed to config + if num_warps is None: + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still # want at least 4 warps if there's enough elements per thread @@ -1718,7 +1732,15 @@ def triton_config( cfg["ZBLOCK"] = z check_max_block(cfg) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if matrix_instr is not None: + config.kwargs["matrix_instr_nonkdim"] = matrix_instr + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: @@ -1766,6 +1788,8 @@ def triton_config_reduction( num_stages=1, num_warps=None, register_intensive=False, + waves_per_eu=None, + dynamic_scale_rblock=True, ) -> Config: """ Construct a reduction triton config with some adjustment heuristics @@ -1809,7 +1833,18 @@ def total_numel() -> int: cfg = _get_config({"x": x, **rnumels}) check_max_block(cfg) check_config(cfg, xnumel=size_hints["x"]) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = InductorConfig( + cfg, + num_warps=num_warps, + num_stages=num_stages, + dynamic_scale_rblock=dynamic_scale_rblock, + ) + + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_config(numels: dict[str, int]) -> dict[str, int]: @@ -1820,7 +1855,9 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]: return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()} -def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1): +def triton_config_tiled_reduction( + size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None +): """ Construct a tile reduction triton config with some adjustment heuristics based on size_hints. Size_hints is a tuple of numels in @@ -1853,7 +1890,11 @@ def total_numel() -> int: num_warps = _num_warps(total_numel() // 256, min_num_warps=1) check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1]) check_max_block(cfg) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + return config def pointwise( @@ -1899,9 +1940,52 @@ def pointwise( ), *hinted_configs, ] + # Additional reduction configs appended for ROCm builds + if torch.version.hip: + configs.extend( + [ + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), + triton_config_with_settings( + size_hints, + 4096, # wrt: better than the max_block for some kernel + ), + triton_config_with_settings( + size_hints, + 2048, + num_warps=8, + num_stages=2, + waves_per_eu=1, # 20% improvement + ), + ] + ) + configs += [ + triton_config_with_settings( + size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1 + ), + # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, + # triton_poi_fused_index_put_new_zeros_45 + # triton_poi_fused_index_put_new_zeros_49 + # triton_poi_fused_index_put_new_zeros_54 + ] + if inductor_meta.get("atomic_add_found"): + configs.extend( + [ + triton_config_with_settings( + size_hints, + 64, + num_warps=1, + num_stages=1, # 250% improvement + ) + ] + ) if len(size_hints) == 2: + # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds + # ROCm has observed improvement by diverging here if ( - disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE + disable_pointwise_autotuning(inductor_meta) + or (torch.version.hip is None and tile_hint == TileHint.SQUARE) ) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") @@ -1910,13 +1994,49 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings( + size_hints, 64, 32 + ), # better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings( + size_hints, 128, 16 + ), # +10% for some kernels + triton_config_with_settings(size_hints, 128, 32), # additional 10% more + triton_config_with_settings( + size_hints, 32, 512 + ), # +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), *hinted_configs, ] + if torch.version.hip: + if "x" in size_hints and "y" in size_hints: + """add 2D tiling configs, but don't use triton_config_with_settings function + as it is buggy and might change the tiling randomly + """ + def addConfig__(xblock:int, yblock:int, num_warps:int, num_stages:int): + # only add a tiling config if size is bigger than the tile + # check also for grid overflow + xgrid = (size_hints["x"] + xblock - 1) // xblock + ygrid = (size_hints["y"] + yblock - 1) // yblock + if xgrid > 2147483647: + return + if ygrid > 65535: + return + if size_hints["x"] < xblock: + return + if size_hints["y"] < yblock: + return + # all good, add the config + configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps, num_stages=num_stages)) + addConfig__(512, 8, 8,1 ) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 + addConfig__(32, 128, 4, 1) # wrt2: 570us : triton_poi_fused_add_transpose_view_52 + addConfig__(64, 32, 8, 1) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 + addConfig__(64, 256, 4, 1) # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19 + addConfig__(512, 64, 8, 1) # wri0: 58us: triton_poi_fused_clone_53 + if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta): configs = [triton_config_with_settings(size_hints, 16, 16, 16)] @@ -1952,13 +2072,17 @@ def _reduction_configs( # Convert reductions to 1D, to simplify heuristics. rnumel = get_total_reduction_numel(size_hints) + # Is max autotune enabled + max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ) + register_intensive = False MAX_R0_BLOCK = 2048 - if ( - size_hints["x"] >= 1024 - and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0) - >= 10 - ): + loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get( + "num_reduction", 0 + ) + if size_hints["x"] >= 1024 and loads_and_red >= 10: # A heuristics to reduce R0_BLOCK if a kernel potentially need many registers. # Consider load and reduction since load need move data into registers and # reduction needs an accumulator. @@ -1974,8 +2098,90 @@ def _reduction_configs( MAX_R0_BLOCK = 1024 register_intensive = True - contiguous_config = triton_config_reduction( - size_hints, + def make_config( + x, + r, + num_warps=None, + num_stages=1, + register_intensive=False, + dynamic_scale_rblock=True, + waves_per_eu=None, + ): + # For 3D case with tiling scores, create an adapted version + if "y" in size_hints: + assert "tiling_scores" in inductor_meta + return adapt_config_for_tiling( + size_hints, + inductor_meta["tiling_scores"], + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + waves_per_eu=waves_per_eu, + ) + else: + # For other cases, use the original function + return triton_config_reduction( + size_hints, + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + waves_per_eu=waves_per_eu, + dynamic_scale_rblock=dynamic_scale_rblock, + ) + + def outer_config_opt(): + # Default to 64 for vectorized loads + max_x_block, x_block = 256, 64 + load_factor = inductor_meta.get("num_load", 0) + x = size_hints["x"] + num_warps = None + + # Try to use all SMs with small x + if x <= 1024: + x_block = max(min(x // 128, 8), 2) + outer_r_block = min(rnumel, 64) + # Lower bound x = 1024, 1024 // 16 = 128 around # of SMs + elif x // 4096 <= 8: + x_block = 16 + outer_r_block = 512 // x_block + elif num_dynamic > 1: + # Lots of compute with multiple dynamic shape per loop iteration + # Larger RBLOCK minimizes loop iteration + outer_r_block = max(min((rnumel // 64), 64), 8) + elif num_dynamic == 1: + # Dynamic shapes introduce a lot register pressure for indexing + outer_r_block = ( + 1 + if load_factor >= 3 + else min(next_power_of_2(max(rnumel, 128) // 128), 8) + ) + else: + x_block = max(min(max_x_block, next_power_of_2(x // 4096)), x_block) + if load_factor < 4 or rnumel <= 128: + outer_r_block = 512 // x_block + else: + # Heavier reductions contain a lot more overhead per loop iteration + # We minimize the overhead by enlarging r block + if rnumel >= 2048: + outer_r_block = 64 + else: + outer_r_block = 32 + x_block = min(x_block, 32) + num_warps = 4 + + # Set register intensive to true by default as we try to maximize tiles with heuristic + return make_config( + x_block, + outer_r_block, + num_warps=num_warps, + register_intensive=register_intensive, + ) + + contiguous_config = make_config( 1, rnumel if 256 <= rnumel < MAX_R0_BLOCK else MAX_R0_BLOCK, register_intensive=register_intensive, @@ -1989,27 +2195,146 @@ def _reduction_configs( min(rnumel, MAX_R0_BLOCK), register_intensive=register_intensive, ) - if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"): - pass # skip all these cases + + outer_config = make_config(64, 8, register_intensive=register_intensive) + # TODO (paulzhan): Test heuristic on AMD and internal testing + # for correctness + if not torch.version.hip: + outer_config = outer_config_opt() + + configs = [] + + if inductor_meta.get("add_persistent_rblock") and loads_and_red <= 8: + xnumel = max(4096 // rnumel, 1) + c = make_config( + xnumel, + rnumel, + register_intensive=register_intensive, + dynamic_scale_rblock=False, + ) + configs.append(c) + + result_configs = [] + + # For 3d tiling, default to more autotuning initially + if "y" in size_hints: + pass + elif max_autotune_enabled: + pass # skip all these cases elif reduction_hint == ReductionHint.INNER: - return [contiguous_config] + return configs + [contiguous_config] elif reduction_hint == ReductionHint.OUTER: - return [outer_config] + return configs + [outer_config] elif reduction_hint == ReductionHint.OUTER_TINY: - return [tiny_config] - if disable_pointwise_autotuning(inductor_meta): - return [triton_config_reduction(size_hints, 32, 128)] - return [ - contiguous_config, - outer_config, - tiny_config, - triton_config_reduction(size_hints, 64, 64), - triton_config_reduction(size_hints, 8, 512), - # halve the XBLOCK/Rn_BLOCK compared to outer_config - # TODO: this may only be beneficial when each iteration of the reduction - # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 - triton_config_reduction(size_hints, 64, 4, num_warps=8), - ] + return configs + [tiny_config] + + # We continue here under the following conditions: + # - max_autotune_enabled is True + # - max_autotune_enabled is False and reduction_hint is NOT one of the above cases + result_configs = configs + [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] + + if torch.version.hip: + result_configs.extend( + [ + make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), + make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), + make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8 + make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4 + make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153 + make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16 + make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29 + ] + ) + + return result_configs + + +def match_target_block_product( + size_hints, tiling_scores, target_block_product, min_block_size=1 +): + """ + Distribute block sizes across dimensions according to tiling scores, + aiming to match a target product of block sizes. + """ + total_score = sum(tiling_scores.values()) + if total_score == 0: + # just assume even score with no minimum block size + min_block_size = 1 + tiling_scores = dict.fromkeys(tiling_scores.keys(), target_block_product) + + # First, give each coalescing dimension at least min_block_size + block_sizes = {} + relative_scores = {} + curr_block_product = 1 + + for dim, score in tiling_scores.items(): + if score == 0: + block_sizes[dim] = 1 + continue + + block_sizes[dim] = min_block_size + curr_block_product *= min_block_size + relative_scores[dim] = score / total_score + + # Scale up dimensions by their relative scores until we reach the target + while curr_block_product < target_block_product and len(relative_scores): + dim, score = max(relative_scores.items(), key=lambda item: item[1]) + + # Check if we've hit the max for this dimension + if ( + block_sizes[dim] >= TRITON_MAX_BLOCK[dim.capitalize()] + or block_sizes[dim] >= size_hints[dim] + ): + del relative_scores[dim] + continue + + block_sizes[dim] *= 2 + relative_scores[dim] /= 2 + curr_block_product *= 2 + + return block_sizes + + +def adapt_config_for_tiling( + size_hints, + tiling_scores, + original_x, + original_r, + num_warps=None, + num_stages=1, + register_intensive=False, + persistent_reduction=False, + waves_per_eu=None, +) -> Config: + """ + Create an adapted configuration based on tiling scores, + redistributing the same total block size (x * r) according to tiling scores. + """ + assert all(s in tiling_scores for s in size_hints) + target_block_product = original_x * original_r + block_sizes = match_target_block_product( + size_hints, tiling_scores, target_block_product + ) + + return triton_config_tiled_reduction( + size_hints, + block_sizes["x"], + block_sizes["y"], + block_sizes["r0_"], + num_stages=num_stages, + register_intensive=register_intensive, + waves_per_eu=waves_per_eu, + ) def reduction( @@ -2097,33 +2422,42 @@ def _persistent_reduction_configs( or inductor_meta.get("max_autotune_pointwise") ) + if torch.version.hip: + xblock_vals = [1, 4, 8, 16, 32, 64, 128, 256] + else: + xblock_vals = [1, 8, 32, 128] + configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) - for xblock in (1, 8, 32, 128) + for xblock in xblock_vals if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096)) ] + tiny_configs = [ + triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + rnumel, + ) + ] + + # defer to more autotuning, initially + if "y" in size_hints: + pass # TODO(jansel): we should be able to improve these heuristics - if not max_autotune_enabled: # Don't filter if tuning enabled + elif not max_autotune_enabled: # Don't filter if tuning enabled if reduction_hint == ReductionHint.INNER and rnumel >= 256: configs = configs[:1] elif reduction_hint == ReductionHint.OUTER: configs = configs[-1:] - - if reduction_hint == ReductionHint.OUTER_TINY: - tiny_configs = [ - triton_config_reduction( - size_hints, - 2 * (256 // rnumel) if rnumel <= 256 else 1, - rnumel, - ) - ] - if max_autotune_enabled: - for tconfig in tiny_configs: - if tconfig not in configs: - configs.append(tconfig) - else: - configs = tiny_configs + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = tiny_configs + else: + if torch.version.hip: + # If autotune is enabled append tiny configs + for conf in tiny_configs: + if conf not in configs: + configs.append(conf) for c in configs: # we don't need Rn_BLOCK for persistent reduction @@ -2275,9 +2609,10 @@ def foreach(triton_meta, filename=None, inductor_meta=None): Compile a triton foreach kernel """ configs = [] + + # Naive autotuning path for num_warps if disable_pointwise_autotuning(inductor_meta) and not ( - inductor_meta.get("max_autotune") or - inductor_meta.get("max_autotune_pointwise") + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") ): configs.append(triton.Config({}, num_stages=1, num_warps=8)) else: @@ -2293,6 +2628,7 @@ def foreach(triton_meta, filename=None, inductor_meta=None): filename=filename, ) + @dataclasses.dataclass class GridExpr: """Generate code for grid size expressions in launcher"""