File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
tests/tests_pytorch/plugins/precision Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -74,8 +74,8 @@ def test_fsdp_precision_scaler_with_bf16():
74
74
75
75
76
76
@RunIf (min_cuda_gpus = 1 )
77
- def test_fsdp_precision_forward_context ():
78
- """Test to ensure that the context manager correctly is set to bfloat16 ."""
77
+ def test_fsdp_precision_forward_context_f16 ():
78
+ """Test to ensure that the context manager correctly is set to float16 ."""
79
79
from torch .distributed .fsdp .sharded_grad_scaler import ShardedGradScaler
80
80
81
81
precision = FSDPPrecision (precision = "16-mixed" )
@@ -94,6 +94,10 @@ def test_fsdp_precision_forward_context():
94
94
assert isinstance (precision .forward_context (), _DtypeContextManager )
95
95
assert precision .forward_context ()._new_dtype == torch .float16
96
96
97
+
98
+ @RunIf (min_cuda_gpus = 1 , bf16_cuda = True )
99
+ def test_fsdp_precision_forward_context_bf16 ():
100
+ """Test to ensure that the context manager correctly is set to bfloat16."""
97
101
precision = FSDPPrecision (precision = "bf16-mixed" )
98
102
assert precision .scaler is None
99
103
with precision .forward_context ():
You can’t perform that action at this time.
0 commit comments