|
25 | 25 | from pytorch_lightning.loggers.base import LoggerCollection |
26 | 26 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger |
27 | 27 | from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler |
28 | | -from pytorch_lightning.profiler.pytorch import RegisterRecordFunction |
| 28 | +from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache |
29 | 29 | from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 |
30 | 30 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
31 | 31 | from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE |
@@ -523,3 +523,31 @@ def test_trainer_profiler_incorrect_str_arg(): |
523 | 523 | match=r"When passing string value for the `profiler` parameter of `Trainer`, it can only be one of.*", |
524 | 524 | ): |
525 | 525 | Trainer(profiler="unknown_profiler") |
| 526 | + |
| 527 | + |
| 528 | +@pytest.mark.skipif(not _KINETO_AVAILABLE, reason="Requires PyTorch Profiler Kineto") |
| 529 | +@pytest.mark.parametrize( |
| 530 | + ["trainer_config", "trainer_fn"], |
| 531 | + [ |
| 532 | + ({"limit_train_batches": 4, "limit_val_batches": 7}, "fit"), |
| 533 | + ({"limit_train_batches": 7, "limit_val_batches": 4, "num_sanity_val_steps": 0}, "fit"), |
| 534 | + ( |
| 535 | + { |
| 536 | + "limit_train_batches": 7, |
| 537 | + "limit_val_batches": 2, |
| 538 | + }, |
| 539 | + "fit", |
| 540 | + ), |
| 541 | + ({"limit_val_batches": 4}, "validate"), |
| 542 | + ({"limit_test_batches": 4}, "test"), |
| 543 | + ({"limit_predict_batches": 4}, "predict"), |
| 544 | + ], |
| 545 | +) |
| 546 | +def test_pytorch_profiler_raises_warning_for_limited_steps(tmpdir, trainer_config, trainer_fn): |
| 547 | + model = BoringModel() |
| 548 | + trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", max_epochs=1, **trainer_config) |
| 549 | + warning_cache.clear() |
| 550 | + with pytest.warns(UserWarning, match="not enough steps to properly record traces"): |
| 551 | + getattr(trainer, trainer_fn)(model) |
| 552 | + assert trainer.profiler._schedule is None |
| 553 | + warning_cache.clear() |
0 commit comments