Skip to content

Commit 73b416d

Browse files
awaelchliRaalsky
authored andcommitted
Proper support for Remote Stop and Remote Abort with NeptuneLogger (#19130)
Co-authored-by: Rafał Jankowski <[email protected]>
1 parent d90fb53 commit 73b416d

File tree

4 files changed

+37
-8
lines changed

4 files changed

+37
-8
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446))
1313
- Fixed the divisibility check for `Trainer.accumulate_grad_batches` and `Trainer.log_every_n_steps` in ThroughputMonitor ([#19470](https://github.com/Lightning-AI/lightning/pull/19470))
14+
- Fixed support for Remote Stop and Remote Abort with NeptuneLogger ([#19130](https://github.com/Lightning-AI/pytorch-lightning/pull/19130))
1415

1516

1617
## [2.2.0] - 2024-02-08

src/lightning/pytorch/loggers/neptune.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import logging
2020
import os
2121
from argparse import Namespace
22-
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Union
22+
from functools import wraps
23+
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union
2324

2425
from lightning_utilities.core.imports import RequirementCache
2526
from torch import Tensor
@@ -44,6 +45,19 @@
4445
_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"
4546

4647

48+
# Neptune client throws `InactiveRunException` when trying to log to an inactive run.
49+
# This may happen when the run was stopped through the UI and the logger is still trying to log to it.
50+
def _catch_inactive(func: Callable) -> Callable:
51+
@wraps(func)
52+
def wrapper(*args: Any, **kwargs: Any) -> Any:
53+
from neptune.exceptions import InactiveRunException
54+
55+
with contextlib.suppress(InactiveRunException):
56+
return func(*args, **kwargs)
57+
58+
return wrapper
59+
60+
4761
class NeptuneLogger(Logger):
4862
r"""Log using `Neptune <https://neptune.ai>`_.
4963
@@ -240,10 +254,7 @@ def __init__(
240254
if self._run_instance is not None:
241255
self._retrieve_run_data()
242256

243-
if _NEPTUNE_AVAILABLE:
244-
from neptune.handler import Handler
245-
else:
246-
from neptune.new.handler import Handler
257+
from neptune.handler import Handler
247258

248259
# make sure that we've log integration version for outside `Run` instances
249260
root_obj = self._run_instance
@@ -390,6 +401,7 @@ def run(self) -> "Run":
390401

391402
@override
392403
@rank_zero_only
404+
@_catch_inactive
393405
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
394406
r"""Log hyperparameters to the run.
395407
@@ -440,9 +452,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: #
440452

441453
@override
442454
@rank_zero_only
443-
def log_metrics( # type: ignore[override]
444-
self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None
445-
) -> None:
455+
@_catch_inactive
456+
def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
446457
"""Log metrics (numeric values) in Neptune runs.
447458
448459
Args:
@@ -460,6 +471,7 @@ def log_metrics( # type: ignore[override]
460471

461472
@override
462473
@rank_zero_only
474+
@_catch_inactive
463475
def finalize(self, status: str) -> None:
464476
if not self._run_instance:
465477
# When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
@@ -483,6 +495,7 @@ def save_dir(self) -> Optional[str]:
483495
return os.path.join(os.getcwd(), ".neptune")
484496

485497
@rank_zero_only
498+
@_catch_inactive
486499
def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None:
487500
if _NEPTUNE_AVAILABLE:
488501
from neptune.types import File
@@ -496,6 +509,7 @@ def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) ->
496509

497510
@override
498511
@rank_zero_only
512+
@_catch_inactive
499513
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
500514
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
501515

tests/tests_pytorch/loggers/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ def __setitem__(self, key, value):
141141
neptune_utils.stringify_unsupported = Mock()
142142
monkeypatch.setitem(sys.modules, "neptune.utils", neptune_utils)
143143

144+
neptune_exceptions = ModuleType("exceptions")
145+
neptune_exceptions.InactiveRunException = Exception
146+
monkeypatch.setitem(sys.modules, "neptune.exceptions", neptune_exceptions)
147+
144148
neptune.handler = neptune_handler
145149
neptune.types = neptune_types
146150
neptune.utils = neptune_utils

tests/tests_pytorch/loggers/test_neptune.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,13 @@ def test_get_full_model_names_from_exp_structure():
305305
}
306306
expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"}
307307
assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys
308+
309+
310+
def test_inactive_run(neptune_mock, tmp_path):
311+
from neptune.exceptions import InactiveRunException
312+
313+
logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project")
314+
run_instance_mock.__setitem__.side_effect = InactiveRunException
315+
316+
# this should work without any exceptions
317+
_fit_and_test(logger=logger, model=BoringModel(), tmp_path=tmp_path)

0 commit comments

Comments
 (0)