Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.throughput import get_float32_matmul_precision_compat


class CUDAAccelerator(Accelerator):
Expand Down Expand Up @@ -162,7 +163,7 @@ def _check_cuda_matmul_precision(device: torch.device) -> None:
return
# check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and
# `set_float32_matmul_precision`
if torch.get_float32_matmul_precision() == "highest": # default
if get_float32_matmul_precision_compat() == "highest": # default
rank_zero_info(
f"You are using a CUDA device ({torch.cuda.get_device_name(device)!r}) that has Tensor Cores. To properly"
" utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off"
Expand Down
33 changes: 31 additions & 2 deletions src/lightning/fabric/utilities/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820/composer/callbacks/speed_monitor.py
import warnings
from collections import deque
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union

import torch
from typing_extensions import override
Expand All @@ -27,6 +28,34 @@
_THROUGHPUT_METRICS = dict[str, Union[int, float]]


def get_float32_matmul_precision_compat() -> Literal["highest", "high", "medium"]:
"""Get the current float32 matmul precision using PyTorch 2.9+ compatible API."""
if not torch.cuda.is_available():
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return torch.get_float32_matmul_precision()
except Exception:
return "highest"

# Check if new API is available (PyTorch 2.9+)
if hasattr(torch.backends.cuda.matmul, "fp32_precision"):
precision_value = torch.backends.cuda.matmul.fp32_precision

if precision_value == "ieee":
return "highest"
if precision_value == "tf32":
return "medium"
return "highest"
# Fallback to old API for older PyTorch versions
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return torch.get_float32_matmul_precision()
except Exception:
return "highest"


# The API design of this class follows `torchmetrics.Metric` but it doesn't need to be an actual Metric because there's
# no need for synchronization or reduction as it doesn't use Tensors at all.
class Throughput:
Expand Down Expand Up @@ -607,7 +636,7 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) ->
if dtype is torch.float32:
from lightning.fabric.accelerators.cuda import _is_ampere_or_later

if _is_ampere_or_later() and torch.get_float32_matmul_precision() != "highest":
if _is_ampere_or_later() and get_float32_matmul_precision_compat() != "highest":
dtype = "tfloat32"
if dtype not in dtype_to_flops:
# for example, T4 doesn't support bfloat16. it might also be that we are missing this dtype from the list
Expand Down
22 changes: 22 additions & 0 deletions tests/tests_fabric/utilities/test_throughput.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from unittest import mock
from unittest.mock import Mock, call

Expand All @@ -11,6 +12,7 @@
ThroughputMonitor,
_MonotonicWindow,
get_available_flops,
get_float32_matmul_precision_compat,
measure_flops,
)
from tests_fabric.test_fabric import BoringModel
Expand Down Expand Up @@ -340,3 +342,23 @@ def test_monotonic_window():
w.append(2)
w.clear()
w.append(2)


def test_get_float32_matmul_precision_compat():
"""Test that the compatibility function works without warnings."""
precision = get_float32_matmul_precision_compat()
assert precision in ["highest", "high", "medium"]

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
precision = get_float32_matmul_precision_compat()

deprecation_warnings = [
warning
for warning in w
if "Please use the new API settings to control TF32 behavior" in str(warning.message)
]
assert len(deprecation_warnings) == 0, (
f"Compatibility function triggered {len(deprecation_warnings)} deprecation warnings"
)
assert precision in ["highest", "high", "medium"]
Loading