Skip to content

Commit a8310cf

Browse files
committed
Fix integration tests
1 parent 4eaac17 commit a8310cf

File tree

3 files changed

+37
-21
lines changed

3 files changed

+37
-21
lines changed

tests/tests_fabric/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def restore_env_variables():
6969
"OMP_NUM_THREADS", # set by our launchers
7070
# set by torchdynamo
7171
"TRITON_CACHE_DIR",
72+
"TORCHINDUCTOR_CACHE_DIR",
7273
}
7374
leaked_vars.difference_update(allowlist)
7475
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"

tests/tests_fabric/strategies/test_model_parallel_integration.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
from tests_fabric.helpers.runif import RunIf
3030

3131

32+
@pytest.fixture
33+
def distributed():
34+
yield
35+
if torch.distributed.is_initialized():
36+
torch.distributed.destroy_process_group()
37+
38+
3239
class FeedForward(nn.Module):
3340
def __init__(self):
3441
super().__init__()
@@ -81,7 +88,7 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh):
8188

8289

8390
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
84-
def test_setup_device_mesh():
91+
def test_setup_device_mesh(distributed):
8592
from torch.distributed.device_mesh import DeviceMesh
8693

8794
for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)):
@@ -129,7 +136,7 @@ def fn(model, device_mesh):
129136
"compile",
130137
[True, False],
131138
)
132-
def test_tensor_parallel(compile):
139+
def test_tensor_parallel(distributed, compile):
133140
from torch.distributed._tensor import DTensor
134141

135142
parallelize = _parallelize_feed_forward_tp
@@ -182,7 +189,7 @@ def test_tensor_parallel(compile):
182189
"compile",
183190
[True, False],
184191
)
185-
def test_fsdp2_tensor_parallel(compile):
192+
def test_fsdp2_tensor_parallel(distributed, compile):
186193
from torch.distributed._tensor import DTensor
187194

188195
parallelize = _parallelize_feed_forward_fsdp2_tp
@@ -264,14 +271,15 @@ def _train(fabric, model=None, optimizer=None):
264271

265272

266273
@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
274+
@pytest.mark.filterwarnings("ignore::UserWarning")
267275
@pytest.mark.parametrize(
268276
"precision",
269277
[
270278
pytest.param("32-true"),
271279
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
272280
],
273281
)
274-
def test_train_save_load(precision, tmp_path):
282+
def test_train_save_load(distributed, precision, tmp_path):
275283
"""Test 2D-parallel training, saving and loading precision settings."""
276284
strategy = ModelParallelStrategy(
277285
_parallelize_feed_forward_fsdp2_tp,
@@ -329,7 +337,7 @@ def test_train_save_load(precision, tmp_path):
329337

330338
@pytest.mark.filterwarnings("ignore::FutureWarning")
331339
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
332-
def test_save_full_state_dict(tmp_path):
340+
def test_save_full_state_dict(distributed, tmp_path):
333341
"""Test that ModelParallelStrategy saves the full state into a single file with
334342
`save_distributed_checkpoint=False`."""
335343
from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
@@ -430,7 +438,7 @@ def test_save_full_state_dict(tmp_path):
430438

431439
@pytest.mark.filterwarnings("ignore::FutureWarning")
432440
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
433-
def test_load_full_state_dict_into_sharded_model(tmp_path):
441+
def test_load_full_state_dict_into_sharded_model(distributed, tmp_path):
434442
"""Test that the strategy can load a full-state checkpoint into a distributed model."""
435443
fabric = Fabric(accelerator="cuda", devices=1)
436444
fabric.seed_everything(0)
@@ -476,7 +484,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path):
476484
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
477485
@pytest.mark.parametrize("move_to_device", [True, False])
478486
@mock.patch("lightning.fabric.wrappers._FabricModule")
479-
def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
487+
def test_setup_module_move_to_device(fabric_module_mock, move_to_device, distributed):
480488
"""Test that `move_to_device` does nothing, ModelParallel decides which device parameters get moved to which device
481489
(sharding)."""
482490
from torch.distributed._tensor import DTensor
@@ -508,7 +516,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
508516
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
509517
],
510518
)
511-
def test_module_init_context(precision, expected_dtype):
519+
def test_module_init_context(distributed, precision, expected_dtype):
512520
"""Test that the module under the init-context gets moved to the right device and dtype."""
513521
strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_fsdp2)
514522
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision=precision)
@@ -531,7 +539,7 @@ def _run_setup_assertions(empty_init, expected_device):
531539

532540

533541
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
534-
def test_save_filter(tmp_path):
542+
def test_save_filter(distributed, tmp_path):
535543
strategy = ModelParallelStrategy(
536544
parallelize_fn=_parallelize_feed_forward_fsdp2,
537545
save_distributed_checkpoint=False,
@@ -584,7 +592,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
584592
"val",
585593
],
586594
)
587-
def test_clip_gradients(clip_type, precision):
595+
def test_clip_gradients(distributed, clip_type, precision):
588596
strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2)
589597
fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy)
590598
fabric.launch()
@@ -626,7 +634,7 @@ def test_clip_gradients(clip_type, precision):
626634

627635

628636
@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
629-
def test_save_sharded_and_consolidate_and_load(tmp_path):
637+
def test_save_sharded_and_consolidate_and_load(distributed, tmp_path):
630638
"""Test the consolidation of a distributed (DTensor) checkpoint into a single file."""
631639
strategy = ModelParallelStrategy(
632640
_parallelize_feed_forward_fsdp2_tp,
@@ -683,7 +691,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
683691

684692

685693
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
686-
def test_load_raw_module_state():
694+
def test_load_raw_module_state(distributed):
687695
from torch.distributed.device_mesh import init_device_mesh
688696
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
689697

tests/tests_pytorch/strategies/test_model_parallel_integration.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ def fn(model, device_mesh):
8686
return fn
8787

8888

89+
@pytest.fixture
90+
def distributed():
91+
yield
92+
if torch.distributed.is_initialized():
93+
torch.distributed.destroy_process_group()
94+
95+
8996
class TemplateModel(LightningModule):
9097
def __init__(self, compile=False):
9198
super().__init__()
@@ -130,7 +137,7 @@ def configure_model(self):
130137

131138

132139
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
133-
def test_setup_device_mesh():
140+
def test_setup_device_mesh(distributed):
134141
from torch.distributed.device_mesh import DeviceMesh
135142

136143
for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)):
@@ -191,7 +198,7 @@ def configure_model(self):
191198
"compile",
192199
[True, False],
193200
)
194-
def test_tensor_parallel(compile):
201+
def test_tensor_parallel(distributed, compile):
195202
from torch.distributed._tensor import DTensor
196203

197204
class Model(TensorParallelModel):
@@ -236,7 +243,7 @@ def training_step(self, batch):
236243
"compile",
237244
[True, False],
238245
)
239-
def test_fsdp2_tensor_parallel(compile):
246+
def test_fsdp2_tensor_parallel(distributed, compile):
240247
from torch.distributed._tensor import DTensor
241248

242249
class Model(FSDP2TensorParallelModel):
@@ -293,7 +300,7 @@ def training_step(self, batch):
293300

294301

295302
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
296-
def test_modules_without_parameters(tmp_path):
303+
def test_modules_without_parameters(distributed, tmp_path):
297304
"""Test that TorchMetrics get moved to the device despite not having any parameters."""
298305

299306
class MetricsModel(TensorParallelModel):
@@ -336,7 +343,7 @@ def training_step(self, batch):
336343
"compile",
337344
[True, False],
338345
)
339-
def test_module_init_context(compile, precision, expected_dtype, tmp_path):
346+
def test_module_init_context(distributed, compile, precision, expected_dtype, tmp_path):
340347
"""Test that the module under the init-context gets moved to the right device and dtype."""
341348

342349
class Model(FSDP2Model):
@@ -375,7 +382,7 @@ def _run_setup_assertions(empty_init, expected_device):
375382

376383
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
377384
@pytest.mark.parametrize("save_distributed_checkpoint", [True, False])
378-
def test_strategy_state_dict(tmp_path, save_distributed_checkpoint):
385+
def test_strategy_state_dict(distributed, tmp_path, save_distributed_checkpoint):
379386
"""Test that the strategy returns the correct state dict of the LightningModule."""
380387
model = FSDP2Model()
381388
correct_state_dict = model.state_dict() # State dict before wrapping
@@ -408,7 +415,7 @@ def test_strategy_state_dict(tmp_path, save_distributed_checkpoint):
408415

409416

410417
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
411-
def test_load_full_state_checkpoint_into_regular_model(tmp_path):
418+
def test_load_full_state_checkpoint_into_regular_model(distributed, tmp_path):
412419
"""Test that a full-state checkpoint saved from a distributed model can be loaded back into a regular model."""
413420

414421
# Save a regular full-state checkpoint from a distributed model
@@ -450,7 +457,7 @@ def test_load_full_state_checkpoint_into_regular_model(tmp_path):
450457

451458
@pytest.mark.filterwarnings("ignore::FutureWarning")
452459
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
453-
def test_load_standard_checkpoint_into_distributed_model(tmp_path):
460+
def test_load_standard_checkpoint_into_distributed_model(distributed, tmp_path):
454461
"""Test that a regular checkpoint (weights and optimizer states) can be loaded into a distributed model."""
455462

456463
# Save a regular DDP checkpoint
@@ -491,7 +498,7 @@ def test_load_standard_checkpoint_into_distributed_model(tmp_path):
491498

492499
@pytest.mark.filterwarnings("ignore::FutureWarning")
493500
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
494-
def test_save_load_sharded_state_dict(tmp_path):
501+
def test_save_load_sharded_state_dict(distributed, tmp_path):
495502
"""Test saving and loading with the distributed state dict format."""
496503

497504
class CheckpointModel(FSDP2Model):

0 commit comments

Comments
 (0)