2929from tests_fabric .helpers .runif import RunIf
3030
3131
32+ @pytest .fixture
33+ def distributed ():
34+ yield
35+ if torch .distributed .is_initialized ():
36+ torch .distributed .destroy_process_group ()
37+
38+
3239class FeedForward (nn .Module ):
3340 def __init__ (self ):
3441 super ().__init__ ()
@@ -81,7 +88,7 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh):
8188
8289
8390@RunIf (min_torch = "2.4" , standalone = True , min_cuda_gpus = 4 )
84- def test_setup_device_mesh ():
91+ def test_setup_device_mesh (distributed ):
8592 from torch .distributed .device_mesh import DeviceMesh
8693
8794 for dp_size , tp_size in ((1 , 4 ), (4 , 1 ), (2 , 2 )):
@@ -129,7 +136,7 @@ def fn(model, device_mesh):
129136 "compile" ,
130137 [True , False ],
131138)
132- def test_tensor_parallel (compile ):
139+ def test_tensor_parallel (distributed , compile ):
133140 from torch .distributed ._tensor import DTensor
134141
135142 parallelize = _parallelize_feed_forward_tp
@@ -182,7 +189,7 @@ def test_tensor_parallel(compile):
182189 "compile" ,
183190 [True , False ],
184191)
185- def test_fsdp2_tensor_parallel (compile ):
192+ def test_fsdp2_tensor_parallel (distributed , compile ):
186193 from torch .distributed ._tensor import DTensor
187194
188195 parallelize = _parallelize_feed_forward_fsdp2_tp
@@ -264,14 +271,15 @@ def _train(fabric, model=None, optimizer=None):
264271
265272
266273@RunIf (min_torch = "2.4" , min_cuda_gpus = 4 , standalone = True )
274+ @pytest .mark .filterwarnings ("ignore::UserWarning" )
267275@pytest .mark .parametrize (
268276 "precision" ,
269277 [
270278 pytest .param ("32-true" ),
271279 pytest .param ("bf16-mixed" , marks = RunIf (bf16_cuda = True )),
272280 ],
273281)
274- def test_train_save_load (precision , tmp_path ):
282+ def test_train_save_load (distributed , precision , tmp_path ):
275283 """Test 2D-parallel training, saving and loading precision settings."""
276284 strategy = ModelParallelStrategy (
277285 _parallelize_feed_forward_fsdp2_tp ,
@@ -329,7 +337,7 @@ def test_train_save_load(precision, tmp_path):
329337
330338@pytest .mark .filterwarnings ("ignore::FutureWarning" )
331339@RunIf (min_torch = "2.4" , min_cuda_gpus = 2 , standalone = True )
332- def test_save_full_state_dict (tmp_path ):
340+ def test_save_full_state_dict (distributed , tmp_path ):
333341 """Test that ModelParallelStrategy saves the full state into a single file with
334342 `save_distributed_checkpoint=False`."""
335343 from torch .distributed .checkpoint .state_dict import get_optimizer_state_dict
@@ -430,7 +438,7 @@ def test_save_full_state_dict(tmp_path):
430438
431439@pytest .mark .filterwarnings ("ignore::FutureWarning" )
432440@RunIf (min_torch = "2.4" , min_cuda_gpus = 2 , standalone = True )
433- def test_load_full_state_dict_into_sharded_model (tmp_path ):
441+ def test_load_full_state_dict_into_sharded_model (distributed , tmp_path ):
434442 """Test that the strategy can load a full-state checkpoint into a distributed model."""
435443 fabric = Fabric (accelerator = "cuda" , devices = 1 )
436444 fabric .seed_everything (0 )
@@ -476,7 +484,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path):
476484@RunIf (min_torch = "2.4" , min_cuda_gpus = 2 , skip_windows = True , standalone = True )
477485@pytest .mark .parametrize ("move_to_device" , [True , False ])
478486@mock .patch ("lightning.fabric.wrappers._FabricModule" )
479- def test_setup_module_move_to_device (fabric_module_mock , move_to_device ):
487+ def test_setup_module_move_to_device (fabric_module_mock , move_to_device , distributed ):
480488 """Test that `move_to_device` does nothing, ModelParallel decides which device parameters get moved to which device
481489 (sharding)."""
482490 from torch .distributed ._tensor import DTensor
@@ -508,7 +516,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
508516 pytest .param ("bf16-true" , torch .bfloat16 , marks = RunIf (bf16_cuda = True )),
509517 ],
510518)
511- def test_module_init_context (precision , expected_dtype ):
519+ def test_module_init_context (distributed , precision , expected_dtype ):
512520 """Test that the module under the init-context gets moved to the right device and dtype."""
513521 strategy = ModelParallelStrategy (parallelize_fn = _parallelize_feed_forward_fsdp2 )
514522 fabric = Fabric (accelerator = "cuda" , devices = 2 , strategy = strategy , precision = precision )
@@ -531,7 +539,7 @@ def _run_setup_assertions(empty_init, expected_device):
531539
532540
533541@RunIf (min_torch = "2.4" , min_cuda_gpus = 2 , standalone = True )
534- def test_save_filter (tmp_path ):
542+ def test_save_filter (distributed , tmp_path ):
535543 strategy = ModelParallelStrategy (
536544 parallelize_fn = _parallelize_feed_forward_fsdp2 ,
537545 save_distributed_checkpoint = False ,
@@ -584,7 +592,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
584592 "val" ,
585593 ],
586594)
587- def test_clip_gradients (clip_type , precision ):
595+ def test_clip_gradients (distributed , clip_type , precision ):
588596 strategy = ModelParallelStrategy (_parallelize_single_linear_tp_fsdp2 )
589597 fabric = Fabric (accelerator = "auto" , devices = 2 , precision = precision , strategy = strategy )
590598 fabric .launch ()
@@ -626,7 +634,7 @@ def test_clip_gradients(clip_type, precision):
626634
627635
628636@RunIf (min_torch = "2.4" , min_cuda_gpus = 4 , standalone = True )
629- def test_save_sharded_and_consolidate_and_load (tmp_path ):
637+ def test_save_sharded_and_consolidate_and_load (distributed , tmp_path ):
630638 """Test the consolidation of a distributed (DTensor) checkpoint into a single file."""
631639 strategy = ModelParallelStrategy (
632640 _parallelize_feed_forward_fsdp2_tp ,
@@ -683,7 +691,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
683691
684692
685693@RunIf (min_torch = "2.4" , min_cuda_gpus = 2 , standalone = True )
686- def test_load_raw_module_state ():
694+ def test_load_raw_module_state (distributed ):
687695 from torch .distributed .device_mesh import init_device_mesh
688696 from torch .distributed .tensor .parallel import ColwiseParallel , parallelize_module
689697
0 commit comments