File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
tests/tests_pytorch/plugins/precision Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -74,8 +74,8 @@ def test_fsdp_precision_scaler_with_bf16():
7474
7575
7676@RunIf (min_cuda_gpus = 1 )
77- def test_fsdp_precision_forward_context ():
78- """Test to ensure that the context manager correctly is set to bfloat16 ."""
77+ def test_fsdp_precision_forward_context_f16 ():
78+ """Test to ensure that the context manager correctly is set to float16 ."""
7979 from torch .distributed .fsdp .sharded_grad_scaler import ShardedGradScaler
8080
8181 precision = FSDPPrecision (precision = "16-mixed" )
@@ -94,6 +94,10 @@ def test_fsdp_precision_forward_context():
9494 assert isinstance (precision .forward_context (), _DtypeContextManager )
9595 assert precision .forward_context ()._new_dtype == torch .float16
9696
97+
98+ @RunIf (min_cuda_gpus = 1 , bf16_cuda = True )
99+ def test_fsdp_precision_forward_context_bf16 ():
100+ """Test to ensure that the context manager correctly is set to bfloat16."""
97101 precision = FSDPPrecision (precision = "bf16-mixed" )
98102 assert precision .scaler is None
99103 with precision .forward_context ():
You can’t perform that action at this time.
0 commit comments