Skip to content

Commit 93c3e69

Browse files
Create test_ddp_sigterm_handling.py
1 parent 2761ad8 commit 93c3e69

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import os
2+
import signal
3+
import time
4+
import pytest
5+
import torch
6+
import torch.distributed as dist
7+
import torch.multiprocessing as mp
8+
from lightning.pytorch import Trainer, seed_everything, LightningModule
9+
from lightning.pytorch.demos.boring_classes import BoringDataModule
10+
from lightning.pytorch.strategies.ddp import DDPStrategy
11+
from lightning.pytorch.utilities.exceptions import SIGTERMException
12+
13+
class DummyModel(LightningModule):
14+
def training_step(self, batch, batch_idx):
15+
# Simulate SIGTERM in rank 0 at batch 2
16+
if self.trainer.global_rank == 0 and batch_idx == 2:
17+
time.sleep(3) # Let other ranks proceed to the next batch
18+
os.kill(os.getpid(), signal.SIGTERM)
19+
return super().training_step(batch, batch_idx)
20+
21+
def run_ddp_sigterm(rank, world_size, tmpdir):
22+
os.environ["MASTER_ADDR"] = "localhost"
23+
os.environ["MASTER_PORT"] = "12355"
24+
os.environ["RANK"] = str(rank)
25+
os.environ["WORLD_SIZE"] = str(world_size)
26+
27+
seed_everything(42)
28+
29+
torch.cuda.set_device(rank) if torch.cuda.is_available() else None
30+
31+
model = DummyModel()
32+
datamodule = BoringDataModule()
33+
34+
trainer = Trainer(
35+
accelerator="cuda" if torch.cuda.is_available() else "cpu",
36+
strategy=DDPStrategy(find_unused_parameters=False),
37+
devices=world_size,
38+
num_nodes=1,
39+
max_epochs=3,
40+
default_root_dir=tmpdir,
41+
enable_checkpointing=False,
42+
enable_progress_bar=False,
43+
enable_model_summary=False,
44+
logger=False,
45+
)
46+
47+
try:
48+
trainer.fit(model, datamodule=datamodule)
49+
except SIGTERMException:
50+
# Test passed: SIGTERM was properly raised and caught
51+
print(f"[Rank {rank}] Caught SIGTERMException successfully.")
52+
except Exception as e:
53+
pytest.fail(f"[Rank {rank}] Unexpected exception: {e}")
54+
55+
def test_ddp_sigterm_handling(tmp_path):
56+
world_size = 2
57+
mp.spawn(run_ddp_sigterm, args=(world_size, tmp_path), nprocs=world_size, join=True)
58+
59+
60+
@pytest.mark.skipif(
61+
not torch.distributed.is_available(),
62+
reason="Requires torch.distributed",
63+
)
64+
@pytest.mark.skipif(
65+
torch.cuda.is_available() and torch.cuda.device_count() < 2,
66+
reason="Requires >=2 CUDA devices or use CPU",
67+
)
68+
def test_sigterm_handling_ddp(tmp_path):
69+
test_ddp_sigterm_handling(tmp_path)

0 commit comments

Comments
 (0)