diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 5b8a4c2f80bed..2746a082b3e85 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -20,6 +20,7 @@ from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.accelerators.registry import _AcceleratorRegistry from lightning.fabric.utilities.rank_zero import rank_zero_info +from lightning.fabric.utilities.throughput import get_float32_matmul_precision_compat class CUDAAccelerator(Accelerator): @@ -162,7 +163,7 @@ def _check_cuda_matmul_precision(device: torch.device) -> None: return # check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and # `set_float32_matmul_precision` - if torch.get_float32_matmul_precision() == "highest": # default + if get_float32_matmul_precision_compat() == "highest": # default rank_zero_info( f"You are using a CUDA device ({torch.cuda.get_device_name(device)!r}) that has Tensor Cores. To properly" " utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off" diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index 6bc329fa1c3be..a162f08e500d3 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820/composer/callbacks/speed_monitor.py +import warnings from collections import deque -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union import torch from typing_extensions import override @@ -27,6 +28,34 @@ _THROUGHPUT_METRICS = dict[str, Union[int, float]] +def get_float32_matmul_precision_compat() -> Literal["highest", "high", "medium"]: + """Get the current float32 matmul precision using PyTorch 2.9+ compatible API.""" + if not torch.cuda.is_available(): + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return torch.get_float32_matmul_precision() + except Exception: + return "highest" + + # Check if new API is available (PyTorch 2.9+) + if hasattr(torch.backends.cuda.matmul, "fp32_precision"): + precision_value = torch.backends.cuda.matmul.fp32_precision + + if precision_value == "ieee": + return "highest" + if precision_value == "tf32": + return "medium" + return "highest" + # Fallback to old API for older PyTorch versions + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return torch.get_float32_matmul_precision() + except Exception: + return "highest" + + # The API design of this class follows `torchmetrics.Metric` but it doesn't need to be an actual Metric because there's # no need for synchronization or reduction as it doesn't use Tensors at all. class Throughput: @@ -607,7 +636,7 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) -> if dtype is torch.float32: from lightning.fabric.accelerators.cuda import _is_ampere_or_later - if _is_ampere_or_later() and torch.get_float32_matmul_precision() != "highest": + if _is_ampere_or_later() and get_float32_matmul_precision_compat() != "highest": dtype = "tfloat32" if dtype not in dtype_to_flops: # for example, T4 doesn't support bfloat16. it might also be that we are missing this dtype from the list diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index a175fa97fd444..9340024a5e0e3 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -1,3 +1,4 @@ +import warnings from unittest import mock from unittest.mock import Mock, call @@ -11,6 +12,7 @@ ThroughputMonitor, _MonotonicWindow, get_available_flops, + get_float32_matmul_precision_compat, measure_flops, ) from tests_fabric.test_fabric import BoringModel @@ -340,3 +342,23 @@ def test_monotonic_window(): w.append(2) w.clear() w.append(2) + + +def test_get_float32_matmul_precision_compat(): + """Test that the compatibility function works without warnings.""" + precision = get_float32_matmul_precision_compat() + assert precision in ["highest", "high", "medium"] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + precision = get_float32_matmul_precision_compat() + + deprecation_warnings = [ + warning + for warning in w + if "Please use the new API settings to control TF32 behavior" in str(warning.message) + ] + assert len(deprecation_warnings) == 0, ( + f"Compatibility function triggered {len(deprecation_warnings)} deprecation warnings" + ) + assert precision in ["highest", "high", "medium"]