Skip to content

Commit e5e987e

Browse files
jataylojithunnair-amd
authored andcommitted
[release/2.7] [SWDEV-543214] Reland #2416 Fix warps runtime (#2421)
Relands #2416 with caching fix Upstream equivalent pytorch#159146 --------- Co-authored-by: Jithun Nair <[email protected]> (cherry picked from commit f0aebdc) (cherry picked from commit 9c429dd)
1 parent 6ac6cac commit e5e987e

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torch/_inductor/runtime/coordinate_descent_tuner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import logging
55
from typing import Callable, Optional, TYPE_CHECKING
6+
from functools import lru_cache
67

78
from .hints import TRITON_MAX_BLOCK
89
from .runtime_utils import red_text, triton_config_to_hashable
@@ -60,6 +61,7 @@ def get_config_max(self, prefix: str) -> int:
6061
size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None
6162
return min(max_block, size_hint) if size_hint is not None else max_block
6263

64+
@lru_cache(maxsize=1)
6365
def get_warpsmax(self):
6466
# Currently, CUDA has a maximum of 1024 threads, so 32 is the max
6567
# number of warps.

0 commit comments

Comments
 (0)