@@ -423,10 +423,6 @@ def fn(a):
423423 assert_close (b , b_ref )
424424
425425
426- def _cuda_capability (device : torch .device ) -> tuple [int , int ]:
427- return torch .cuda .get_device_capability (device )
428-
429-
430426def _cuda_version_tuple () -> tuple [int , int ] | None :
431427 if torch .version .cuda is None :
432428 return None
@@ -450,7 +446,7 @@ def wrapper(*args, **kwargs):
450446
451447
452448def _ensure_fp8_tensorwise (device : torch .device ) -> None :
453- major , minor = _cuda_capability (device )
449+ major , minor = torch . cuda . get_device_capability (device )
454450 if (major , minor ) < (8 , 9 ):
455451 pytest .skip ("scaled_mm tensor-wise support requires SM89 or newer" )
456452
@@ -467,7 +463,7 @@ def wrapper(*args, **kwargs):
467463
468464def _require_fp8_rowwise (device : torch .device ) -> None :
469465 _ensure_fp8_tensorwise (device )
470- major , minor = _cuda_capability (device )
466+ major , minor = torch . cuda . get_device_capability (device )
471467 if (major , minor ) < (9 , 0 ):
472468 pytest .skip ("row-wise scaled_mm requires SM90 or newer" )
473469 cuda_version = _cuda_version_tuple ()
0 commit comments