|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +from unittest.mock import patch |
| 15 | + |
14 | 16 | import pytest
|
15 | 17 |
|
16 | 18 | from pytorch_lightning.loops import TrainingEpochLoop
|
| 19 | +from pytorch_lightning.trainer.trainer import Trainer |
| 20 | +from tests.helpers.boring_model import BoringModel |
17 | 21 |
|
18 | 22 | _out00 = {"loss": 0.0}
|
19 | 23 | _out01 = {"loss": 0.1}
|
@@ -141,3 +145,28 @@ def test_prepare_outputs_training_batch_end_manual(batch_end_outputs, expected):
|
141 | 145 | num_optimizers=-1, # does not matter for manual optimization
|
142 | 146 | )
|
143 | 147 | assert prepared == expected
|
| 148 | + |
| 149 | + |
| 150 | +def test_no_val_on_train_epoch_loop_restart(tmpdir): |
| 151 | + """Test that training validation loop doesn't get triggered at the beginning of a restart.""" |
| 152 | + trainer_kwargs = { |
| 153 | + "max_epochs": 1, |
| 154 | + "limit_train_batches": 1, |
| 155 | + "limit_val_batches": 1, |
| 156 | + "num_sanity_val_steps": 0, |
| 157 | + "enable_checkpointing": False, |
| 158 | + } |
| 159 | + trainer = Trainer(**trainer_kwargs) |
| 160 | + model = BoringModel() |
| 161 | + trainer.fit(model) |
| 162 | + ckpt_path = str(tmpdir / "last.ckpt") |
| 163 | + trainer.save_checkpoint(ckpt_path) |
| 164 | + |
| 165 | + trainer_kwargs["max_epochs"] = 2 |
| 166 | + trainer = Trainer(**trainer_kwargs) |
| 167 | + |
| 168 | + with patch.object( |
| 169 | + trainer.fit_loop.epoch_loop.val_loop, "advance", wraps=trainer.fit_loop.epoch_loop.val_loop.advance |
| 170 | + ) as advance_mocked: |
| 171 | + trainer.fit(model, ckpt_path=ckpt_path) |
| 172 | + assert advance_mocked.call_count == 1 |
0 commit comments