Skip to content

Commit 96df8c4

Browse files
senarviSeppo Enarvi
andauthored
Fix the file extension of model checkpoints uploaded by NeptuneLogger (#20581)
Fix file extension of uploaded model checkpoints Co-authored-by: Seppo Enarvi <[email protected]>
1 parent 41315bc commit 96df8c4

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

src/lightning/pytorch/loggers/neptune.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -508,17 +508,14 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
508508
if not self._log_model_checkpoints:
509509
return
510510

511-
from neptune.types import File
512-
513511
file_names = set()
514512
checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")
515513

516514
# save last model
517515
if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
518516
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
519517
file_names.add(model_last_name)
520-
with open(checkpoint_callback.last_model_path, "rb") as fp:
521-
self.run[f"{checkpoints_namespace}/{model_last_name}"] = File.from_stream(fp)
518+
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
522519

523520
# save best k models
524521
if hasattr(checkpoint_callback, "best_k_models"):
@@ -533,8 +530,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
533530

534531
model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
535532
file_names.add(model_name)
536-
with open(checkpoint_callback.best_model_path, "rb") as fp:
537-
self.run[f"{checkpoints_namespace}/{model_name}"] = File.from_stream(fp)
533+
self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
538534

539535
# remove old models logged to experiment if they are not part of best k models at this point
540536
if self.run.exists(checkpoints_namespace):

tests/tests_pytorch/loggers/test_neptune.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,10 @@ def test_after_save_checkpoint(neptune_mock):
256256
mock_file.side_effect = mock.Mock()
257257
logger.after_save_checkpoint(cb_mock)
258258

259-
assert run_instance_mock.__setitem__.call_count == 3
260-
assert run_instance_mock.__getitem__.call_count == 2
261-
assert run_attr_mock.upload.call_count == 2
262-
263-
assert mock_file.from_stream.call_count == 2
259+
assert run_instance_mock.__setitem__.call_count == 1 # best_model_path
260+
assert run_instance_mock.__getitem__.call_count == 4 # last_model_path, best_k_models, best_model_path
261+
assert run_attr_mock.upload.call_count == 4 # last_model_path, best_k_models, best_model_path
262+
assert mock_file.from_stream.call_count == 0
264263

265264
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
266265
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")

0 commit comments

Comments
 (0)