Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/lightning/pytorch/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,17 +508,14 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
if not self._log_model_checkpoints:
return

from neptune.types import File

file_names = set()
checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")

# save last model
if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
file_names.add(model_last_name)
with open(checkpoint_callback.last_model_path, "rb") as fp:
self.run[f"{checkpoints_namespace}/{model_last_name}"] = File.from_stream(fp)
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)

# save best k models
if hasattr(checkpoint_callback, "best_k_models"):
Expand All @@ -533,8 +530,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:

model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
file_names.add(model_name)
with open(checkpoint_callback.best_model_path, "rb") as fp:
self.run[f"{checkpoints_namespace}/{model_name}"] = File.from_stream(fp)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)

# remove old models logged to experiment if they are not part of best k models at this point
if self.run.exists(checkpoints_namespace):
Expand Down
9 changes: 4 additions & 5 deletions tests/tests_pytorch/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,10 @@ def test_after_save_checkpoint(neptune_mock):
mock_file.side_effect = mock.Mock()
logger.after_save_checkpoint(cb_mock)

assert run_instance_mock.__setitem__.call_count == 3
assert run_instance_mock.__getitem__.call_count == 2
assert run_attr_mock.upload.call_count == 2

assert mock_file.from_stream.call_count == 2
assert run_instance_mock.__setitem__.call_count == 1 # best_model_path
assert run_instance_mock.__getitem__.call_count == 4 # last_model_path, best_k_models, best_model_path
assert run_attr_mock.upload.call_count == 4 # last_model_path, best_k_models, best_model_path
assert mock_file.from_stream.call_count == 0

run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")
Expand Down
Loading