Skip to content

Commit 7989c11

Browse files
committed
replace
1 parent 03dfed9 commit 7989c11

File tree

1 file changed

+2
-1
lines changed
  • src/lightning/fabric/accelerators

1 file changed

+2
-1
lines changed

src/lightning/fabric/accelerators/cuda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from lightning.fabric.accelerators.accelerator import Accelerator
2121
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
2222
from lightning.fabric.utilities.rank_zero import rank_zero_info
23+
from lightning.fabric.utilities.throughput import get_float32_matmul_precision_compat
2324

2425

2526
class CUDAAccelerator(Accelerator):
@@ -162,7 +163,7 @@ def _check_cuda_matmul_precision(device: torch.device) -> None:
162163
return
163164
# check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and
164165
# `set_float32_matmul_precision`
165-
if torch.get_float32_matmul_precision() == "highest": # default
166+
if get_float32_matmul_precision_compat() == "highest": # default
166167
rank_zero_info(
167168
f"You are using a CUDA device ({torch.cuda.get_device_name(device)!r}) that has Tensor Cores. To properly"
168169
" utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off"

0 commit comments

Comments
 (0)