Skip to content

Commit ba67e72

Browse files
Merge commit 'fa0c2bdfa4b907700624d7dd6ffbba2f9f8e10e4'
2 parents 1bd6d43 + fa0c2bd commit ba67e72

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

python/triton/_internal_testing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,13 @@ def to_numpy(x):
142142

143143

144144
def supports_tma(byval_only=False):
145+
if not is_cuda():
146+
return False
145147
_, cuda_version = _path_to_binary("ptxas")
146148
min_cuda_version = (12, 0) if byval_only else (12, 3)
147149
cuda_version_tuple = tuple(map(int, cuda_version.split(".")))
148150
assert len(cuda_version_tuple) == 2, cuda_version_tuple
149-
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version
151+
return torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version
150152

151153

152154
def tma_skip_msg(byval_only=False):

0 commit comments

Comments
 (0)