Skip to content

Commit 059dad7

Browse files
committed
add testing
1 parent 7989c11 commit 059dad7

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/tests_fabric/utilities/test_throughput.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from unittest import mock
23
from unittest.mock import Mock, call
34

@@ -11,6 +12,7 @@
1112
ThroughputMonitor,
1213
_MonotonicWindow,
1314
get_available_flops,
15+
get_float32_matmul_precision_compat,
1416
measure_flops,
1517
)
1618
from tests_fabric.test_fabric import BoringModel
@@ -340,3 +342,23 @@ def test_monotonic_window():
340342
w.append(2)
341343
w.clear()
342344
w.append(2)
345+
346+
347+
def test_get_float32_matmul_precision_compat():
348+
"""Test that the compatibility function works without warnings."""
349+
precision = get_float32_matmul_precision_compat()
350+
assert precision in ["highest", "high", "medium"]
351+
352+
with warnings.catch_warnings(record=True) as w:
353+
warnings.simplefilter("always")
354+
precision = get_float32_matmul_precision_compat()
355+
356+
deprecation_warnings = [
357+
warning
358+
for warning in w
359+
if "Please use the new API settings to control TF32 behavior" in str(warning.message)
360+
]
361+
assert len(deprecation_warnings) == 0, (
362+
f"Compatibility function triggered {len(deprecation_warnings)} deprecation warnings"
363+
)
364+
assert precision in ["highest", "high", "medium"]

0 commit comments

Comments
 (0)