diff --git a/requirements/pytorch/docs.txt b/requirements/pytorch/docs.txt index 7ee6c8bb309cb..c199106635209 100644 --- a/requirements/pytorch/docs.txt +++ b/requirements/pytorch/docs.txt @@ -4,6 +4,7 @@ nbformat # used for generate empty notebook ipython[notebook] <9.6.0 setuptools<81.0 # workaround for `error in ipython setup command: use_2to3 is invalid.` -onnxscript >= 0.2.2, < 0.5.0 +onnxscript >= 0.2.2, <0.5.0 +tensorboard #-r ../../_notebooks/.actions/requires.txt diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 208244dc38cd3..5ada282b675de 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -173,11 +173,19 @@ def sub_dir(self) -> Optional[str]: @property @rank_zero_experiment def experiment(self) -> "SummaryWriter": - """Actual tensorboard object. To use TensorBoard features anywhere in your code, do the following. + """Returns the underlying TensorBoard summary writer object. Allows you to use TensorBoard logging features + directly in your :class:`~lightning.pytorch.core.LightningModule` or anywhere else in your code with: + + `logger.experiment.some_tensorboard_function()` Example:: - logger.experiment.some_tensorboard_function() + class LitModel(LightningModule): + def training_step(self, batch, batch_idx): + # log a image + self.logger.experiment.add_image('my_image', batch['image'], self.global_step) + # log a histogram + self.logger.experiment.add_histogram('my_histogram', batch['data'], self.global_step) """ if self._experiment is not None: diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py index f9cc41c67045c..55ead4e694b18 100644 --- a/src/lightning/pytorch/loggers/tensorboard.py +++ b/src/lightning/pytorch/loggers/tensorboard.py @@ -18,13 +18,14 @@ import os from argparse import Namespace -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union +from lightning_utilities.core.imports import RequirementCache from torch import Tensor from typing_extensions import override import lightning.pytorch as pl -from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE +from lightning.fabric.loggers.logger import rank_zero_experiment from lightning.fabric.loggers.tensorboard import TensorBoardLogger as FabricTensorBoardLogger from lightning.fabric.utilities.cloud_io import _is_dir from lightning.fabric.utilities.logger import _convert_params @@ -35,6 +36,14 @@ from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn +_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard") +if TYPE_CHECKING: + # assumes at least one will be installed when type checking + if _TENSORBOARD_AVAILABLE: + from torch.utils.tensorboard import SummaryWriter + else: + from tensorboardX import SummaryWriter # type: ignore[no-redef] + class TensorBoardLogger(Logger, FabricTensorBoardLogger): r"""Log to local or remote file system in `TensorBoard `_ format. @@ -260,3 +269,26 @@ def _get_next_version(self) -> int: return 0 return max(existing_versions) + 1 + + @property + @override + @rank_zero_experiment + def experiment(self) -> "SummaryWriter": + """Returns the underlying TensorBoard summary writer object. + + Allows you to use TensorBoard logging features directly in your + :class:`~lightning.pytorch.core.LightningModule` or anywhere else in your code with: + + `logger.experiment.some_tensorboard_function()` + + Example:: + + class LitModel(LightningModule): + def training_step(self, batch, batch_idx): + # log a image + self.logger.experiment.add_image('my_image', batch['image'], self.global_step) + # log a histogram + self.logger.experiment.add_histogram('my_histogram', batch['data'], self.global_step) + + """ + return super().experiment