Skip to content

Commit 74d7455

Browse files
committed
Fix bad merge of triton_heuristics.py
Even though I had picked all upstream changes during merge conflicts, other parts that didn't have conflicts still picked local changes. Now, this file is broken with missing symbols. I am just copying the upstream file into this branch now.
1 parent c56fe7d commit 74d7455

File tree

1 file changed

+6
-24
lines changed

1 file changed

+6
-24
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def _dynamic_scale_rblock(self):
530530
# = regs_per_multiprocessor / (nreg * 32 * num_warps)
531531
# < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
532532
# = max_threads_per_multi_processor / (32 * num_warps)
533-
# Using a tigher upper bound can reveal more optimization opportunities.
533+
# Using a tighter upper bound can reveal more optimization opportunities.
534534
max_blocks_per_sm = max(
535535
device_prop.regs_per_multiprocessor // nreg_per_block, 1
536536
)
@@ -3241,16 +3241,7 @@ def _persistent_reduction_configs(
32413241
"num_store", 0
32423242
)
32433243

3244-
max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or (
3245-
inductor_meta.get("max_autotune")
3246-
or inductor_meta.get("max_autotune_pointwise")
3247-
)
3248-
3249-
configs = [
3250-
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
3251-
for xblock in (1, 8, 32, 128)
3252-
if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096))
3253-
]
3244+
MAX_PERSISTENT_BLOCK_NUMEL = 4096
32543245

32553246
if triton_meta.get("native_matmul"):
32563247
if len(size_hints) == 3:
@@ -3286,7 +3277,7 @@ def _persistent_reduction_configs(
32863277
)
32873278
for xblock in xblock_vals
32883279
if xblock == 1
3289-
or (rnumel * xblock <= 4096 and xblock <= xnumel)
3280+
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
32903281
]
32913282
else:
32923283
configs = []
@@ -3559,29 +3550,20 @@ def user_autotune(
35593550
)
35603551

35613552

3562-
def foreach(triton_meta, filename=None, inductor_meta=None):
3553+
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
35633554
"""
35643555
Compile a triton foreach kernel
35653556
"""
3566-
configs = []
3567-
if disable_pointwise_autotuning(inductor_meta) and not (
3568-
inductor_meta.get("max_autotune") or
3569-
inductor_meta.get("max_autotune_pointwise")
3570-
):
3571-
configs.append(triton.Config({}, num_stages=1, num_warps=8))
3572-
else:
3573-
for warps in [1, 2, 4, 8]:
3574-
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
3575-
35763557
return cached_autotune(
35773558
None,
3578-
configs,
3559+
[triton.Config({}, num_stages=1, num_warps=num_warps)],
35793560
triton_meta=triton_meta,
35803561
inductor_meta=inductor_meta,
35813562
heuristic_type=HeuristicType.TEMPLATE,
35823563
filename=filename,
35833564
)
35843565

3566+
35853567
@dataclasses.dataclass
35863568
class GridExpr:
35873569
"""Generate code for grid size expressions in launcher"""

0 commit comments

Comments
 (0)