Skip to content

Commit 6349cb6

Browse files
committed
test_fsdp_precision_forward_context_bf16
1 parent 069d15a commit 6349cb6

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/tests_pytorch/plugins/precision/test_fsdp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def test_fsdp_precision_scaler_with_bf16():
7474

7575

7676
@RunIf(min_cuda_gpus=1)
77-
def test_fsdp_precision_forward_context():
78-
"""Test to ensure that the context manager correctly is set to bfloat16."""
77+
def test_fsdp_precision_forward_context_f16():
78+
"""Test to ensure that the context manager correctly is set to float16."""
7979
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
8080

8181
precision = FSDPPrecision(precision="16-mixed")
@@ -94,6 +94,10 @@ def test_fsdp_precision_forward_context():
9494
assert isinstance(precision.forward_context(), _DtypeContextManager)
9595
assert precision.forward_context()._new_dtype == torch.float16
9696

97+
98+
@RunIf(min_cuda_gpus=1, bf16_cuda=True)
99+
def test_fsdp_precision_forward_context_bf16():
100+
"""Test to ensure that the context manager correctly is set to bfloat16."""
97101
precision = FSDPPrecision(precision="bf16-mixed")
98102
assert precision.scaler is None
99103
with precision.forward_context():

0 commit comments

Comments
 (0)