Skip to content

Commit f2f3ef5

Browse files
authored
Proper support for Remote Stop and Remote Abort with NeptuneLogger (#19130)
1 parent 0235543 commit f2f3ef5

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4848
- Fixed the divisibility check for `Trainer.accumulate_grad_batches` and `Trainer.log_every_n_steps` in ThroughputMonitor ([#19470](https://github.com/Lightning-AI/lightning/pull/19470))
4949

5050

51+
- Fixed support for Remote Stop and Remote Abort with NeptuneLogger ([#19130](https://github.com/Lightning-AI/pytorch-lightning/pull/19130))
52+
53+
5154
-
5255

5356

src/lightning/pytorch/loggers/neptune.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import logging
2121
import os
2222
from argparse import Namespace
23-
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Union
23+
from functools import wraps
24+
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union
2425

2526
from lightning_utilities.core.imports import RequirementCache
2627
from torch import Tensor
@@ -48,6 +49,19 @@
4849
_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"
4950

5051

52+
# Neptune client throws `InactiveRunException` when trying to log to an inactive run.
53+
# This may happen when the run was stopped through the UI and the logger is still trying to log to it.
54+
def _catch_inactive(func: Callable) -> Callable:
55+
@wraps(func)
56+
def wrapper(*args: Any, **kwargs: Any) -> Any:
57+
from neptune.exceptions import InactiveRunException
58+
59+
with contextlib.suppress(InactiveRunException):
60+
return func(*args, **kwargs)
61+
62+
return wrapper
63+
64+
5165
class NeptuneLogger(Logger):
5266
r"""Log using `Neptune <https://neptune.ai>`_.
5367
@@ -245,10 +259,7 @@ def __init__(
245259
if self._run_instance is not None:
246260
self._retrieve_run_data()
247261

248-
if _NEPTUNE_AVAILABLE:
249-
from neptune.handler import Handler
250-
else:
251-
from neptune.new.handler import Handler
262+
from neptune.handler import Handler
252263

253264
# make sure that we've log integration version for outside `Run` instances
254265
root_obj = self._run_instance
@@ -383,6 +394,7 @@ def run(self) -> "Run":
383394

384395
@override
385396
@rank_zero_only
397+
@_catch_inactive
386398
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
387399
r"""Log hyperparameters to the run.
388400
@@ -430,9 +442,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
430442

431443
@override
432444
@rank_zero_only
433-
def log_metrics( # type: ignore[override]
434-
self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None
435-
) -> None:
445+
@_catch_inactive
446+
def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
436447
"""Log metrics (numeric values) in Neptune runs.
437448
438449
Args:
@@ -450,6 +461,7 @@ def log_metrics( # type: ignore[override]
450461

451462
@override
452463
@rank_zero_only
464+
@_catch_inactive
453465
def finalize(self, status: str) -> None:
454466
if not self._run_instance:
455467
# When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
@@ -473,6 +485,7 @@ def save_dir(self) -> Optional[str]:
473485
return os.path.join(os.getcwd(), ".neptune")
474486

475487
@rank_zero_only
488+
@_catch_inactive
476489
def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None:
477490
from neptune.types import File
478491

@@ -483,6 +496,7 @@ def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) ->
483496

484497
@override
485498
@rank_zero_only
499+
@_catch_inactive
486500
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
487501
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
488502

tests/tests_pytorch/loggers/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ def __setitem__(self, key, value):
141141
neptune_utils.stringify_unsupported = Mock()
142142
monkeypatch.setitem(sys.modules, "neptune.utils", neptune_utils)
143143

144+
neptune_exceptions = ModuleType("exceptions")
145+
neptune_exceptions.InactiveRunException = Exception
146+
monkeypatch.setitem(sys.modules, "neptune.exceptions", neptune_exceptions)
147+
144148
neptune.handler = neptune_handler
145149
neptune.types = neptune_types
146150
neptune.utils = neptune_utils

tests/tests_pytorch/loggers/test_neptune.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,13 @@ def test_get_full_model_names_from_exp_structure():
303303
}
304304
expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"}
305305
assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys
306+
307+
308+
def test_inactive_run(neptune_mock, tmp_path):
309+
from neptune.exceptions import InactiveRunException
310+
311+
logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project")
312+
run_instance_mock.__setitem__.side_effect = InactiveRunException
313+
314+
# this should work without any exceptions
315+
_fit_and_test(logger=logger, model=BoringModel(), tmp_path=tmp_path)

0 commit comments

Comments
 (0)