Skip to content

Commit 2c38813

Browse files
feat: Update NeptuneScaleLogger to log model checkpoint paths instead of uploading checkpoints
1 parent af773d6 commit 2c38813

File tree

2 files changed

+47
-54
lines changed

2 files changed

+47
-54
lines changed

src/lightning/pytorch/loggers/neptune.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -669,18 +669,18 @@ def any_lightning_module_function_or_hook(self):
669669
neptune_logger.run.log_configs(data={"your/metadata/structure": metadata})
670670
neptune_logger.run.add_tags(["tag1", "tag2"])
671671
672-
**Log model checkpoints**
672+
**Log model checkpoint paths**
673673
674674
If you have :class:`~lightning.pytorch.callbacks.ModelCheckpoint` configured,
675-
the Neptune logger automatically logs model checkpoints.
676-
Model weights will be uploaded to the "model/checkpoints" namespace in the Neptune run.
675+
the Neptune logger can log model checkpoint paths.
676+
Paths will be logged to the "model/checkpoints" namespace in the Neptune run.
677677
You can disable this option with:
678678
679679
.. code-block:: python
680680
681681
neptune_logger = NeptuneScaleLogger(log_model_checkpoints=False)
682682
683-
Note: All model checkpoints will be uploaded. ``save_last`` and ``save_top_k`` are currently not supported.
683+
Note: All model checkpoint paths will be logged. ``save_last`` and ``save_top_k`` are currently not supported.
684684
685685
**Pass additional parameters to the Neptune run**
686686
@@ -743,9 +743,9 @@ def any_lightning_module_function_or_hook(self):
743743
run: Optional. Default is ``None``. A Neptune ``Run`` object.
744744
If specified, this existing run will be used for logging, instead of a new run being created.
745745
prefix: Optional. Default is ``"training"``. Root namespace for all metadata logging.
746-
log_model_checkpoints: Optional. Default is ``True``. Log model checkpoints to Neptune.
746+
log_model_checkpoints: Optional. Default is ``True``. Log model checkpoint paths to Neptune.
747747
Works only if ``ModelCheckpoint`` is passed to the ``Trainer``.
748-
NOTE: All model checkpoints will be uploaded.
748+
NOTE: All model checkpoint paths will be logged.
749749
``save_last`` and ``save_top_k`` are currently not supported.
750750
neptune_run_kwargs: Additional arguments like ``creation_time``, ``log_directory``,
751751
``fork_run_id``, ``fork_step``, ``*_callback``, etc. used when a run is created.
@@ -1050,7 +1050,7 @@ def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) ->
10501050
@override
10511051
@rank_zero_only
10521052
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
1053-
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
1053+
"""Automatically log checkpointed model's path. Called after model checkpoint callback saves a new checkpoint.
10541054
10551055
Args:
10561056
checkpoint_callback: the model checkpoint callback instance
@@ -1066,7 +1066,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
10661066
if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
10671067
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
10681068
file_names.add(model_last_name)
1069-
self.run.assign_files({
1069+
self.run.log_configs({
10701070
f"{checkpoints_namespace}/{model_last_name}": checkpoint_callback.last_model_path,
10711071
})
10721072

@@ -1075,7 +1075,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
10751075
for key in checkpoint_callback.best_k_models:
10761076
model_name = self._get_full_model_name(key, checkpoint_callback)
10771077
file_names.add(model_name)
1078-
self.run.assign_files({
1078+
self.run.log_configs({
10791079
f"{checkpoints_namespace}/{model_name}": key,
10801080
})
10811081

@@ -1087,7 +1087,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
10871087

10881088
model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
10891089
file_names.add(model_name)
1090-
self.run.assign_files({
1090+
self.run.log_configs({
10911091
f"{checkpoints_namespace}/{model_name}": checkpoint_callback.best_model_path,
10921092
})
10931093

tests/tests_pytorch/loggers/test_neptune.py

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,8 @@ def test_neptune_scale_logger_finalize(neptune_scale_logger):
461461
"""Test finalize method sets status correctly."""
462462
logger, mock_run = neptune_scale_logger
463463
logger.finalize("success")
464-
assert mock_run._status == "success"
464+
expected_key = logger._construct_path_with_prefix("status")
465+
mock_run.log_configs.assert_any_call({expected_key: "success"})
465466

466467

467468
@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
@@ -472,24 +473,20 @@ def test_neptune_scale_logger_invalid_run():
472473

473474

474475
@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
475-
def test_neptune_scale_logger_log_model_summary(neptune_scale_logger):
476+
def test_neptune_scale_logger_log_model_summary(neptune_scale_logger, monkeypatch):
476477
from neptune_scale.types import File
477478

478479
model = BoringModel()
479-
test_variants = [
480-
({}, "training/model/summary"),
481-
({"prefix": "custom_prefix"}, "custom_prefix/model/summary"),
482-
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model/summary"),
483-
]
484-
485-
for prefix, model_summary_key in test_variants:
486-
logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project", **prefix)
487-
488-
logger.log_model_summary(model)
489-
490-
assert run_instance_mock.__setitem__.call_count == 1
491-
assert run_instance_mock.__getitem__.call_count == 0
492-
run_instance_mock.__setitem__.assert_called_once_with(model_summary_key, File)
480+
logger, mock_run = neptune_scale_logger
481+
# Patch assign_files to track calls
482+
assign_files_mock = mock.MagicMock()
483+
monkeypatch.setattr(mock_run, "assign_files", assign_files_mock)
484+
logger.log_model_summary(model)
485+
# Check that assign_files was called with the correct key and a File instance
486+
called_args = assign_files_mock.call_args[0][0]
487+
assert list(called_args.keys())[0].endswith("model/summary")
488+
file_val = list(called_args.values())[0]
489+
assert isinstance(file_val, File)
493490

494491

495492
@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
@@ -511,32 +508,28 @@ def test_neptune_scale_logger_with_prefix(neptune_scale_logger):
511508

512509
@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
513510
def test_neptune_scale_logger_after_save_checkpoint(neptune_scale_logger):
514-
test_variants = [
515-
({}, "training/model"),
516-
({"prefix": "custom_prefix"}, "custom_prefix/model"),
517-
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model"),
511+
logger, mock_run = neptune_scale_logger
512+
models_root_dir = os.path.join("path", "to", "models")
513+
cb_mock = MagicMock(
514+
dirpath=models_root_dir,
515+
last_model_path=os.path.join(models_root_dir, "last"),
516+
best_k_models={
517+
f"{os.path.join(models_root_dir, 'model1')}": None,
518+
f"{os.path.join(models_root_dir, 'model2/with/slashes')}": None,
519+
},
520+
best_model_path=os.path.join(models_root_dir, "best_model"),
521+
best_model_score=None,
522+
)
523+
logger.after_save_checkpoint(cb_mock)
524+
prefix = logger._prefix
525+
model_key_prefix = f"{prefix}/model" if prefix else "model"
526+
expected_calls = [
527+
call.log_configs({f"{model_key_prefix}/checkpoints/model1": os.path.join(models_root_dir, "model1")}),
528+
call.log_configs({
529+
f"{model_key_prefix}/checkpoints/model2/with/slashes": os.path.join(models_root_dir, "model2/with/slashes")
530+
}),
531+
call.log_configs({f"{model_key_prefix}/checkpoints/last": os.path.join(models_root_dir, "last")}),
532+
call.log_configs({f"{model_key_prefix}/checkpoints/best_model": os.path.join(models_root_dir, "best_model")}),
533+
call.log_configs({f"{model_key_prefix}/best_model_path": os.path.join(models_root_dir, "best_model")}),
518534
]
519-
520-
for prefix, model_key_prefix in test_variants:
521-
logger, run_instance_mock, run_attr_mock = _get_logger_with_mocks(api_key="test", project="project", **prefix)
522-
models_root_dir = os.path.join("path", "to", "models")
523-
cb_mock = MagicMock(
524-
dirpath=models_root_dir,
525-
last_model_path=os.path.join(models_root_dir, "last"),
526-
best_k_models={
527-
f"{os.path.join(models_root_dir, 'model1')}": None,
528-
f"{os.path.join(models_root_dir, 'model2/with/slashes')}": None,
529-
},
530-
best_model_path=os.path.join(models_root_dir, "best_model"),
531-
best_model_score=None,
532-
)
533-
534-
logger.after_save_checkpoint(cb_mock)
535-
536-
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
537-
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")
538-
539-
run_attr_mock.upload.assert_has_calls([
540-
call(os.path.join(models_root_dir, "model1")),
541-
call(os.path.join(models_root_dir, "model2/with/slashes")),
542-
])
535+
mock_run.log_configs.assert_has_calls(expected_calls, any_order=True)

0 commit comments

Comments
 (0)