Skip to content

Commit bb4009a

Browse files
jataylonaromero77amd
authored andcommitted
[Inductor] Naive foreach autotune support (pytorch#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code. Before: triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 | triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 | triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 | After: triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 | triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 | triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 | num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374 Pull Request resolved: pytorch#162053 Approved by: https://github.com/mlazos, https://github.com/naromero77amd, https://github.com/jeffdaily Co-authored-by: Nichols A. Romero <[email protected]>
1 parent 9e9e8fa commit bb4009a

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def jit_line(
627627
if heuristics == "foreach":
628628
heuristics_line = f"""
629629
@triton_heuristics.foreach(
630-
num_warps={self.num_warps},
630+
filename=__file__,
631631
triton_meta={triton_meta!r},
632632
inductor_meta={inductor_meta!r},
633633
)

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3621,13 +3621,24 @@ def user_autotune(
36213621
)
36223622

36233623

3624-
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
3624+
def foreach(triton_meta, filename=None, inductor_meta=None):
36253625
"""
36263626
Compile a triton foreach kernel
36273627
"""
3628+
configs = []
3629+
3630+
# Naive autotuning path for num_warps
3631+
if not (
3632+
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
3633+
):
3634+
configs.append(triton.Config({}, num_stages=1, num_warps=8))
3635+
else:
3636+
for warps in [1, 2, 4, 8]:
3637+
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
3638+
36283639
return cached_autotune(
36293640
None,
3630-
[triton.Config({}, num_stages=1, num_warps=num_warps)],
3641+
configs,
36313642
triton_meta=triton_meta,
36323643
inductor_meta=inductor_meta,
36333644
heuristic_type=HeuristicType.TEMPLATE,

0 commit comments

Comments
 (0)