diff --git a/README.md b/README.md
index aa58c6d8a585a..f8cfb0797b1a4 100644
--- a/README.md
+++ b/README.md
@@ -279,6 +279,9 @@ trainer = Trainer(logger=loggers.MLFlowLogger())
# neptune
trainer = Trainer(logger=loggers.NeptuneLogger())
+# neptune scale
+trainer = Trainer(logger=loggers.NeptuneScaleLogger())
+
# ... and dozens more
```
diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py
index 90400b1df491d..821a0afd71983 100644
--- a/docs/source-pytorch/conf.py
+++ b/docs/source-pytorch/conf.py
@@ -367,6 +367,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
# TODO: these are missing objects.inv
# "comet_ml": ("https://www.comet.com/docs/v2/", None),
# "neptune": ("https://docs.neptune.ai/", None),
+ # "neptune scale": ("https://docs-beta.neptune.ai/", None),
# "wandb": ("https://docs.wandb.ai//", None),
}
nitpicky = True
@@ -477,6 +478,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:meth", "move_data_to_device"),
("py:class", "neptune.Run"),
("py:class", "neptune.handler.Handler"),
+ ("py:class", "neptune_scale.Run"),
("py:meth", "on_after_batch_transfer"),
("py:meth", "on_before_batch_transfer"),
("py:meth", "on_save_checkpoint"),
@@ -621,7 +623,7 @@ def package_list_from_file(file):
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
-from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
+from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE, _NEPTUNE_SCALE_AVAILABLE
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
diff --git a/docs/source-pytorch/extensions/logging.rst b/docs/source-pytorch/extensions/logging.rst
index f0c12464e6db2..6f3daa1fa7d43 100644
--- a/docs/source-pytorch/extensions/logging.rst
+++ b/docs/source-pytorch/extensions/logging.rst
@@ -31,6 +31,7 @@ The following are loggers we support:
CSVLogger
MLFlowLogger
NeptuneLogger
+ NeptuneScaleLogger
TensorBoardLogger
WandbLogger
diff --git a/docs/source-pytorch/visualize/supported_exp_managers.rst b/docs/source-pytorch/visualize/supported_exp_managers.rst
index e26514e9747c4..26b4b039461ff 100644
--- a/docs/source-pytorch/visualize/supported_exp_managers.rst
+++ b/docs/source-pytorch/visualize/supported_exp_managers.rst
@@ -60,9 +60,9 @@ Here's the full documentation for the :class:`~lightning.pytorch.loggers.MLFlowL
----
-Neptune.ai
+Neptune 2.x
==========
-To use `Neptune.ai `_ first install the neptune package:
+To use `Neptune 2.x `_ first install the neptune package:
.. code-block:: bash
@@ -101,6 +101,43 @@ Here's the full documentation for the :class:`~lightning.pytorch.loggers.Neptune
----
+Neptune 3.x (Neptune Scale)
+==========
+To use `Neptune 3.x `_ first install the neptune-scale package:
+
+.. code-block:: bash
+
+ pip install neptune-scale
+
+
+Configure the logger and pass it to the :class:`~lightning.pytorch.trainer.trainer.Trainer`:
+
+.. testcode::
+ :skipif: not _NEPTUNE_SCALE_AVAILABLE
+
+ from neptune_scale import Run
+ from lightning.pytorch.loggers import NeptuneScaleLogger
+
+ neptune_scale_logger = NeptuneScaleLogger(
+ api_key=, # replace with your own
+ project=/, # replace with your own
+ )
+ trainer = Trainer(logger=neptune_scale_logger)
+
+Access the Neptune Scale logger from any function (except the LightningModule *init*) to use its API for tracking advanced artifacts
+
+.. code-block:: python
+
+ class LitModel(LightningModule):
+ def any_lightning_module_function_or_hook(self):
+ neptune_scale_logger = self.logger.experiment
+ neptune_scale_logger.log_metrics(data={"path/to/metadata": metadata}, step=step)
+ neptune_scale_logger.log_configs(data={"path/to/config": config})
+
+Here's the full documentation for the :class:`~lightning.pytorch.loggers.NeptuneScaleLogger`.
+
+----
+
Tensorboard
===========
`TensorBoard `_ can be installed with:
diff --git a/requirements/pytorch/loggers.info b/requirements/pytorch/loggers.info
index 94ff89ff6c62f..ca0c8369935ab 100644
--- a/requirements/pytorch/loggers.info
+++ b/requirements/pytorch/loggers.info
@@ -1,6 +1,7 @@
# all supported loggers. this list is here as a reference, but they are not installed in CI
neptune >=1.0.0
+neptune-scale >= 0.12.0
comet-ml >=3.31.0
mlflow >=1.0.0
wandb >=0.12.10
diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md
index 0095367e9187a..326f6068ee8d7 100644
--- a/src/lightning/pytorch/CHANGELOG.md
+++ b/src/lightning/pytorch/CHANGELOG.md
@@ -27,6 +27,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/issues/20692))
+---
+
+## [unreleased] - 2025-03-31
+
+### Added
+
+- Add support for Neptune Scale logger ([#20686](https://github.com/Lightning-AI/pytorch-lightning/pull/20686))
+
+
---
## [2.5.1] - 2025-03-18
diff --git a/src/lightning/pytorch/loggers/__init__.py b/src/lightning/pytorch/loggers/__init__.py
index 3d7d9e4c20139..379434da30403 100644
--- a/src/lightning/pytorch/loggers/__init__.py
+++ b/src/lightning/pytorch/loggers/__init__.py
@@ -15,8 +15,17 @@
from lightning.pytorch.loggers.csv_logs import CSVLogger
from lightning.pytorch.loggers.logger import Logger
from lightning.pytorch.loggers.mlflow import MLFlowLogger
-from lightning.pytorch.loggers.neptune import NeptuneLogger
+from lightning.pytorch.loggers.neptune import NeptuneLogger, NeptuneScaleLogger
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.loggers.wandb import WandbLogger
-__all__ = ["CometLogger", "CSVLogger", "Logger", "MLFlowLogger", "TensorBoardLogger", "WandbLogger", "NeptuneLogger"]
+__all__ = [
+ "CometLogger",
+ "CSVLogger",
+ "Logger",
+ "MLFlowLogger",
+ "TensorBoardLogger",
+ "WandbLogger",
+ "NeptuneLogger",
+ "NeptuneScaleLogger",
+]
diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py
index bf9669c824784..9151a9d602276 100644
--- a/src/lightning/pytorch/loggers/neptune.py
+++ b/src/lightning/pytorch/loggers/neptune.py
@@ -21,6 +21,7 @@
import os
from argparse import Namespace
from collections.abc import Generator
+from datetime import datetime
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@@ -29,24 +30,28 @@
from typing_extensions import override
import lightning.pytorch as pl
-from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
+from lightning.fabric.utilities.logger import (
+ _add_prefix,
+ _convert_params,
+ _sanitize_callable_params,
+)
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.rank_zero import rank_zero_only
-if TYPE_CHECKING:
- from neptune import Run
- from neptune.handler import Handler
-
log = logging.getLogger(__name__)
-# Neptune is available with two names on PyPI : `neptune` and `neptune-client`
-# `neptune` was introduced as a name transition of neptune-client and the long-term target is to get
-# rid of Neptune-client package completely someday. It was introduced as a part of breaking-changes with a release
-# of neptune-client==1.0. neptune-client>=1.0 is just an alias of neptune package and have some breaking-changes
-# in compare to neptune-client<1.0.0.
+_NEPTUNE_SCALE_AVAILABLE = RequirementCache("neptune-scale")
_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0")
+
+if TYPE_CHECKING:
+ if _NEPTUNE_AVAILABLE:
+ from neptune import Run
+ from neptune.handler import Handler
+ elif _NEPTUNE_SCALE_AVAILABLE:
+ from neptune_scale import Run
+
_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"
@@ -64,7 +69,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
class NeptuneLogger(Logger):
- r"""Log using `Neptune `_.
+ r"""Log using `Neptune `_.
Install it with pip:
@@ -124,7 +129,7 @@ def any_lightning_module_function_or_hook(self):
Note that the syntax ``self.logger.experiment["your/metadata/structure"].append(metadata)`` is specific to
Neptune and extends the logger capabilities. It lets you log various types of metadata, such as
scores, files, images, interactive visuals, and CSVs.
- Refer to the `Neptune docs `_
+ Refer to the `Neptune docs `_
for details.
You can also use the regular logger methods ``log_metrics()``, and ``log_hyperparams()`` with NeptuneLogger.
@@ -179,7 +184,7 @@ def any_lightning_module_function_or_hook(self):
)
trainer = Trainer(max_epochs=3, logger=neptune_logger)
- Check `run documentation `_
+ Check `run documentation `_
for more info about additional run parameters.
**Details about Neptune run structure**
@@ -191,18 +196,18 @@ def any_lightning_module_function_or_hook(self):
See also:
- Read about
- `what objects you can log to Neptune `_.
+ `what objects you can log to Neptune `_.
- Check out an `example run `_
with multiple types of metadata logged.
- For more detailed examples, see the
- `user guide `_.
+ `user guide `_.
Args:
api_key: Optional.
Neptune API token, found on https://www.neptune.ai upon registration.
You should save your token to the `NEPTUNE_API_TOKEN`
environment variable and leave the api_key argument out of your code.
- Instructions: `Setting your API token `_.
+ Instructions: `Setting your API token `_.
project: Optional.
Name of a project in the form "workspace-name/project-name", for example "tom/mask-rcnn".
If ``None``, the value of `NEPTUNE_PROJECT` environment variable is used.
@@ -372,7 +377,7 @@ def training_step(self, batch, batch_idx):
is specific to Neptune and extends the logger capabilities.
It lets you log various types of metadata, such as scores, files,
images, interactive visuals, and CSVs. Refer to the
- `Neptune docs `_
+ `Neptune docs `_
for more detailed explanations.
You can also use the regular logger methods ``log_metrics()``, and ``log_hyperparams()``
with NeptuneLogger.
@@ -592,3 +597,558 @@ def version(self) -> Optional[str]:
"""
return self._run_short_id
+
+
+class NeptuneScaleLogger(Logger):
+ r"""Log using `Neptune Scale `_.
+
+ Install it with pip:
+
+ .. code-block:: bash
+
+ pip install neptune-scale
+
+ **Quickstart**
+
+ Pass a NeptuneScaleLogger instance to the Trainer to log metadata with Neptune Scale:
+
+ .. code-block:: python
+
+
+ from lightning.pytorch import Trainer
+ from lightning.pytorch.loggers import NeptuneScaleLogger
+
+ neptune_scale_logger = NeptuneScaleLogger(
+ api_token="",
+ project="",
+ )
+ trainer = Trainer(max_epochs=10, logger=neptune_scale_logger)
+
+ **How to use NeptuneScaleLogger?**
+
+ Use the logger anywhere in your :class:`~lightning.pytorch.core.LightningModule` as follows:
+
+ .. code-block:: python
+
+ from lightning.pytorch import LightningModule
+
+
+ class LitModel(LightningModule):
+ def training_step(self, batch, batch_idx):
+ # log metrics
+ loss = ...
+ self.append("train/loss", loss)
+
+ def any_lightning_module_function_or_hook(self):
+ # generic recipe
+ metadata = ...
+ self.logger.run.log_metrics(data={"your/metadata/structure": metadata}, step=step)
+
+ Note that the syntax ``self.logger.run.log_metrics(data={"your/metadata/structure": metadata}, step=step)``
+ is specific to Neptune Scale.
+ Refer to the `Neptune Scale docs `_ for details.
+ You can also use the regular logger methods ``log_metrics()``, and ``log_hyperparams()`` with NeptuneScaleLogger.
+
+ **Log after fitting or testing is finished**
+
+ You can log objects after the fitting or testing methods are finished:
+
+ .. code-block:: python
+
+ neptune_scale_logger = NeptuneScaleLogger()
+
+ trainer = pl.Trainer(logger=neptune_scale_logger)
+ model = ...
+ datamodule = ...
+ trainer.fit(model, datamodule=datamodule)
+ trainer.test(model, datamodule=datamodule)
+
+ # Log objects after `fit` or `test` methods
+ # generic recipe
+ metadata = ...
+ neptune_logger.run.log_configs(data={"your/metadata/structure": metadata})
+ neptune_logger.run.add_tags(["tag1", "tag2"])
+
+ **Log model checkpoints**
+
+ If you have :class:`~lightning.pytorch.callbacks.ModelCheckpoint` configured,
+ the Neptune logger automatically logs model checkpoints.
+ Model weights will be uploaded to the "model/checkpoints" namespace in the Neptune run.
+ You can disable this option with:
+
+ .. code-block:: python
+
+ neptune_logger = NeptuneScaleLogger(log_model_checkpoints=False)
+
+ Note: All model checkpoints will be uploaded. ``save_last`` and ``save_top_k`` are currently not supported.
+
+ **Pass additional parameters to the Neptune run**
+
+ You can also pass ``neptune_run_kwargs`` to add details to the run, like ``creation_time``,
+ ``log_directory``, ``fork_run_id``, ``fork_step`` or ``*_callback``:
+
+ .. code-block:: python
+
+ from lightning.pytorch import Trainer
+ from lightning.pytorch.loggers import NeptuneScaleLogger
+
+ neptune_scale_logger = NeptuneScaleLogger(
+ log_directory="logs",
+ fork_run_id="fast-lightning-1",
+ fork_step=420,
+ )
+ trainer = Trainer(max_epochs=3, logger=neptune_scale_logger)
+
+ Check `run documentation `_ for more info about additional run
+ parameters.
+
+ **Details about Neptune run structure**
+
+ Runs can be viewed as nested dictionary-like structures that you can define in your code.
+ Thanks to this you can easily organize your metadata in a way that is most convenient for you.
+
+ The hierarchical structure that you apply to your metadata is reflected in the Neptune web app.
+
+ Args:
+ run_id: Optional.
+ Identifier of the run. Max length: 128 bytes.
+ The custom run ID provided to the run_id argument must be unique within the project.
+ It can't contain the / character.
+ If not provided, a random, human-readable ID is generated.
+ project: Optional.
+ Name of a project in the form "workspace-name/project-name", for example "tom/mask-rcnn".
+ If ``None``, the value of `NEPTUNE_PROJECT` environment variable is used.
+ You need to create the project on https://scale.neptune.ai first.
+ api_token: Optional.
+ Neptune API token, found on https://scale.neptune.ai upon registration.
+ You should save your token to the `NEPTUNE_API_TOKEN` environment variable and leave
+ the api_token argument out of your code.
+ Instructions: `Setting your API token `_.
+ resume: Optional.
+ If `False`, creates a new run.
+ To continue an existing run, set to `True` and pass the ID of an existing run to the `run_id` argument.
+ In this case, omit the `experiment_name` parameter.
+ To fork a run, use `fork_run_id` and `fork_step` instead.
+ mode: Optional.
+ `Mode `_ of operation.
+ If "disabled", the run doesn't log any metadata.
+ If "offline", the run is only stored locally. For details, see `Offline logging `_.
+ If this parameter and the
+ `NEPTUNE_MODE `_
+ environment variable are not set, the default is "async".
+ experiment_name: Optional.
+ Name of the experiment to associate the run with.
+ Can't be used together with the `resume` parameter.
+ To make the name easy to read in the app, ensure that it's at most 190 characters long.
+ run: Optional. Default is ``None``. A Neptune ``Run`` object.
+ If specified, this existing run will be used for logging, instead of a new run being created.
+ prefix: Optional. Default is ``"training"``. Root namespace for all metadata logging.
+ log_model_checkpoints: Optional. Default is ``True``. Log model checkpoints to Neptune.
+ Works only if ``ModelCheckpoint`` is passed to the ``Trainer``.
+ NOTE: All model checkpoints will be uploaded.
+ ``save_last`` and ``save_top_k`` are currently not supported.
+ neptune_run_kwargs: Additional arguments like ``creation_time``, ``log_directory``,
+ ``fork_run_id``, ``fork_step``, ``*_callback``, etc. used when a run is created.
+
+ Raises:
+ ModuleNotFoundError:
+ If the required Neptune package is not installed.
+ ValueError:
+ If an argument passed to the logger's constructor is incorrect.
+
+ """
+
+ LOGGER_JOIN_CHAR = "/"
+ PARAMETERS_KEY = "hyperparams"
+ DEFAULT_SAVE_DIR = ".neptune"
+ ALLOWED_DATATYPES = [int, float, str, datetime, bool, list, set]
+
+ def __init__(
+ self,
+ *, # force users to call `NeptuneScaleLogger` initializer with `kwargs`
+ run_id: Optional[str] = None,
+ project: Optional[str] = None,
+ api_token: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ run: Optional["Run"] = None,
+ prefix: str = "training",
+ log_model_checkpoints: Optional[bool] = True,
+ **neptune_run_kwargs: Any,
+ ):
+ if not _NEPTUNE_SCALE_AVAILABLE:
+ raise ModuleNotFoundError(str(_NEPTUNE_SCALE_AVAILABLE))
+
+ # verify if user passed proper init arguments
+ self._verify_input_arguments(
+ api_token,
+ project,
+ run,
+ run_id,
+ experiment_name,
+ neptune_run_kwargs,
+ )
+ super().__init__()
+ self._api_token = api_token
+ self._project = project
+ self._run_instance = run
+ self._run_id = run_id
+ self._experiment_name = experiment_name
+ self._prefix = prefix
+ self._log_model_checkpoints = log_model_checkpoints
+ self._neptune_run_kwargs = neptune_run_kwargs
+ self._description = self._neptune_run_kwargs.pop("description", None)
+ self._tags = self._neptune_run_kwargs.pop("tags", None)
+ self._group_tags = self._neptune_run_kwargs.pop("group_tags", None)
+
+ if self._run_instance is not None:
+ self._retrieve_run_data()
+
+ else:
+ from neptune_scale import Run
+
+ self._run_instance = Run(**self._neptune_init_args)
+
+ root_obj = self._run_instance
+
+ root_obj.log_configs(data={_INTEGRATION_VERSION_KEY: pl.__version__})
+
+ def _retrieve_run_data(self) -> None:
+ assert self._run_instance is not None
+ root_obj = self._run_instance
+ root_obj.wait_for_submission()
+
+ self._run_id = root_obj._run_id
+ self._experiment_name = root_obj._experiment_name
+
+ @property
+ def _neptune_init_args(self) -> dict:
+ args: dict = {}
+
+ args = self._neptune_run_kwargs
+
+ if self._project is not None:
+ args["project"] = self._project
+
+ if self._api_token is not None:
+ args["api_token"] = self._api_token
+
+ if self._run_id is not None:
+ args["run_id"] = self._run_id
+
+ if self._experiment_name is not None:
+ args["experiment_name"] = self._experiment_name
+
+ return args
+
+ def _construct_path_with_prefix(self, *keys: str) -> str:
+ """Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
+ if self._prefix:
+ return self.LOGGER_JOIN_CHAR.join([self._prefix, *keys])
+ return self.LOGGER_JOIN_CHAR.join(keys)
+
+ @staticmethod
+ def _verify_input_arguments(
+ api_token: Optional[str],
+ project: Optional[str],
+ run: Optional["Run"],
+ run_id: Optional[str],
+ experiment_name: Optional[str],
+ neptune_run_kwargs: dict,
+ ) -> None:
+ from neptune_scale import Run
+
+ # check if user passed the client `Run` object
+ if run is not None and not isinstance(run, Run):
+ raise ValueError("Run parameter expected to be of type `neptune_scale.Run`.")
+
+ # check if user passed redundant neptune.init_run arguments when passed run
+ any_neptune_init_arg_passed = (
+ any(arg is not None for arg in [api_token, project, run_id, experiment_name]) or neptune_run_kwargs
+ )
+ if run is not None and any_neptune_init_arg_passed:
+ raise ValueError(
+ "When an already initialized run object is provided, you can't provide other `Run()` "
+ "initialization parameters."
+ )
+
+ def __getstate__(self) -> dict[str, Any]:
+ state = self.__dict__.copy()
+ # Run instance can't be pickled
+ state["_run_instance"] = None
+ return state
+
+ def __setstate__(self, state: dict[str, Any]) -> None:
+ from neptune_scale import Run
+
+ self.__dict__ = state
+ self._run_instance = Run(**self._neptune_init_args)
+
+ @property
+ @rank_zero_experiment
+ def experiment(self) -> "Run":
+ r"""Actual Neptune run object. Allows you to use neptune logging features in your
+ :class:`~lightning.pytorch.core.LightningModule`.
+
+ Example::
+
+ class LitModel(LightningModule):
+ def training_step(self, batch, batch_idx):
+ # log metrics
+ acc = ...
+ self.logger.run.log_metrics(data={"train/acc": acc}, step=step)
+
+ Note that the syntax ``self.logger.run.log_metrics(data={"your/metadata/structure": metadata}, step=step)``
+ is specific to Neptune Scale. Refer to the
+ `Neptune Scale docs `_
+ for more detailed explanations.
+ You can also use the regular logger methods ``log_metrics()``, and ``log_hyperparams()``
+ with NeptuneScaleLogger.
+
+ """
+ return self.run
+
+ @property
+ @rank_zero_experiment
+ def run(self) -> "Run":
+ from neptune_scale import Run
+
+ if not self._run_instance:
+ self._run_instance = Run(**self._neptune_init_args)
+ self._retrieve_run_data()
+ # make sure that we've log integration version for newly created
+ self._run_instance.log_configs({_INTEGRATION_VERSION_KEY: pl.__version__})
+
+ return self._run_instance
+
+ @override
+ @rank_zero_only
+ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
+ r"""Log hyperparameters to the run.
+
+ Hyperparameters will be logged under the "/hyperparams" namespace.
+
+ Note:
+
+ You can also log parameters by directly using the logger instance:
+
+ ``neptune_logger.run.log_configs(
+ data={
+ "data/batch_size": 64,
+ "model/optimizer/name": "adam",
+ "model/optimizer/lr": 0.07,
+ "model/optimizer/decay_factor": 0.97,
+ "model/tokenizer/name": "bert-base-uncased",
+ },
+ )``.
+
+ In this way you can keep hierarchical structure of the parameters.
+
+ Args:
+ params: `dict`.
+ Python dictionary structure with parameters.
+
+ Example::
+
+ from lightning.pytorch.loggers import NeptuneScaleLogger
+
+ PARAMS = {
+ "batch_size": 64,
+ "lr": 0.07,
+ "decay_factor": 0.97,
+ }
+
+ neptune_scale_logger = NeptuneScaleLogger()
+
+ neptune_scale_logger.log_hyperparams(PARAMS)
+
+ """
+ from datetime import datetime
+
+ params = _convert_params(params)
+ params = _sanitize_callable_params(params)
+
+ parameters_key = self.PARAMETERS_KEY
+ parameters_key = self._construct_path_with_prefix(parameters_key)
+
+ allowed_datatypes = [int, float, str, datetime, bool, list, set]
+
+ def flatten(d: dict, prefix: str = "") -> dict[str, Any]:
+ """Flatten a nested dictionary by concatenating keys with '/'."""
+ flattened = {}
+ for key, value in d.items():
+ new_key = f"{prefix}/{key}" if prefix else key
+ if isinstance(value, dict):
+ flattened.update(flatten(value, new_key))
+ elif type(value) in allowed_datatypes:
+ flattened[new_key] = value
+ else:
+ flattened[new_key] = str(value)
+ return flattened
+
+ flattened = flatten(params)
+
+ for key, value in flattened.items():
+ self.run.log_configs({f"{parameters_key}/{key}": value})
+
+ @override
+ @rank_zero_only
+ def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
+ """Log metrics (numeric values) in Neptune runs.
+
+ Args:
+ metrics: Dictionary with metric names as keys and measured quantities as values.
+ step: Step number at which the metrics should be recorded. Defaults to `trainer.global_step`.
+
+ """
+ if rank_zero_only.rank != 0:
+ raise ValueError("run tried to log from global_rank != 0")
+
+ metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
+
+ for key, val in metrics.items():
+ self.run.log_metrics({key: val}, step=step)
+
+ @override
+ @rank_zero_only
+ def finalize(self, status: str) -> None:
+ if not self._run_instance:
+ # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
+ # initialized there
+ return
+ if status:
+ self.run.log_configs({self._construct_path_with_prefix("status"): status})
+
+ super().finalize(status)
+
+ @property
+ @override
+ def save_dir(self) -> Optional[str]:
+ """Gets the save directory of the run.
+
+ Returns:
+ the directory where experiment logs get saved
+
+ """
+ return (
+ self.run._neptune_run_kwargs.get("log_directory", os.path.join(os.getcwd(), ".neptune"))
+ if hasattr(self.run, "_neptune_run_kwargs")
+ else os.path.join(os.getcwd(), ".neptune")
+ )
+
+ @rank_zero_only
+ def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None:
+ """Logs a summary of all layers in the model to Neptune as a text file."""
+ from neptune_scale.types import File
+
+ model_str = str(ModelSummary(model=model, max_depth=max_depth))
+ self.run.assign_files({
+ self._construct_path_with_prefix("model/summary"): File(
+ source=model_str.encode("utf-8"), mime_type="text/plain"
+ )
+ })
+
+ @override
+ @rank_zero_only
+ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
+ """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
+
+ Args:
+ checkpoint_callback: the model checkpoint callback instance
+
+ """
+ if not self._log_model_checkpoints:
+ return
+
+ file_names = set()
+ checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")
+
+ # save last model
+ if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
+ model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
+ file_names.add(model_last_name)
+ self.run.assign_files({
+ f"{checkpoints_namespace}/{model_last_name}": checkpoint_callback.last_model_path,
+ })
+
+ # save best k models
+ if hasattr(checkpoint_callback, "best_k_models"):
+ for key in checkpoint_callback.best_k_models:
+ model_name = self._get_full_model_name(key, checkpoint_callback)
+ file_names.add(model_name)
+ self.run.assign_files({
+ f"{checkpoints_namespace}/{model_name}": key,
+ })
+
+ # log best model path and checkpoint
+ if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path:
+ self.run.log_configs({
+ self._construct_path_with_prefix("model/best_model_path"): checkpoint_callback.best_model_path,
+ })
+
+ model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
+ file_names.add(model_name)
+ self.run.assign_files({
+ f"{checkpoints_namespace}/{model_name}": checkpoint_callback.best_model_path,
+ })
+
+ # remove old models logged to experiment if they are not part of best k models at this point
+ # TODO: Implement after Neptune Scale supports `del`
+ # if self.run.exists(checkpoints_namespace):
+ # exp_structure = self.run.get_structure()
+ # uploaded_model_names = self._get_full_model_names_from_exp_structure(
+ # exp_structure, checkpoints_namespace
+ # )
+
+ # for file_to_drop in list(uploaded_model_names - file_names):
+ # del self.run[f"{checkpoints_namespace}/{file_to_drop}"]
+
+ # log best model score
+ if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score:
+ self.run.log_configs({
+ self._construct_path_with_prefix("model/best_model_score"): float(
+ checkpoint_callback.best_model_score.cpu().detach().numpy()
+ ),
+ })
+
+ @staticmethod
+ def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> None:
+ """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
+ if hasattr(checkpoint_callback, "dirpath"):
+ model_path = os.path.normpath(model_path)
+ expected_model_path = os.path.normpath(checkpoint_callback.dirpath)
+ if not model_path.startswith(expected_model_path):
+ raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
+ # Remove extension from filepath
+ filepath, _ = os.path.splitext(model_path[len(expected_model_path) + 1 :])
+ return filepath.replace(os.sep, "/")
+ return model_path.replace(os.sep, "/")
+
+ @classmethod
+ def _get_full_model_names_from_exp_structure(cls, exp_structure: dict[str, Any], namespace: str) -> set[None]:
+ """Returns all paths to properties which were already logged in `namespace`"""
+ structure_keys: list[str] = namespace.split(cls.LOGGER_JOIN_CHAR)
+ for key in structure_keys:
+ exp_structure = exp_structure[key]
+ uploaded_models_dict = exp_structure
+ return set(cls._dict_paths(uploaded_models_dict))
+
+ @classmethod
+ def _dict_paths(cls, d: dict[str, Any], path_in_build: Optional[str] = None) -> Generator:
+ for k, v in d.items():
+ path = f"{path_in_build}/{k}" if path_in_build is not None else k
+ if not isinstance(v, dict):
+ yield path
+ else:
+ yield from cls._dict_paths(v, path)
+
+ @property
+ @override
+ def name(self) -> Optional[str]:
+ """Return the experiment name."""
+ return self._experiment_name
+
+ @property
+ @override
+ def version(self) -> Optional[str]:
+ """Return the Neptune custom run ID."""
+ return self._run_id
diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md
index f3fb8cb2fd2b3..8e2bca207820c 100644
--- a/src/pytorch_lightning/README.md
+++ b/src/pytorch_lightning/README.md
@@ -252,9 +252,12 @@ trainer = Trainer(logger=loggers.CometLogger())
# mlflow
trainer = Trainer(logger=loggers.MLFlowLogger())
-# neptune
+# neptune 2.x
trainer = Trainer(logger=loggers.NeptuneLogger())
+# neptune 3.x
+trainer = Trainer(logger=loggers.NeptuneScaleLogger())
+
# ... and dozens more
```
diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py
index 6dc3816fac858..004db3700a738 100644
--- a/tests/tests_pytorch/loggers/test_neptune.py
+++ b/tests/tests_pytorch/loggers/test_neptune.py
@@ -23,7 +23,8 @@
import lightning.pytorch as pl
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
-from lightning.pytorch.loggers import NeptuneLogger
+from lightning.pytorch.loggers import NeptuneLogger, NeptuneScaleLogger
+from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE, _NEPTUNE_SCALE_AVAILABLE
def _fetchable_paths(value):
@@ -53,6 +54,7 @@ def _get_logger_with_mocks(**kwargs):
return logger, run_instance_mock, run_attr_mock
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_neptune_online(neptune_mock):
neptune_mock.init_run.return_value.exists.return_value = True
neptune_mock.init_run.return_value.__getitem__.side_effect = _fetchable_paths
@@ -71,6 +73,7 @@ def test_neptune_online(neptune_mock):
created_run_mock.__setitem__.assert_called_once_with("source_code/integrations/pytorch-lightning", pl.__version__)
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_neptune_offline(neptune_mock):
neptune_mock.init_run.return_value.exists.return_value = False
@@ -83,6 +86,7 @@ def test_neptune_offline(neptune_mock):
assert logger._run_name == "offline-name"
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_online_with_custom_run(neptune_mock):
neptune_mock.init_run.return_value.exists.return_value = True
neptune_mock.init_run.return_value.__getitem__.side_effect = _fetchable_paths
@@ -97,6 +101,7 @@ def test_online_with_custom_run(neptune_mock):
assert neptune_mock.init_run.call_count == 0
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_neptune_pickling(neptune_mock):
neptune_mock.init_run.return_value.exists.return_value = True
neptune_mock.init_run.return_value.__getitem__.side_effect = _fetchable_paths
@@ -116,6 +121,7 @@ def test_neptune_pickling(neptune_mock):
assert unpickled.experiment is not None
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_online_with_wrong_kwargs(neptune_mock):
"""Tests combinations of kwargs together with `run` kwarg which makes some of other parameters unavailable in
init."""
@@ -142,6 +148,7 @@ def test_online_with_wrong_kwargs(neptune_mock):
NeptuneLogger(foo="bar")
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_neptune_additional_methods(neptune_mock):
logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project")
@@ -150,6 +157,7 @@ def test_neptune_additional_methods(neptune_mock):
run_instance_mock.__getitem__().log.assert_called_once_with(torch.ones(1))
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_neptune_leave_open_experiment_after_fit(neptune_mock, tmp_path, monkeypatch):
"""Verify that neptune experiment was NOT closed after training."""
monkeypatch.chdir(tmp_path)
@@ -158,6 +166,7 @@ def test_neptune_leave_open_experiment_after_fit(neptune_mock, tmp_path, monkeyp
assert run_instance_mock.stop.call_count == 0
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_neptune_log_metrics_on_trained_model(neptune_mock, tmp_path, monkeypatch):
"""Verify that trained models do log data."""
monkeypatch.chdir(tmp_path)
@@ -172,6 +181,7 @@ def on_validation_epoch_end(self):
run_instance_mock.__getitem__.return_value.append.assert_has_calls([call(42, step=2)])
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_log_hyperparams(neptune_mock):
neptune_mock.utils.stringify_unsupported = lambda x: x
@@ -189,6 +199,7 @@ def test_log_hyperparams(neptune_mock):
run_instance_mock.__setitem__.assert_called_once_with(hyperparams_key, params)
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_log_metrics(neptune_mock):
metrics = {
"foo": 42,
@@ -197,7 +208,10 @@ def test_log_metrics(neptune_mock):
test_variants = [
({}, ("training/foo", "training/bar")),
({"prefix": "custom_prefix"}, ("custom_prefix/foo", "custom_prefix/bar")),
- ({"prefix": "custom/nested/prefix"}, ("custom/nested/prefix/foo", "custom/nested/prefix/bar")),
+ (
+ {"prefix": "custom/nested/prefix"},
+ ("custom/nested/prefix/foo", "custom/nested/prefix/bar"),
+ ),
]
for prefix, (metrics_foo_key, metrics_bar_key) in test_variants:
@@ -210,6 +224,7 @@ def test_log_metrics(neptune_mock):
run_attr_mock.append.assert_has_calls([call(42, step=None), call(555, step=None)])
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_log_model_summary(neptune_mock):
model = BoringModel()
test_variants = [
@@ -229,7 +244,7 @@ def test_log_model_summary(neptune_mock):
run_instance_mock.__setitem__.assert_called_once_with(model_summary_key, file_from_content_mock)
-@mock.patch("builtins.open", mock.mock_open(read_data="test"))
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_after_save_checkpoint(neptune_mock):
test_variants = [
({}, "training/model"),
@@ -270,33 +285,53 @@ def test_after_save_checkpoint(neptune_mock):
])
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_save_dir(neptune_mock):
logger = NeptuneLogger(api_key="test", project="project")
assert logger.save_dir == os.path.join(os.getcwd(), ".neptune")
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_get_full_model_name():
SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
test_input_data = [
- ("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
+ (
+ "key",
+ os.path.join("foo", "bar", "key.ext"),
+ SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
+ ),
(
"key/in/parts",
os.path.join("foo", "bar", "key/in/parts.ext"),
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
),
- ("key", os.path.join("../foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("../foo", "bar"))),
- ("key", os.path.join("foo", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("./foo", "bar/../"))),
+ (
+ "key",
+ os.path.join("../foo", "bar", "key.ext"),
+ SimpleCheckpoint(dirpath=os.path.join("../foo", "bar")),
+ ),
+ (
+ "key",
+ os.path.join("foo", "key.ext"),
+ SimpleCheckpoint(dirpath=os.path.join("./foo", "bar/../")),
+ ),
]
for expected_model_name, model_path, checkpoint in test_input_data:
assert NeptuneLogger._get_full_model_name(model_path, checkpoint) == expected_model_name
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_get_full_model_names_from_exp_structure():
input_dict = {
"foo": {
"bar": {
- "lvl1_1": {"lvl2": {"lvl3_1": "some non important value", "lvl3_2": "some non important value"}},
+ "lvl1_1": {
+ "lvl2": {
+ "lvl3_1": "some non important value",
+ "lvl3_2": "some non important value",
+ }
+ },
"lvl1_2": "some non important value",
},
"other_non_important": {"val100": 100},
@@ -307,6 +342,7 @@ def test_get_full_model_names_from_exp_structure():
assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys
+@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="Neptune is required for this test.")
def test_inactive_run(neptune_mock, tmp_path, monkeypatch):
from neptune.exceptions import InactiveRunException
@@ -316,3 +352,194 @@ def test_inactive_run(neptune_mock, tmp_path, monkeypatch):
# this should work without any exceptions
_fit_and_test(logger=logger, model=BoringModel(), tmp_path=tmp_path)
+
+
+# Fixtures for NeptuneScaleLogger tests
+@pytest.fixture
+def neptune_scale_logger(monkeypatch):
+ """Fixture that provides a NeptuneScaleLogger instance and handles cleanup."""
+ from neptune_scale import Run
+
+ mock_run = MagicMock(spec=Run)
+ monkeypatch.setattr("neptune_scale.Run", MagicMock(return_value=mock_run))
+
+ logger = NeptuneScaleLogger()
+ yield logger, mock_run
+ logger.finalize("success")
+
+
+@pytest.fixture
+def neptune_scale_run():
+ """Fixture that provides a mocked Neptune Scale Run instance."""
+ from neptune_scale import Run
+
+ mock_run = MagicMock(spec=Run)
+ mock_run._run_id = "existing-run-id"
+ mock_run._experiment_name = "test-experiment"
+ return mock_run
+
+
+# NeptuneScaleLogger tests
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_init_with_api_token(neptune_scale_logger):
+ """Test that the logger can be initialized with an API token."""
+ logger, _ = neptune_scale_logger
+
+ assert logger._api_token is None # default value
+ assert logger._project is None # default value
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_init_with_run(neptune_scale_run):
+ """Test that the logger can be initialized with an existing run."""
+ logger = NeptuneScaleLogger(run=neptune_scale_run)
+ try:
+ assert logger.experiment == neptune_scale_run
+ assert logger.version == "existing-run-id"
+ assert logger.name == "test-experiment"
+ finally:
+ logger.finalize("success")
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_log_metrics(neptune_scale_logger):
+ """Test logging metrics."""
+ logger, mock_run = neptune_scale_logger
+
+ metrics = {"loss": 1.23, "accuracy": 0.89}
+ logger.log_metrics(metrics, step=5)
+
+ expected_calls = [
+ call.log_metrics({"training/loss": 1.23}, step=5),
+ call.log_metrics({"training/accuracy": 0.89}, step=5),
+ ]
+ mock_run.assert_has_calls(expected_calls, any_order=True)
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_log_hyperparams(neptune_scale_logger):
+ """Test logging hyperparameters."""
+ logger, mock_run = neptune_scale_logger
+
+ params = {"batch_size": 32, "learning_rate": 0.001}
+ logger.log_hyperparams(params)
+
+ expected_calls = [
+ call.log_configs({"training/hyperparams/batch_size": 32}),
+ call.log_configs({"training/hyperparams/learning_rate": 0.001}),
+ ]
+ mock_run.assert_has_calls(expected_calls, any_order=True)
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_save_dir(neptune_scale_logger):
+ """Test that save_dir returns the correct directory."""
+ logger, _ = neptune_scale_logger
+ assert logger.save_dir.endswith(".neptune")
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_with_tags(monkeypatch):
+ """Test initialization with tags and group tags."""
+ from neptune_scale import Run
+
+ mock_run = MagicMock(spec=Run)
+ monkeypatch.setattr("neptune_scale.Run", MagicMock(return_value=mock_run))
+
+ tags = ["test-tag-1", "test-tag-2"]
+ group_tags = ["group-1", "group-2"]
+ logger = NeptuneScaleLogger(tags=tags, group_tags=group_tags)
+ try:
+ assert logger._tags == tags
+ assert logger._group_tags == group_tags
+ finally:
+ logger.finalize("success")
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_finalize(neptune_scale_logger):
+ """Test finalize method sets status correctly."""
+ logger, mock_run = neptune_scale_logger
+ logger.finalize("success")
+ assert mock_run._status == "success"
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_invalid_run():
+ """Test that initialization with invalid run object raises ValueError."""
+ with pytest.raises(ValueError, match="Run parameter expected to be of type"):
+ NeptuneScaleLogger(run="invalid-run")
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_log_model_summary(neptune_scale_logger):
+ from neptune_scale.types import File
+
+ model = BoringModel()
+ test_variants = [
+ ({}, "training/model/summary"),
+ ({"prefix": "custom_prefix"}, "custom_prefix/model/summary"),
+ ({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model/summary"),
+ ]
+
+ for prefix, model_summary_key in test_variants:
+ logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project", **prefix)
+
+ logger.log_model_summary(model)
+
+ assert run_instance_mock.__setitem__.call_count == 1
+ assert run_instance_mock.__getitem__.call_count == 0
+ run_instance_mock.__setitem__.assert_called_once_with(model_summary_key, File)
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_experiment_getter(neptune_scale_logger):
+ """Test that experiment property returns the run instance."""
+ logger, mock_run = neptune_scale_logger
+ assert logger.experiment == mock_run
+ assert logger.run == mock_run
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_with_prefix(neptune_scale_logger):
+ """Test that logger uses custom prefix correctly."""
+ logger, mock_run = neptune_scale_logger
+ metrics = {"loss": 1.23}
+ logger.log_metrics(metrics, step=5)
+ mock_run.log_metrics.assert_called_once_with({"training/loss": 1.23}, step=5)
+
+
+@pytest.mark.skipif(not _NEPTUNE_SCALE_AVAILABLE, reason="Neptune-Scale is required for this test.")
+def test_neptune_scale_logger_after_save_checkpoint(neptune_scale_logger):
+ test_variants = [
+ ({}, "training/model"),
+ ({"prefix": "custom_prefix"}, "custom_prefix/model"),
+ ({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model"),
+ ]
+
+ for prefix, model_key_prefix in test_variants:
+ logger, run_instance_mock, run_attr_mock = _get_logger_with_mocks(api_key="test", project="project", **prefix)
+ models_root_dir = os.path.join("path", "to", "models")
+ cb_mock = MagicMock(
+ dirpath=models_root_dir,
+ last_model_path=os.path.join(models_root_dir, "last"),
+ best_k_models={
+ f"{os.path.join(models_root_dir, 'model1')}": None,
+ f"{os.path.join(models_root_dir, 'model2/with/slashes')}": None,
+ },
+ best_model_path=os.path.join(models_root_dir, "best_model"),
+ best_model_score=None,
+ )
+
+ mock_file = neptune_scale_logger.types.File
+ mock_file.reset_mock()
+ mock_file.side_effect = mock.Mock()
+ logger.after_save_checkpoint(cb_mock)
+
+ run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
+ run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")
+
+ run_attr_mock.upload.assert_has_calls([
+ call(os.path.join(models_root_dir, "model1")),
+ call(os.path.join(models_root_dir, "model2/with/slashes")),
+ ])