|
18 | 18 |
|
19 | 19 | import os |
20 | 20 | from argparse import Namespace |
21 | | -from typing import Any, Optional, Union |
| 21 | +from typing import TYPE_CHECKING, Any, Optional, Union |
22 | 22 |
|
| 23 | +from lightning_utilities.core.imports import RequirementCache |
23 | 24 | from torch import Tensor |
24 | 25 | from typing_extensions import override |
25 | 26 |
|
26 | 27 | import lightning.pytorch as pl |
27 | | -from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE |
28 | 28 | from lightning.fabric.loggers.tensorboard import TensorBoardLogger as FabricTensorBoardLogger |
29 | 29 | from lightning.fabric.utilities.cloud_io import _is_dir |
30 | 30 | from lightning.fabric.utilities.logger import _convert_params |
|
35 | 35 | from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE |
36 | 36 | from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn |
37 | 37 |
|
| 38 | +_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard") |
| 39 | +if TYPE_CHECKING: |
| 40 | + # assumes at least one will be installed when type checking |
| 41 | + if _TENSORBOARD_AVAILABLE: |
| 42 | + from torch.utils.tensorboard import SummaryWriter |
| 43 | + else: |
| 44 | + from tensorboardX import SummaryWriter # type: ignore[no-redef] |
| 45 | + |
38 | 46 |
|
39 | 47 | class TensorBoardLogger(Logger, FabricTensorBoardLogger): |
40 | 48 | r"""Log to local or remote file system in `TensorBoard <https://www.tensorflow.org/tensorboard>`_ format. |
@@ -260,3 +268,22 @@ def _get_next_version(self) -> int: |
260 | 268 | return 0 |
261 | 269 |
|
262 | 270 | return max(existing_versions) + 1 |
| 271 | + |
| 272 | + @property |
| 273 | + @override |
| 274 | + @rank_zero_only |
| 275 | + def experiment(self) -> "SummaryWriter": |
| 276 | + """Returns the underlying TensorBoard summary writer object. To use TensorBoard features anywhere in your code, |
| 277 | + do the following. |
| 278 | +
|
| 279 | + Example:: |
| 280 | +
|
| 281 | + class LitModel(LightningModule): |
| 282 | + def training_step(self, batch, batch_idx): |
| 283 | + # log a image |
| 284 | + self.logger.experiment.add_image('my_image', batch['image'], self.global_step) |
| 285 | + # log a histogram |
| 286 | + self.logger.experiment.add_histogram('my_histogram', batch['data'], self.global |
| 287 | +
|
| 288 | + """ |
| 289 | + return super().experiment |
0 commit comments