From 096db8479af693c17a641783a2354a684c88d9c4 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 10 Feb 2025 11:37:40 +0200 Subject: [PATCH] Fix file extension of uploaded model checkpoints --- src/lightning/pytorch/loggers/neptune.py | 8 ++------ tests/tests_pytorch/loggers/test_neptune.py | 9 ++++----- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index a363f589b29b4..bf9669c824784 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -508,8 +508,6 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: if not self._log_model_checkpoints: return - from neptune.types import File - file_names = set() checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints") @@ -517,8 +515,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path: model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback) file_names.add(model_last_name) - with open(checkpoint_callback.last_model_path, "rb") as fp: - self.run[f"{checkpoints_namespace}/{model_last_name}"] = File.from_stream(fp) + self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path) # save best k models if hasattr(checkpoint_callback, "best_k_models"): @@ -533,8 +530,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback) file_names.add(model_name) - with open(checkpoint_callback.best_model_path, "rb") as fp: - self.run[f"{checkpoints_namespace}/{model_name}"] = File.from_stream(fp) + self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path) # remove old models logged to experiment if they are not part of best k models at this point if self.run.exists(checkpoints_namespace): diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index b5e98fbe99113..6dc3816fac858 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -256,11 +256,10 @@ def test_after_save_checkpoint(neptune_mock): mock_file.side_effect = mock.Mock() logger.after_save_checkpoint(cb_mock) - assert run_instance_mock.__setitem__.call_count == 3 - assert run_instance_mock.__getitem__.call_count == 2 - assert run_attr_mock.upload.call_count == 2 - - assert mock_file.from_stream.call_count == 2 + assert run_instance_mock.__setitem__.call_count == 1 # best_model_path + assert run_instance_mock.__getitem__.call_count == 4 # last_model_path, best_k_models, best_model_path + assert run_attr_mock.upload.call_count == 4 # last_model_path, best_k_models, best_model_path + assert mock_file.from_stream.call_count == 0 run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1") run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")