Skip to content

Commit 03dfed9

Browse files
committed
add compatibility function
1 parent 7d0d0ad commit 03dfed9

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

src/lightning/fabric/utilities/throughput.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820/composer/callbacks/speed_monitor.py
15+
import warnings
1516
from collections import deque
16-
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
17+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union
1718

1819
import torch
1920
from typing_extensions import override
@@ -27,6 +28,34 @@
2728
_THROUGHPUT_METRICS = dict[str, Union[int, float]]
2829

2930

31+
def get_float32_matmul_precision_compat() -> Literal["highest", "high", "medium"]:
32+
"""Get the current float32 matmul precision using PyTorch 2.9+ compatible API."""
33+
if not torch.cuda.is_available():
34+
try:
35+
with warnings.catch_warnings():
36+
warnings.simplefilter("ignore")
37+
return torch.get_float32_matmul_precision()
38+
except Exception:
39+
return "highest"
40+
41+
# Check if new API is available (PyTorch 2.9+)
42+
if hasattr(torch.backends.cuda.matmul, "fp32_precision"):
43+
precision_value = torch.backends.cuda.matmul.fp32_precision
44+
45+
if precision_value == "ieee":
46+
return "highest"
47+
if precision_value == "tf32":
48+
return "medium"
49+
return "highest"
50+
# Fallback to old API for older PyTorch versions
51+
try:
52+
with warnings.catch_warnings():
53+
warnings.simplefilter("ignore")
54+
return torch.get_float32_matmul_precision()
55+
except Exception:
56+
return "highest"
57+
58+
3059
# The API design of this class follows `torchmetrics.Metric` but it doesn't need to be an actual Metric because there's
3160
# no need for synchronization or reduction as it doesn't use Tensors at all.
3261
class Throughput:
@@ -607,7 +636,7 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) ->
607636
if dtype is torch.float32:
608637
from lightning.fabric.accelerators.cuda import _is_ampere_or_later
609638

610-
if _is_ampere_or_later() and torch.get_float32_matmul_precision() != "highest":
639+
if _is_ampere_or_later() and get_float32_matmul_precision_compat() != "highest":
611640
dtype = "tfloat32"
612641
if dtype not in dtype_to_flops:
613642
# for example, T4 doesn't support bfloat16. it might also be that we are missing this dtype from the list

0 commit comments

Comments
 (0)