Skip to content

Commit edcb143

Browse files
jataylojithunnair-amd
authored andcommitted
[SWDEV-539076] Initial naive foreach autotune support (#2377)
Adds initial autotuning for foreach support required for https://ontrack-internal.amd.com/browse/SWDEV-539076 4x improvement for some kernels Before: triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |   triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |   triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |   After: triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |   triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |   triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |   (cherry picked from commit f07b7f7) (cherry picked from commit ed0d0a7)
1 parent b324b36 commit edcb143

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def jit_line(
614614
if heuristics == "foreach":
615615
heuristics_line = f"""
616616
@triton_heuristics.foreach(
617-
num_warps={self.num_warps},
617+
filename=__file__,
618618
triton_meta={triton_meta!r},
619619
inductor_meta={inductor_meta!r},
620620
)

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3102,20 +3102,29 @@ def user_autotune(
31023102
)
31033103

31043104

3105-
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
3105+
def foreach(triton_meta, filename=None, inductor_meta=None):
31063106
"""
31073107
Compile a triton foreach kernel
31083108
"""
3109+
configs = []
3110+
if disable_pointwise_autotuning(inductor_meta) and not (
3111+
inductor_meta.get("max_autotune") or
3112+
inductor_meta.get("max_autotune_pointwise")
3113+
):
3114+
configs.append(triton.Config({}, num_stages=1, num_warps=8))
3115+
else:
3116+
for warps in [1, 2, 4, 8]:
3117+
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
3118+
31093119
return cached_autotune(
31103120
None,
3111-
[triton.Config({}, num_stages=1, num_warps=num_warps)],
3121+
configs,
31123122
triton_meta=triton_meta,
31133123
inductor_meta=inductor_meta,
31143124
heuristic_type=HeuristicType.TEMPLATE,
31153125
filename=filename,
31163126
)
31173127

3118-
31193128
@dataclasses.dataclass
31203129
class GridExpr:
31213130
"""Generate code for grid size expressions in launcher"""

0 commit comments

Comments
 (0)