11import shutil
22import tempfile
33import warnings
4+ from contextlib import contextmanager
45from copy import deepcopy
56from pathlib import Path
67
7- import pytest
88import torch
99from torch .utils .data import DataLoader , Dataset
1010
@@ -63,17 +63,20 @@ def configure_optimizers(self):
6363 return torch .optim .SGD (self .parameters (), lr = 0.01 )
6464
6565
66- @pytest . fixture
67- def auto_cleanup_lightning_logs ():
68- """Fixture to clean up lightning_logs directory after each test ."""
66+ @contextmanager
67+ def cleanup_after_test ():
68+ """Context manager to ensure all test artifacts are cleaned up ."""
6969 log_dir = Path ("tests_pytorch/lightning_logs" )
70- yield
71- if log_dir .exists ():
72- shutil .rmtree (log_dir , ignore_errors = True )
70+ try :
71+ yield
72+ finally :
73+ # Clean up any remaining log files
74+ if log_dir .exists ():
75+ shutil .rmtree (log_dir , ignore_errors = True )
7376
7477
75- def test_model_checkpoint_manual_opt (auto_cleanup_lightning_logs ):
76- with tempfile .TemporaryDirectory () as tmpdir :
78+ def test_model_checkpoint_manual_opt ():
79+ with cleanup_after_test (), tempfile .TemporaryDirectory () as tmpdir :
7780 dataset = FakeDataset ()
7881 train_dataloader = DataLoader (dataset , batch_size = 1 )
7982 model = SimpleModule ()
@@ -95,9 +98,12 @@ def test_model_checkpoint_manual_opt(auto_cleanup_lightning_logs):
9598 ],
9699 log_every_n_steps = 1 ,
97100 num_sanity_val_steps = 0 ,
101+ logger = False , # Disable logging to prevent creating lightning_logs
98102 )
99- trainer .fit (model , train_dataloader )
100- trainer ._teardown () # Ensure trainer is properly closed
103+ try :
104+ trainer .fit (model , train_dataloader )
105+ finally :
106+ trainer ._teardown () # Ensure trainer is properly closed
101107
102108 # The best loss is at batch_idx=2 (loss=0.0)
103109 best_step = 2
@@ -113,7 +119,7 @@ def test_model_checkpoint_manual_opt(auto_cleanup_lightning_logs):
113119 )
114120
115121
116- def test_model_checkpoint_manual_opt_warning (auto_cleanup_lightning_logs ):
122+ def test_model_checkpoint_manual_opt_warning ():
117123 """Test that a warning is raised when using manual optimization without saving the state."""
118124
119125 class SimpleModuleNoSave (SimpleModule ):
@@ -129,7 +135,7 @@ def training_step(self, batch, batch_idx):
129135 optimizer .step ()
130136 return loss
131137
132- with tempfile .TemporaryDirectory () as tmpdir :
138+ with cleanup_after_test (), tempfile .TemporaryDirectory () as tmpdir :
133139 dataset = FakeDataset ()
134140 train_dataloader = DataLoader (dataset , batch_size = 1 , num_workers = 0 ) # Avoid num_workers warning
135141 model = SimpleModuleNoSave ()
@@ -157,9 +163,12 @@ def training_step(self, batch, batch_idx):
157163 ],
158164 log_every_n_steps = 1 ,
159165 num_sanity_val_steps = 0 ,
166+ logger = False , # Disable logging to prevent creating lightning_logs
160167 )
161- trainer .fit (model , train_dataloader )
162- trainer ._teardown () # Ensure trainer is properly closed
168+ try :
169+ trainer .fit (model , train_dataloader )
170+ finally :
171+ trainer ._teardown ()
163172
164173 # Find our warning in the list of warnings
165174 manual_opt_warnings = [
0 commit comments