Skip to content

Commit fc195b9

Browse files
carmoccaBorda
andauthored
Revert #16401 and user proper CSVLogger (#16405)
Co-authored-by: Jirka <[email protected]>
1 parent 76b3cd5 commit fc195b9

File tree

4 files changed

+14
-16
lines changed

4 files changed

+14
-16
lines changed

.github/workflows/ci-tests-pytorch.yml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,3 @@ jobs:
218218
flags: ${COVERAGE_SCOPE},cpu,pytest-full,python${{ matrix.python-version }},pytorch${{ matrix.pytorch-version }}
219219
name: CPU-coverage
220220
fail_ci_if_error: false
221-
222-
# TODO
223-
# - name: Testing legacy creation
224-
# working-directory: tests/
225-
# run: |
226-
# export PYTHONPATH=$(dirname $LEGACY_PATH);$PYTHONPATH # for `import tests_pytorch`
227-
# python legacy/simple_classif_training.py

src/pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,11 +582,14 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH:
582582
return self.dirpath
583583

584584
if len(trainer.loggers) > 0:
585-
logger_ = trainer.loggers[0]
586-
save_dir = getattr(logger_, "save_dir", None) or trainer.default_root_dir
587-
version = logger_.version
585+
if trainer.loggers[0].save_dir is not None:
586+
save_dir = trainer.loggers[0].save_dir
587+
else:
588+
save_dir = trainer.default_root_dir
589+
name = trainer.loggers[0].name
590+
version = trainer.loggers[0].version
588591
version = version if isinstance(version, str) else f"version_{version}"
589-
ckpt_path = os.path.join(save_dir, str(logger_.name), version, "checkpoints")
592+
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
590593
else:
591594
# if no loggers, use default_root_dir
592595
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")

src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
from torch import Tensor
1919

2020
import pytorch_lightning as pl
21-
from lightning_fabric.loggers import CSVLogger
2221
from lightning_fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
2322
from lightning_fabric.plugins.environments import SLURMEnvironment
2423
from lightning_fabric.utilities import move_data_to_device
2524
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
26-
from pytorch_lightning.loggers import Logger, TensorBoardLogger
25+
from pytorch_lightning.loggers import CSVLogger, Logger, TensorBoardLogger
2726
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
2827

2928
warning_cache = WarningCache()
@@ -72,7 +71,7 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non
7271
" or `tensorboardX` packages are found."
7372
" Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default"
7473
)
75-
logger_ = CSVLogger(root_dir=self.trainer.default_root_dir) # type: ignore[assignment]
74+
logger_ = CSVLogger(save_dir=self.trainer.default_root_dir) # type: ignore[assignment]
7675
self.trainer.loggers = [logger_]
7776
elif isinstance(logger, Iterable):
7877
self.trainer.loggers = list(logger)

src/pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
5454
from pytorch_lightning.core.datamodule import LightningDataModule
5555
from pytorch_lightning.loggers import Logger
56+
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
5657
from pytorch_lightning.loops import PredictionLoop, TrainingEpochLoop
5758
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
5859
from pytorch_lightning.loops.fit_loop import FitLoop
@@ -1806,8 +1807,10 @@ def model(self, model: torch.nn.Module) -> None:
18061807
@property
18071808
def log_dir(self) -> Optional[str]:
18081809
if len(self.loggers) > 0:
1809-
logger_ = self.loggers[0]
1810-
dirpath = getattr(logger_, "log_dir", None) or getattr(logger_, "save_dir", None)
1810+
if not isinstance(self.loggers[0], TensorBoardLogger):
1811+
dirpath = self.loggers[0].save_dir
1812+
else:
1813+
dirpath = self.loggers[0].log_dir
18111814
else:
18121815
dirpath = self.default_root_dir
18131816

0 commit comments

Comments
 (0)