1
1
import shutil
2
2
import tempfile
3
3
import warnings
4
+ from contextlib import contextmanager
4
5
from copy import deepcopy
5
6
from pathlib import Path
6
7
7
- import pytest
8
8
import torch
9
9
from torch .utils .data import DataLoader , Dataset
10
10
@@ -63,17 +63,20 @@ def configure_optimizers(self):
63
63
return torch .optim .SGD (self .parameters (), lr = 0.01 )
64
64
65
65
66
- @pytest . fixture
67
- def auto_cleanup_lightning_logs ():
68
- """Fixture to clean up lightning_logs directory after each test ."""
66
+ @contextmanager
67
+ def cleanup_after_test ():
68
+ """Context manager to ensure all test artifacts are cleaned up ."""
69
69
log_dir = Path ("tests_pytorch/lightning_logs" )
70
- yield
71
- if log_dir .exists ():
72
- shutil .rmtree (log_dir , ignore_errors = True )
70
+ try :
71
+ yield
72
+ finally :
73
+ # Clean up any remaining log files
74
+ if log_dir .exists ():
75
+ shutil .rmtree (log_dir , ignore_errors = True )
73
76
74
77
75
- def test_model_checkpoint_manual_opt (auto_cleanup_lightning_logs ):
76
- with tempfile .TemporaryDirectory () as tmpdir :
78
+ def test_model_checkpoint_manual_opt ():
79
+ with cleanup_after_test (), tempfile .TemporaryDirectory () as tmpdir :
77
80
dataset = FakeDataset ()
78
81
train_dataloader = DataLoader (dataset , batch_size = 1 )
79
82
model = SimpleModule ()
@@ -95,9 +98,12 @@ def test_model_checkpoint_manual_opt(auto_cleanup_lightning_logs):
95
98
],
96
99
log_every_n_steps = 1 ,
97
100
num_sanity_val_steps = 0 ,
101
+ logger = False , # Disable logging to prevent creating lightning_logs
98
102
)
99
- trainer .fit (model , train_dataloader )
100
- trainer ._teardown () # Ensure trainer is properly closed
103
+ try :
104
+ trainer .fit (model , train_dataloader )
105
+ finally :
106
+ trainer ._teardown () # Ensure trainer is properly closed
101
107
102
108
# The best loss is at batch_idx=2 (loss=0.0)
103
109
best_step = 2
@@ -113,7 +119,7 @@ def test_model_checkpoint_manual_opt(auto_cleanup_lightning_logs):
113
119
)
114
120
115
121
116
- def test_model_checkpoint_manual_opt_warning (auto_cleanup_lightning_logs ):
122
+ def test_model_checkpoint_manual_opt_warning ():
117
123
"""Test that a warning is raised when using manual optimization without saving the state."""
118
124
119
125
class SimpleModuleNoSave (SimpleModule ):
@@ -129,7 +135,7 @@ def training_step(self, batch, batch_idx):
129
135
optimizer .step ()
130
136
return loss
131
137
132
- with tempfile .TemporaryDirectory () as tmpdir :
138
+ with cleanup_after_test (), tempfile .TemporaryDirectory () as tmpdir :
133
139
dataset = FakeDataset ()
134
140
train_dataloader = DataLoader (dataset , batch_size = 1 , num_workers = 0 ) # Avoid num_workers warning
135
141
model = SimpleModuleNoSave ()
@@ -157,9 +163,12 @@ def training_step(self, batch, batch_idx):
157
163
],
158
164
log_every_n_steps = 1 ,
159
165
num_sanity_val_steps = 0 ,
166
+ logger = False , # Disable logging to prevent creating lightning_logs
160
167
)
161
- trainer .fit (model , train_dataloader )
162
- trainer ._teardown () # Ensure trainer is properly closed
168
+ try :
169
+ trainer .fit (model , train_dataloader )
170
+ finally :
171
+ trainer ._teardown ()
163
172
164
173
# Find our warning in the list of warnings
165
174
manual_opt_warnings = [
0 commit comments