Skip to content

Commit e30133a

Browse files
authored
Disabled TF32 on Amper+ devices to stabilize numeric accuracy (#2579)
1 parent 25ce595 commit e30133a

File tree

6 files changed

+26
-12
lines changed

6 files changed

+26
-12
lines changed

thunder/tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,11 @@ def pytest_collection_modifyitems(items):
7777

7878
def pytest_addoption(parser):
7979
parser.addoption("--gpu-mem-limit", type=float)
80+
81+
82+
@pytest.fixture
83+
def turn_off_tf32_and_set_seed(monkeypatch):
84+
monkeypatch.setenv("NVIDIA_TF32_OVERRIDE", "0")
85+
torch.manual_seed(42)
86+
yield
87+
torch.seed()

thunder/tests/distributed/test_tensor_parallel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ def forward(self, x):
131131
actual=tp_jitted_model.get_parameter("embed.weight").grad,
132132
)
133133

134+
# Note: When running with TF32 enabled on CUDA, the maximum absolute difference between outputs
135+
# can be on the order of 1e-3, which exceeds the default tolerances for torch.testing.assert_close.
136+
# This is expected due to the reduced precision of TF32 matrix multiplications.
137+
@pytest.mark.usefixtures("turn_off_tf32_and_set_seed")
134138
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="")
135139
@common_utils.parametrize("bias", (True, False))
136140
def test_both_column_and_row(self, bias):
@@ -154,6 +158,7 @@ def forward(self, x):
154158
return h
155159

156160
device = torch.device("cuda", self.rank)
161+
157162
x = torch.randint(0, num_embeddings - 1, (16, 16), device=device)
158163
x_ref = x.clone().detach()
159164

thunder/tests/test_grad.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1487,8 +1487,11 @@ def test_populate_grads_block(executor, device, dtype):
14871487
assert_close(torch_grads, thunder_grads, atol=1e-2, rtol=1e-2)
14881488

14891489

1490+
# Note: When running with TF32 enabled on CUDA, the maximum absolute difference between outputs
1491+
# can be on the order of 1e-3, which exceeds the default tolerances for torch.testing.assert_close.
1492+
# This is expected due to the reduced precision of TF32 matrix multiplications.
14901493
@instantiate(dtypes=(thunder.float32,))
1491-
def test_populate_grads_nanogpt(executor, device, dtype):
1494+
def test_populate_grads_nanogpt(executor, device, dtype, turn_off_tf32_and_set_seed):
14921495
import sys
14931496

14941497
if sys.platform == "win32":

thunder/tests/test_jit_general.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,9 @@ def test_nanogpt():
649649
assert_close(result, module(*args, **kwargs))
650650

651651

652+
# Note: When running with TF32 enabled on CUDA, the maximum absolute difference between outputs
653+
# can be on the order of 1e-3, which exceeds the default tolerances for torch.testing.assert_close.
654+
# This is expected due to the reduced precision of TF32 matrix multiplications.
652655
@skipif_not_pytorch_2_1
653656
@pytest.mark.parametrize(
654657
"name",
@@ -668,7 +671,7 @@ def test_nanogpt():
668671
"device",
669672
("cpu", "cuda", "meta"),
670673
)
671-
def test_litgpt_variants(name, device):
674+
def test_litgpt_variants(name, device, turn_off_tf32_and_set_seed):
672675
from thunder.tests.litgpt_model import Config
673676
from litgpt.model import GPT
674677

@@ -704,6 +707,9 @@ def test_litgpt_variants(name, device):
704707
torch.testing.assert_close(param1.grad, param2.grad, rtol=1e-2, atol=1e-2)
705708

706709

710+
# Note: When running with TF32 enabled on CUDA, the maximum absolute difference between outputs
711+
# can be on the order of 1e-3, which exceeds the default tolerances for torch.testing.assert_close.
712+
# This is expected due to the reduced precision of TF32 matrix multiplications.
707713
@skipif_not_pytorch_2_1
708714
@pytest.mark.parametrize(
709715
"name",
@@ -724,7 +730,7 @@ def test_litgpt_variants(name, device):
724730
"device",
725731
("cpu", "cuda"),
726732
)
727-
def test_litgpt_variants_kvcache(name, device):
733+
def test_litgpt_variants_kvcache(name, device, turn_off_tf32_and_set_seed):
728734
from thunder.tests.litgpt_model import Config
729735
from litgpt.model import GPT
730736
import torch._dynamo # this monkeypatches torch.manual_seed

thunder/tests/test_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
# see https://docs.pytest.org/en/stable/how-to/capture-warnings.html#recwarn for the recwarn fixture
3838
@instantiate(dtypes=(thunder.float32,), executors=all_test_executors_and_dynamo)
39-
def test_nanogpt_complete(executor, device, dtype, recwarn):
39+
def test_nanogpt_complete(executor, device, dtype, recwarn, turn_off_tf32_and_set_seed):
4040
tdtype = ttorch.to_torch_dtype(dtype)
4141
make = partial(make_tensor, dtype=torch.int64, device=device)
4242

thunder/tests/test_update_aliases.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,6 @@ def inplace_masked_fill_sample_generator(op, device, dtype, requires_grad, **kwa
8282
_inplace_opinfos.append(inplace_opinfo)
8383

8484

85-
@pytest.fixture
86-
def turn_off_tf32_and_set_seed(monkeypatch):
87-
monkeypatch.setenv("NVIDIA_TF32_OVERRIDE", "0")
88-
torch.manual_seed(42)
89-
yield
90-
torch.seed()
91-
92-
9385
@instantiate(
9486
dtypes=(thunder.float32, thunder.float64),
9587
devicetypes=(devices.DeviceType.CUDA,),

0 commit comments

Comments
 (0)