Skip to content

Commit 9aae77b

Browse files
committed
remove _cuda_capability
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
1 parent e115c2b commit 9aae77b

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

thunder/tests/test_ops.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,6 @@ def fn(a):
423423
assert_close(b, b_ref)
424424

425425

426-
def _cuda_capability(device: torch.device) -> tuple[int, int]:
427-
return torch.cuda.get_device_capability(device)
428-
429-
430426
def _cuda_version_tuple() -> tuple[int, int] | None:
431427
if torch.version.cuda is None:
432428
return None
@@ -450,7 +446,7 @@ def wrapper(*args, **kwargs):
450446

451447

452448
def _ensure_fp8_tensorwise(device: torch.device) -> None:
453-
major, minor = _cuda_capability(device)
449+
major, minor = torch.cuda.get_device_capability(device)
454450
if (major, minor) < (8, 9):
455451
pytest.skip("scaled_mm tensor-wise support requires SM89 or newer")
456452

@@ -467,7 +463,7 @@ def wrapper(*args, **kwargs):
467463

468464
def _require_fp8_rowwise(device: torch.device) -> None:
469465
_ensure_fp8_tensorwise(device)
470-
major, minor = _cuda_capability(device)
466+
major, minor = torch.cuda.get_device_capability(device)
471467
if (major, minor) < (9, 0):
472468
pytest.skip("row-wise scaled_mm requires SM90 or newer")
473469
cuda_version = _cuda_version_tuple()

0 commit comments

Comments
 (0)