Skip to content

Commit 28e14f4

Browse files
committed
cleanup log after test
1 parent da5ab53 commit 28e14f4

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

tests/tests_pytorch/callbacks/test_model_checkpoint_manual_opt.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import shutil
22
import tempfile
33
import warnings
4+
from contextlib import contextmanager
45
from copy import deepcopy
56
from pathlib import Path
67

7-
import pytest
88
import torch
99
from torch.utils.data import DataLoader, Dataset
1010

@@ -63,17 +63,20 @@ def configure_optimizers(self):
6363
return torch.optim.SGD(self.parameters(), lr=0.01)
6464

6565

66-
@pytest.fixture
67-
def auto_cleanup_lightning_logs():
68-
"""Fixture to clean up lightning_logs directory after each test."""
66+
@contextmanager
67+
def cleanup_after_test():
68+
"""Context manager to ensure all test artifacts are cleaned up."""
6969
log_dir = Path("tests_pytorch/lightning_logs")
70-
yield
71-
if log_dir.exists():
72-
shutil.rmtree(log_dir, ignore_errors=True)
70+
try:
71+
yield
72+
finally:
73+
# Clean up any remaining log files
74+
if log_dir.exists():
75+
shutil.rmtree(log_dir, ignore_errors=True)
7376

7477

75-
def test_model_checkpoint_manual_opt(auto_cleanup_lightning_logs):
76-
with tempfile.TemporaryDirectory() as tmpdir:
78+
def test_model_checkpoint_manual_opt():
79+
with cleanup_after_test(), tempfile.TemporaryDirectory() as tmpdir:
7780
dataset = FakeDataset()
7881
train_dataloader = DataLoader(dataset, batch_size=1)
7982
model = SimpleModule()
@@ -95,9 +98,12 @@ def test_model_checkpoint_manual_opt(auto_cleanup_lightning_logs):
9598
],
9699
log_every_n_steps=1,
97100
num_sanity_val_steps=0,
101+
logger=False, # Disable logging to prevent creating lightning_logs
98102
)
99-
trainer.fit(model, train_dataloader)
100-
trainer._teardown() # Ensure trainer is properly closed
103+
try:
104+
trainer.fit(model, train_dataloader)
105+
finally:
106+
trainer._teardown() # Ensure trainer is properly closed
101107

102108
# The best loss is at batch_idx=2 (loss=0.0)
103109
best_step = 2
@@ -113,7 +119,7 @@ def test_model_checkpoint_manual_opt(auto_cleanup_lightning_logs):
113119
)
114120

115121

116-
def test_model_checkpoint_manual_opt_warning(auto_cleanup_lightning_logs):
122+
def test_model_checkpoint_manual_opt_warning():
117123
"""Test that a warning is raised when using manual optimization without saving the state."""
118124

119125
class SimpleModuleNoSave(SimpleModule):
@@ -129,7 +135,7 @@ def training_step(self, batch, batch_idx):
129135
optimizer.step()
130136
return loss
131137

132-
with tempfile.TemporaryDirectory() as tmpdir:
138+
with cleanup_after_test(), tempfile.TemporaryDirectory() as tmpdir:
133139
dataset = FakeDataset()
134140
train_dataloader = DataLoader(dataset, batch_size=1, num_workers=0) # Avoid num_workers warning
135141
model = SimpleModuleNoSave()
@@ -157,9 +163,12 @@ def training_step(self, batch, batch_idx):
157163
],
158164
log_every_n_steps=1,
159165
num_sanity_val_steps=0,
166+
logger=False, # Disable logging to prevent creating lightning_logs
160167
)
161-
trainer.fit(model, train_dataloader)
162-
trainer._teardown() # Ensure trainer is properly closed
168+
try:
169+
trainer.fit(model, train_dataloader)
170+
finally:
171+
trainer._teardown()
163172

164173
# Find our warning in the list of warnings
165174
manual_opt_warnings = [

0 commit comments

Comments
 (0)