-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add best_k_metrics
parameter to the ModelCheckpoint
#20457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
8df37c2
32db7d4
0b3322d
74cf6a7
a214154
835538c
80c4bb6
821bfef
247809f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,6 @@ | |
from torch import optim | ||
from torch.utils.data.dataloader import DataLoader | ||
|
||
import lightning.pytorch as pl | ||
from lightning.fabric.utilities.cloud_io import _load as pl_load | ||
from lightning.pytorch import Trainer, seed_everything | ||
from lightning.pytorch.callbacks import ModelCheckpoint | ||
|
@@ -703,6 +702,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog): | |
assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt") | ||
assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt") | ||
assert checkpoint_callback.best_model_score is None | ||
assert checkpoint_callback.best_model_metrics is None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to add tests that exercise the new code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gonzachiar, you have only test for the actual scenarion when the newly added attribute is empty, but we need to add a test to validate that your addition works as expected |
||
assert checkpoint_callback.best_k_models == {} | ||
assert checkpoint_callback.kth_best_model_path == "" | ||
|
||
|
@@ -809,6 +809,7 @@ def test_model_checkpoint_topk_zero(tmp_path): | |
assert checkpoint_callback.monitor is None | ||
assert checkpoint_callback.best_model_path == "" | ||
assert checkpoint_callback.best_model_score is None | ||
assert checkpoint_callback.best_model_metrics is None | ||
assert checkpoint_callback.best_k_models == {} | ||
assert checkpoint_callback.kth_best_model_path == "" | ||
# check that only the last ckpt was created | ||
|
@@ -1074,7 +1075,7 @@ def assert_checkpoint_log_dir(idx): | |
|
||
# load from checkpoint | ||
trainer_config["logger"] = TensorBoardLogger(tmp_path) | ||
trainer = pl.Trainer(**trainer_config) | ||
trainer = Trainer(**trainer_config) | ||
assert_trainer_init(trainer) | ||
|
||
model = ExtendedBoringModel() | ||
|
Uh oh!
There was an error while loading. Please reload this page.