@@ -81,6 +81,8 @@ def _assert_layer_fsdp_instance(self) -> None:
8181 param_dtype = reduce_dtype = buffer_dtype = torch .float16
8282 elif self .trainer .precision in ("bf16-true" , "bf16-mixed" ):
8383 param_dtype = reduce_dtype = buffer_dtype = torch .bfloat16
84+ elif self .trainer .precision == "32-true" :
85+ param_dtype = reduce_dtype = buffer_dtype = torch .float32
8486 else :
8587 raise ValueError (f"Unknown precision { self .trainer .precision } " )
8688
@@ -215,7 +217,7 @@ def test_strategy_sync_batchnorm(tmp_path):
215217 accelerator = "gpu" ,
216218 devices = 2 ,
217219 strategy = "fsdp" ,
218- precision = "16-mixed " ,
220+ precision = "32-true " ,
219221 max_epochs = 1 ,
220222 sync_batchnorm = True ,
221223 )
@@ -255,7 +257,7 @@ def training_step(self, batch, batch_idx):
255257
256258@pytest .mark .filterwarnings ("ignore::FutureWarning" )
257259@RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
258- @pytest .mark .parametrize ("precision" , ["16-mixed " , pytest .param ("bf16-mixed" , marks = RunIf (bf16_cuda = True ))])
260+ @pytest .mark .parametrize ("precision" , ["32-true " , pytest .param ("bf16-mixed" , marks = RunIf (bf16_cuda = True ))])
259261@pytest .mark .parametrize ("state_dict_type" , ["sharded" , "full" ])
260262def test_strategy_checkpoint (state_dict_type , precision , tmp_path ):
261263 """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
@@ -347,7 +349,7 @@ def test_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg):
347349 accelerator = "gpu" ,
348350 devices = 2 ,
349351 strategy = strategy ,
350- precision = "16-mixed " ,
352+ precision = "32-true " ,
351353 max_epochs = 1 ,
352354 limit_train_batches = 2 ,
353355 limit_val_batches = 2 ,
0 commit comments