Skip to content

Commit 5e67be1

Browse files
okakarpajataylo
authored andcommitted
[AUTOGENERATED] [release/2.8] [release/2.7] [SWDEV-543214] Reland #2416 Fix warps runtime part 2 (#2455)
Cherry-pick of #2442 Co-authored-by: Jack Taylor <[email protected]> (cherry picked from commit 77a6760)
1 parent 4142eef commit 5e67be1

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

torch/_inductor/runtime/coordinate_descent_tuner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,14 @@ def get_config_max(self, prefix: str) -> int:
6363

6464
@lru_cache(maxsize=1)
6565
def get_warpsmax(self):
66-
# Currently, CUDA has a maximum of 1024 threads, so 32 is the max
67-
# number of warps.
68-
return 1024 // 32
66+
# CUDA/ROCm has a maximum of 1024 threads per block
67+
from torch.cuda import current_device, get_device_properties, is_available
68+
69+
warp_size = (
70+
get_device_properties(current_device()).warp_size if is_available() else 32
71+
)
72+
73+
return 1024 // warp_size
6974

7075
def cache_benchmark_result(self, config, timing):
7176
self.cached_benchmark_results[triton_config_to_hashable(config)] = timing

0 commit comments

Comments
 (0)