@@ -210,15 +210,15 @@ def test_invalid_on_cpu(tmp_path, cuda_count_0):
210
210
trainer .strategy .setup_environment ()
211
211
212
212
213
- def test_fsdp_custom_mixed_precision ():
213
+ def test_custom_mixed_precision ():
214
214
"""Test to ensure that passing a custom mixed precision config works."""
215
215
config = MixedPrecision ()
216
216
strategy = FSDPStrategy (mixed_precision = config )
217
217
assert strategy .mixed_precision_config == config
218
218
219
219
220
220
@RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
221
- def test_fsdp_strategy_sync_batchnorm (tmp_path ):
221
+ def test_strategy_sync_batchnorm (tmp_path ):
222
222
"""Test to ensure that sync_batchnorm works when using FSDP and GPU, and all stages can be run."""
223
223
model = TestFSDPModel ()
224
224
trainer = Trainer (
@@ -234,7 +234,7 @@ def test_fsdp_strategy_sync_batchnorm(tmp_path):
234
234
235
235
236
236
@RunIf (min_cuda_gpus = 1 , skip_windows = True )
237
- def test_fsdp_modules_without_parameters (tmp_path ):
237
+ def test_modules_without_parameters (tmp_path ):
238
238
"""Test that TorchMetrics get moved to the device despite not having any parameters."""
239
239
240
240
class MetricsModel (BoringModel ):
@@ -266,7 +266,7 @@ def training_step(self, batch, batch_idx):
266
266
@RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
267
267
@pytest .mark .parametrize ("precision" , ["16-mixed" , pytest .param ("bf16-mixed" , marks = RunIf (bf16_cuda = True ))])
268
268
@pytest .mark .parametrize ("state_dict_type" , ["sharded" , "full" ])
269
- def test_fsdp_strategy_checkpoint (state_dict_type , precision , tmp_path ):
269
+ def test_strategy_checkpoint (state_dict_type , precision , tmp_path ):
270
270
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
271
271
model = TestFSDPModel ()
272
272
strategy = FSDPStrategy (state_dict_type = state_dict_type )
@@ -286,7 +286,7 @@ def custom_auto_wrap_policy(
286
286
287
287
@RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
288
288
@pytest .mark .parametrize ("wrap_min_params" , [2 , 1024 , 100000000 ])
289
- def test_fsdp_strategy_full_state_dict (tmp_path , wrap_min_params ):
289
+ def test_strategy_full_state_dict (tmp_path , wrap_min_params ):
290
290
"""Test to ensure that the full state dict is extracted when using FSDP strategy.
291
291
292
292
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all.
@@ -342,7 +342,7 @@ def test_fsdp_strategy_full_state_dict(tmp_path, wrap_min_params):
342
342
),
343
343
],
344
344
)
345
- def test_fsdp_checkpoint_multi_gpus (tmp_path , model , strategy , strategy_cfg ):
345
+ def test_checkpoint_multi_gpus (tmp_path , model , strategy , strategy_cfg ):
346
346
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
347
347
ck = ModelCheckpoint (save_last = True )
348
348
@@ -410,7 +410,7 @@ def configure_optimizers(self):
410
410
trainer .fit (model )
411
411
412
412
413
- def test_fsdp_forbidden_precision_raises ():
413
+ def test_forbidden_precision_raises ():
414
414
with pytest .raises (TypeError , match = "can only work with the `FSDPPrecision" ):
415
415
FSDPStrategy (precision_plugin = HalfPrecision ())
416
416
@@ -419,7 +419,7 @@ def test_fsdp_forbidden_precision_raises():
419
419
strategy .precision_plugin = HalfPrecision ()
420
420
421
421
422
- def test_fsdp_activation_checkpointing ():
422
+ def test_activation_checkpointing ():
423
423
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
424
424
425
425
class Block1 (nn .Linear ):
@@ -469,7 +469,7 @@ def __init__(self):
469
469
apply_mock .assert_called_with (wrapped , checkpoint_wrapper_fn = ANY , ** strategy ._activation_checkpointing_kwargs )
470
470
471
471
472
- def test_fsdp_strategy_cpu_offload ():
472
+ def test_strategy_cpu_offload ():
473
473
"""Test the different ways cpu offloading can be enabled."""
474
474
# bool
475
475
strategy = FSDPStrategy (cpu_offload = True )
@@ -481,7 +481,7 @@ def test_fsdp_strategy_cpu_offload():
481
481
assert strategy .cpu_offload == config
482
482
483
483
484
- def test_fsdp_sharding_strategy ():
484
+ def test_sharding_strategy ():
485
485
"""Test the different ways the sharding strategy can be set."""
486
486
from torch .distributed .fsdp import ShardingStrategy
487
487
@@ -501,7 +501,7 @@ def test_fsdp_sharding_strategy():
501
501
502
502
503
503
@pytest .mark .parametrize ("sharding_strategy" , ["HYBRID_SHARD" , "_HYBRID_SHARD_ZERO2" ])
504
- def test_fsdp_hybrid_sharding_strategy (sharding_strategy ):
504
+ def test_hybrid_sharding_strategy (sharding_strategy ):
505
505
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
506
506
with pytest .raises (RuntimeError , match = "The hybrid sharding strategy requires you to pass at least one of" ):
507
507
FSDPStrategy (sharding_strategy = sharding_strategy )
@@ -523,7 +523,7 @@ def test_fsdp_hybrid_sharding_strategy(sharding_strategy):
523
523
FSDPStrategy (sharding_strategy = sharding_strategy , process_group = process_group , device_mesh = device_mesh )
524
524
525
525
526
- def test_fsdp_use_orig_params ():
526
+ def test_use_orig_params ():
527
527
"""Test that Lightning enables `use_orig_params` automatically."""
528
528
strategy = FSDPStrategy ()
529
529
assert strategy .kwargs ["use_orig_params" ]
@@ -548,7 +548,7 @@ def test_set_timeout(init_process_group_mock):
548
548
549
549
550
550
@mock .patch ("lightning.pytorch.strategies.fsdp._load_raw_module_state" )
551
- def test_fsdp_strategy_load_optimizer_states_multiple (_ , tmp_path ):
551
+ def test_strategy_load_optimizer_states_multiple (_ , tmp_path ):
552
552
strategy = FSDPStrategy (parallel_devices = [torch .device ("cpu" )], state_dict_type = "full" )
553
553
trainer = Trainer ()
554
554
trainer .state .fn = TrainerFn .FITTING
@@ -572,7 +572,7 @@ def test_fsdp_strategy_load_optimizer_states_multiple(_, tmp_path):
572
572
573
573
@RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
574
574
@pytest .mark .parametrize ("wrap_min_params" , [2 , 1024 , 100000000 ])
575
- def test_fsdp_strategy_save_optimizer_states (tmp_path , wrap_min_params ):
575
+ def test_strategy_save_optimizer_states (tmp_path , wrap_min_params ):
576
576
"""Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy.
577
577
578
578
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
@@ -630,7 +630,7 @@ def test_fsdp_strategy_save_optimizer_states(tmp_path, wrap_min_params):
630
630
631
631
@RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
632
632
@pytest .mark .parametrize ("wrap_min_params" , [2 , 1024 , 100000000 ])
633
- def test_fsdp_strategy_load_optimizer_states (wrap_min_params , tmp_path ):
633
+ def test_strategy_load_optimizer_states (wrap_min_params , tmp_path ):
634
634
"""Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy.
635
635
636
636
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model
@@ -741,7 +741,7 @@ def test_save_checkpoint_storage_options(tmp_path):
741
741
@mock .patch ("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context" )
742
742
@mock .patch ("lightning.fabric.plugins.io.torch_io._atomic_save" )
743
743
@mock .patch ("lightning.pytorch.strategies.fsdp.shutil" )
744
- def test_fsdp_save_checkpoint_path_exists (shutil_mock , torch_save_mock , __ , ___ , tmp_path ):
744
+ def test_save_checkpoint_path_exists (shutil_mock , torch_save_mock , __ , ___ , tmp_path ):
745
745
strategy = FSDPStrategy (state_dict_type = "full" )
746
746
747
747
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
@@ -757,16 +757,12 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
757
757
path .mkdir ()
758
758
(path / "meta.pt" ).touch ()
759
759
assert _is_sharded_checkpoint (path )
760
- model = Mock (spec = FullyShardedDataParallel )
761
- model .modules .return_value = [model ]
762
760
strategy .save_checkpoint (Mock (), filepath = path )
763
761
shutil_mock .rmtree .assert_called_once_with (path )
764
762
765
763
# state_dict_type='full', path exists, path is a file: no error (overwrite)
766
764
path = tmp_path / "file.pt"
767
765
path .touch ()
768
- model = Mock (spec = FullyShardedDataParallel )
769
- model .modules .return_value = [model ]
770
766
torch_save_mock .reset_mock ()
771
767
strategy .save_checkpoint (Mock (), filepath = path )
772
768
torch_save_mock .assert_called_once ()
@@ -783,30 +779,26 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
783
779
path = tmp_path / "not-empty-2"
784
780
path .mkdir ()
785
781
(path / "file" ).touch ()
786
- model = Mock (spec = FullyShardedDataParallel )
787
- model .modules .return_value = [model ]
788
782
with save_mock :
789
783
strategy .save_checkpoint ({"state_dict" : {}, "optimizer_states" : {"" : {}}}, filepath = path )
790
784
assert (path / "file" ).exists ()
791
785
792
786
# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
793
787
path = tmp_path / "file-2.pt"
794
788
path .touch ()
795
- model = Mock (spec = FullyShardedDataParallel )
796
- model .modules .return_value = [model ]
797
789
with save_mock :
798
790
strategy .save_checkpoint ({"state_dict" : {}, "optimizer_states" : {"" : {}}}, filepath = path )
799
791
assert path .is_dir ()
800
792
801
793
802
794
@mock .patch ("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast" , lambda _ , x : x )
803
- def test_fsdp_save_checkpoint_unknown_state_dict_type (tmp_path ):
795
+ def test_save_checkpoint_unknown_state_dict_type (tmp_path ):
804
796
strategy = FSDPStrategy (state_dict_type = "invalid" )
805
797
with pytest .raises (ValueError , match = "Unknown state_dict_type" ):
806
798
strategy .save_checkpoint (checkpoint = Mock (), filepath = tmp_path )
807
799
808
800
809
- def test_fsdp_load_unknown_checkpoint_type (tmp_path ):
801
+ def test_load_unknown_checkpoint_type (tmp_path ):
810
802
"""Test that the strategy validates the contents at the checkpoint path."""
811
803
strategy = FSDPStrategy ()
812
804
strategy .model = Mock ()
@@ -874,7 +866,7 @@ def test_save_load_sharded_state_dict(tmp_path):
874
866
@mock .patch ("lightning.pytorch.strategies.fsdp.torch.load" )
875
867
@mock .patch ("lightning.pytorch.strategies.fsdp._lazy_load" )
876
868
@mock .patch ("lightning.pytorch.strategies.fsdp._load_raw_module_state" )
877
- def test_fsdp_lazy_load_full_state_dict (_ , lazy_load_mock , torch_load_mock , tmp_path ):
869
+ def test_lazy_load_full_state_dict (_ , lazy_load_mock , torch_load_mock , tmp_path ):
878
870
"""Test that loading a single file (full state) is lazy to reduce peak CPU memory usage."""
879
871
model = BoringModel ()
880
872
checkpoint = {"state_dict" : model .state_dict ()}
0 commit comments