Skip to content

Commit e1632fc

Browse files
committed
applied pr2377
1 parent 7bcbafe commit e1632fc

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def jit_line(
614614
if heuristics == "foreach":
615615
heuristics_line = f"""
616616
@triton_heuristics.foreach(
617-
num_warps={self.num_warps},
617+
filename=__file__,
618618
triton_meta={triton_meta!r},
619619
inductor_meta={inductor_meta!r},
620620
)

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2922,13 +2922,23 @@ def user_autotune(
29222922
)
29232923

29242924

2925-
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
2925+
def foreach(triton_meta, filename=None, inductor_meta=None):
29262926
"""
29272927
Compile a triton foreach kernel
29282928
"""
2929+
configs = []
2930+
if disable_pointwise_autotuning(inductor_meta) and not (
2931+
inductor_meta.get("max_autotune") or
2932+
inductor_meta.get("max_autotune_pointwise")
2933+
):
2934+
configs.append(triton.Config({}, num_stages=1, num_warps=8))
2935+
else:
2936+
for warps in [1, 2, 4, 8]:
2937+
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
2938+
29292939
return cached_autotune(
29302940
None,
2931-
[triton.Config({}, num_stages=1, num_warps=num_warps)],
2941+
configs,
29322942
triton_meta=triton_meta,
29332943
inductor_meta=inductor_meta,
29342944
heuristic_type=HeuristicType.TEMPLATE,

0 commit comments

Comments
 (0)