Skip to content

Commit 15ba5b5

Browse files
okakarpajataylo
andauthored
[AUTOGENERATED] [release/2.6] [release/2.7] [SWDEV-543214] Reland #2416 Fix warps runtime part 2 (#2451)
Cherry-pick of #2442 --------- Co-authored-by: Jack Taylor <[email protected]>
1 parent fb0db08 commit 15ba5b5

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

torch/_inductor/runtime/coordinate_descent_tuner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import logging
55
from typing import Callable, Optional
6+
from functools import lru_cache
67

78
from .hints import TRITON_MAX_BLOCK
89
from .runtime_utils import red_text, triton_config_to_hashable
@@ -59,10 +60,16 @@ def get_config_max(self, prefix: str) -> int:
5960
size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None
6061
return min(max_block, size_hint) if size_hint is not None else max_block
6162

63+
@lru_cache(maxsize=1)
6264
def get_warpsmax(self):
63-
# Currently, CUDA has a maximum of 1024 threads, so 32 is the max
64-
# number of warps.
65-
return 1024 // 32
65+
# CUDA/ROCm has a maximum of 1024 threads per block
66+
from torch.cuda import current_device, get_device_properties, is_available
67+
68+
warp_size = (
69+
get_device_properties(current_device()).warp_size if is_available() else 32
70+
)
71+
72+
return 1024 // warp_size
6673

6774
def cache_benchmark_result(self, config, timing):
6875
self.cached_benchmark_results[triton_config_to_hashable(config)] = timing

0 commit comments

Comments
 (0)