|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import glob |
14 | 15 | import logging
|
15 | 16 | import math
|
16 | 17 | import os
|
@@ -750,3 +751,52 @@ def __init__(self):
|
750 | 751 | assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), (
|
751 | 752 | "Gradients should differ significantly in exponential mode when using proper spacing"
|
752 | 753 | )
|
| 754 | + |
| 755 | + |
| 756 | +def test_lr_finder_checkpoint_cleanup_on_error(tmp_path): |
| 757 | + """Test that temporary checkpoint files are cleaned up even when an error occurs during lr finding.""" |
| 758 | + |
| 759 | + class FailingModel(BoringModel): |
| 760 | + def __init__(self, fail_on_step=2): |
| 761 | + super().__init__() |
| 762 | + self.fail_on_step = fail_on_step |
| 763 | + self.current_step = 0 |
| 764 | + self.learning_rate = 1e-3 |
| 765 | + |
| 766 | + def training_step(self, batch, batch_idx): |
| 767 | + self.current_step += 1 |
| 768 | + if self.current_step >= self.fail_on_step: |
| 769 | + raise RuntimeError("Intentional failure for testing cleanup") |
| 770 | + return super().training_step(batch, batch_idx) |
| 771 | + |
| 772 | + def configure_optimizers(self): |
| 773 | + optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate) |
| 774 | + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) |
| 775 | + return [optimizer], [lr_scheduler] |
| 776 | + |
| 777 | + model = FailingModel() |
| 778 | + lr_finder = LearningRateFinder(num_training_steps=5) |
| 779 | + |
| 780 | + trainer = Trainer( |
| 781 | + default_root_dir=tmp_path, |
| 782 | + max_epochs=1, |
| 783 | + enable_checkpointing=False, |
| 784 | + enable_progress_bar=False, |
| 785 | + enable_model_summary=False, |
| 786 | + logger=False, |
| 787 | + callbacks=[lr_finder], |
| 788 | + ) |
| 789 | + |
| 790 | + # Check no lr_find checkpoint files exist initially |
| 791 | + lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt")) |
| 792 | + assert len(lr_find_checkpoints) == 0, "No lr_find checkpoint files should exist initially" |
| 793 | + |
| 794 | + # Run lr finder and expect it to fail |
| 795 | + with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"): |
| 796 | + trainer.fit(model) |
| 797 | + |
| 798 | + # Check that no lr_find checkpoint files are left behind |
| 799 | + lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt")) |
| 800 | + assert len(lr_find_checkpoints) == 0, ( |
| 801 | + f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}" |
| 802 | + ) |
0 commit comments