Skip to content

Commit 8524d03

Browse files
rohitgr7carmocca
authored andcommitted
Fix val_loop run on restart (#11552)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 0f99fcf commit 8524d03

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
### Fixed
1010

1111
- Fixed the format of the configuration saved automatically by the CLI's `SaveConfigCallback` ([#11532](https://github.com/PyTorchLightning/pytorch-lightning/pull/11532))
12-
13-
14-
-
12+
- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552))
1513

1614

1715
## [1.5.9] - 2022-01-18

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,11 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
529529

530530
# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
531531
is_val_check_batch = is_last_batch
532+
533+
# while restarting with no fault-tolerant, batch_progress.current.ready is -1
534+
if batch_idx == -1:
535+
return False
536+
532537
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
533538
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
534539
elif self.trainer.val_check_batch != float("inf"):

tests/loops/epoch/test_training_epoch_loop.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest.mock import patch
15+
1416
import pytest
1517

1618
from pytorch_lightning.loops import TrainingEpochLoop
19+
from pytorch_lightning.trainer.trainer import Trainer
20+
from tests.helpers.boring_model import BoringModel
1721

1822
_out00 = {"loss": 0.0}
1923
_out01 = {"loss": 0.1}
@@ -141,3 +145,28 @@ def test_prepare_outputs_training_batch_end_manual(batch_end_outputs, expected):
141145
num_optimizers=-1, # does not matter for manual optimization
142146
)
143147
assert prepared == expected
148+
149+
150+
def test_no_val_on_train_epoch_loop_restart(tmpdir):
151+
"""Test that training validation loop doesn't get triggered at the beginning of a restart."""
152+
trainer_kwargs = {
153+
"max_epochs": 1,
154+
"limit_train_batches": 1,
155+
"limit_val_batches": 1,
156+
"num_sanity_val_steps": 0,
157+
"enable_checkpointing": False,
158+
}
159+
trainer = Trainer(**trainer_kwargs)
160+
model = BoringModel()
161+
trainer.fit(model)
162+
ckpt_path = str(tmpdir / "last.ckpt")
163+
trainer.save_checkpoint(ckpt_path)
164+
165+
trainer_kwargs["max_epochs"] = 2
166+
trainer = Trainer(**trainer_kwargs)
167+
168+
with patch.object(
169+
trainer.fit_loop.epoch_loop.val_loop, "advance", wraps=trainer.fit_loop.epoch_loop.val_loop.advance
170+
) as advance_mocked:
171+
trainer.fit(model, ckpt_path=ckpt_path)
172+
assert advance_mocked.call_count == 1

0 commit comments

Comments
 (0)