Skip to content

Commit 825f605

Browse files
committed
add testing
1 parent 63ce3c7 commit 825f605

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import glob
1415
import logging
1516
import math
1617
import os
@@ -750,3 +751,52 @@ def __init__(self):
750751
assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), (
751752
"Gradients should differ significantly in exponential mode when using proper spacing"
752753
)
754+
755+
756+
def test_lr_finder_checkpoint_cleanup_on_error(tmp_path):
757+
"""Test that temporary checkpoint files are cleaned up even when an error occurs during lr finding."""
758+
759+
class FailingModel(BoringModel):
760+
def __init__(self, fail_on_step=2):
761+
super().__init__()
762+
self.fail_on_step = fail_on_step
763+
self.current_step = 0
764+
self.learning_rate = 1e-3
765+
766+
def training_step(self, batch, batch_idx):
767+
self.current_step += 1
768+
if self.current_step >= self.fail_on_step:
769+
raise RuntimeError("Intentional failure for testing cleanup")
770+
return super().training_step(batch, batch_idx)
771+
772+
def configure_optimizers(self):
773+
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
774+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
775+
return [optimizer], [lr_scheduler]
776+
777+
model = FailingModel()
778+
lr_finder = LearningRateFinder(num_training_steps=5)
779+
780+
trainer = Trainer(
781+
default_root_dir=tmp_path,
782+
max_epochs=1,
783+
enable_checkpointing=False,
784+
enable_progress_bar=False,
785+
enable_model_summary=False,
786+
logger=False,
787+
callbacks=[lr_finder],
788+
)
789+
790+
# Check no lr_find checkpoint files exist initially
791+
lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
792+
assert len(lr_find_checkpoints) == 0, "No lr_find checkpoint files should exist initially"
793+
794+
# Run lr finder and expect it to fail
795+
with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
796+
trainer.fit(model)
797+
798+
# Check that no lr_find checkpoint files are left behind
799+
lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
800+
assert len(lr_find_checkpoints) == 0, (
801+
f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}"
802+
)

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import glob
1415
import logging
1516
import os
1617
from copy import deepcopy
@@ -486,3 +487,49 @@ def test_batch_size_finder_callback_val_batches(tmp_path):
486487

487488
assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
488489
assert trainer.num_val_batches[0] != steps_per_trial
490+
491+
492+
def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path):
493+
"""Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling."""
494+
495+
class FailingModel(BoringModel):
496+
def __init__(self, fail_on_step=2):
497+
super().__init__()
498+
self.fail_on_step = fail_on_step
499+
self.current_step = 0
500+
self.batch_size = 2
501+
502+
def training_step(self, batch, batch_idx):
503+
self.current_step += 1
504+
if self.current_step >= self.fail_on_step:
505+
raise RuntimeError("Intentional failure for testing cleanup")
506+
return super().training_step(batch, batch_idx)
507+
508+
def train_dataloader(self):
509+
return DataLoader(RandomDataset(32, 64), batch_size=self.batch_size)
510+
511+
model = FailingModel()
512+
batch_size_finder = BatchSizeFinder(max_trials=3, steps_per_trial=2)
513+
trainer = Trainer(
514+
default_root_dir=tmp_path,
515+
max_epochs=1,
516+
enable_checkpointing=False,
517+
enable_progress_bar=False,
518+
enable_model_summary=False,
519+
logger=False,
520+
callbacks=[batch_size_finder],
521+
)
522+
523+
# Check no scale_batch_size checkpoint files exist initially
524+
scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
525+
assert len(scale_checkpoints) == 0, "No scale_batch_size checkpoint files should exist initially"
526+
527+
# Run batch size scaler and expect it to fail
528+
with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
529+
trainer.fit(model)
530+
531+
# Check that no scale_batch_size checkpoint files are left behind
532+
scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
533+
assert len(scale_checkpoints) == 0, (
534+
f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}"
535+
)

0 commit comments

Comments
 (0)