Skip to content

Commit 15d4361

Browse files
awaelchlilantiga
authored andcommitted
Fix parsing of version in TensorBoardLogger and CSVLogger (#18897)
(cherry picked from commit 98685c3)
1 parent 1bda28a commit 15d4361

File tree

8 files changed

+147
-132
lines changed

8 files changed

+147
-132
lines changed

src/lightning/fabric/loggers/csv_logs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ def _get_next_version(self) -> int:
168168
full_path = d["name"]
169169
name = os.path.basename(full_path)
170170
if _is_dir(self._fs, full_path) and name.startswith("version_"):
171-
existing_versions.append(int(name.split("_")[1]))
171+
dir_ver = name.split("_")[1]
172+
if dir_ver.isdigit():
173+
existing_versions.append(int(dir_ver))
172174

173175
if len(existing_versions) == 0:
174176
return 0

src/lightning/fabric/loggers/tensorboard.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ def _get_next_version(self) -> int:
304304
bn = os.path.basename(d)
305305
if _is_dir(self._fs, d) and bn.startswith("version_"):
306306
dir_ver = bn.split("_")[1].replace("/", "")
307-
existing_versions.append(int(dir_ver))
307+
if dir_ver.isdigit():
308+
existing_versions.append(int(dir_ver))
308309
if len(existing_versions) == 0:
309310
return 0
310311

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4141
- Refined the FSDP saving logic and error messaging when path exists ([#18884](https://github.com/Lightning-AI/lightning/pull/18884))
4242

4343

44+
- Fixed an issue parsing the version from folders that don't include a version number in `TensorBoardLogger` and `CSVLogger` ([#18897](https://github.com/Lightning-AI/lightning/issues/18897))
45+
46+
4447
## [2.1.0] - 2023-10-11
4548

4649
### Added

src/lightning/pytorch/loggers/tensorboard.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ def _get_next_version(self) -> int:
243243
bn = os.path.basename(d)
244244
if _is_dir(self._fs, d) and bn.startswith("version_"):
245245
dir_ver = bn.split("_")[1].replace("/", "")
246-
existing_versions.append(int(dir_ver))
246+
if dir_ver.isdigit():
247+
existing_versions.append(int(dir_ver))
247248
if len(existing_versions) == 0:
248249
return 0
249250

tests/tests_fabric/loggers/test_csv.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def test_automatic_versioning(tmp_path):
2424
"""Verify that automatic versioning works."""
2525
(tmp_path / "exp" / "version_0").mkdir(parents=True)
2626
(tmp_path / "exp" / "version_1").mkdir()
27+
(tmp_path / "exp" / "version_nonumber").mkdir()
28+
(tmp_path / "exp" / "other").mkdir()
29+
2730
logger = CSVLogger(root_dir=tmp_path, name="exp")
2831
assert logger.version == 2
2932

@@ -37,43 +40,43 @@ def test_automatic_versioning_relative_root_dir(tmp_path, monkeypatch):
3740
assert logger.version == 2
3841

3942

40-
def test_manual_versioning(tmpdir):
43+
def test_manual_versioning(tmp_path):
4144
"""Verify that manual versioning works."""
42-
root_dir = tmpdir.mkdir("exp")
43-
root_dir.mkdir("version_0")
44-
root_dir.mkdir("version_1")
45-
root_dir.mkdir("version_2")
45+
root_dir = tmp_path / "exp"
46+
(root_dir / "version_0").mkdir(parents=True)
47+
(root_dir / "version_1").mkdir()
48+
(root_dir / "version_2").mkdir()
4649
logger = CSVLogger(root_dir=root_dir, name="exp", version=1)
4750
assert logger.version == 1
4851

4952

50-
def test_named_version(tmpdir):
53+
def test_named_version(tmp_path):
5154
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'."""
5255
exp_name = "exp"
53-
tmpdir.mkdir(exp_name)
56+
(tmp_path / exp_name).mkdir()
5457
expected_version = "2020-02-05-162402"
5558

56-
logger = CSVLogger(root_dir=tmpdir, name=exp_name, version=expected_version)
59+
logger = CSVLogger(root_dir=tmp_path, name=exp_name, version=expected_version)
5760
logger.log_metrics({"a": 1, "b": 2})
5861
logger.save()
5962
assert logger.version == expected_version
60-
assert os.listdir(tmpdir / exp_name) == [expected_version]
61-
assert os.listdir(tmpdir / exp_name / expected_version)
63+
assert os.listdir(tmp_path / exp_name) == [expected_version]
64+
assert os.listdir(tmp_path / exp_name / expected_version)
6265

6366

6467
@pytest.mark.parametrize("name", ["", None])
65-
def test_no_name(tmpdir, name):
68+
def test_no_name(tmp_path, name):
6669
"""Verify that None or empty name works."""
67-
logger = CSVLogger(root_dir=tmpdir, name=name)
70+
logger = CSVLogger(root_dir=tmp_path, name=name)
6871
logger.log_metrics({"a": 1})
6972
logger.save()
70-
assert os.path.normpath(logger._root_dir) == tmpdir # use os.path.normpath to handle trailing /
71-
assert os.listdir(tmpdir / "version_0")
73+
assert os.path.normpath(logger._root_dir) == str(tmp_path) # use os.path.normpath to handle trailing /
74+
assert os.listdir(tmp_path / "version_0")
7275

7376

7477
@pytest.mark.parametrize("step_idx", [10, None])
75-
def test_log_metrics(tmpdir, step_idx):
76-
logger = CSVLogger(tmpdir)
78+
def test_log_metrics(tmp_path, step_idx):
79+
logger = CSVLogger(tmp_path)
7780
metrics = {"float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), "IntTensor": torch.tensor(1)}
7881
logger.log_metrics(metrics, step_idx)
7982
logger.save()
@@ -85,14 +88,14 @@ def test_log_metrics(tmpdir, step_idx):
8588
assert all(n in lines[0] for n in metrics)
8689

8790

88-
def test_log_hyperparams(tmpdir):
89-
logger = CSVLogger(tmpdir)
91+
def test_log_hyperparams(tmp_path):
92+
logger = CSVLogger(tmp_path)
9093
with pytest.raises(NotImplementedError):
9194
logger.log_hyperparams({})
9295

9396

94-
def test_flush_n_steps(tmpdir):
95-
logger = CSVLogger(tmpdir, flush_logs_every_n_steps=2)
97+
def test_flush_n_steps(tmp_path):
98+
logger = CSVLogger(tmp_path, flush_logs_every_n_steps=2)
9699
metrics = {"float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), "IntTensor": torch.tensor(1)}
97100
logger.save = MagicMock()
98101
logger.log_metrics(metrics, step=0)

tests/tests_fabric/loggers/test_tensorboard.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,55 +27,57 @@
2727
from tests_fabric.test_fabric import BoringModel
2828

2929

30-
def test_tensorboard_automatic_versioning(tmpdir):
30+
def test_tensorboard_automatic_versioning(tmp_path):
3131
"""Verify that automatic versioning works."""
32-
root_dir = tmpdir / "tb_versioning"
32+
root_dir = tmp_path / "tb_versioning"
3333
root_dir.mkdir()
3434
(root_dir / "version_0").mkdir()
3535
(root_dir / "version_1").mkdir()
36+
(root_dir / "version_nonumber").mkdir()
37+
(root_dir / "other").mkdir()
3638

37-
logger = TensorBoardLogger(root_dir=tmpdir, name="tb_versioning")
39+
logger = TensorBoardLogger(root_dir=tmp_path, name="tb_versioning")
3840
assert logger.version == 2
3941

4042

41-
def test_tensorboard_manual_versioning(tmpdir):
43+
def test_tensorboard_manual_versioning(tmp_path):
4244
"""Verify that manual versioning works."""
43-
root_dir = tmpdir / "tb_versioning"
45+
root_dir = tmp_path / "tb_versioning"
4446
root_dir.mkdir()
4547
(root_dir / "version_0").mkdir()
4648
(root_dir / "version_1").mkdir()
4749
(root_dir / "version_2").mkdir()
4850

49-
logger = TensorBoardLogger(root_dir=tmpdir, name="tb_versioning", version=1)
51+
logger = TensorBoardLogger(root_dir=tmp_path, name="tb_versioning", version=1)
5052
assert logger.version == 1
5153

5254

53-
def test_tensorboard_named_version(tmpdir):
55+
def test_tensorboard_named_version(tmp_path):
5456
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'."""
5557
name = "tb_versioning"
56-
(tmpdir / name).mkdir()
58+
(tmp_path / name).mkdir()
5759
expected_version = "2020-02-05-162402"
5860

59-
logger = TensorBoardLogger(root_dir=tmpdir, name=name, version=expected_version)
61+
logger = TensorBoardLogger(root_dir=tmp_path, name=name, version=expected_version)
6062
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
6163

6264
assert logger.version == expected_version
63-
assert os.listdir(tmpdir / name) == [expected_version]
64-
assert os.listdir(tmpdir / name / expected_version)
65+
assert os.listdir(tmp_path / name) == [expected_version]
66+
assert os.listdir(tmp_path / name / expected_version)
6567

6668

6769
@pytest.mark.parametrize("name", ["", None])
68-
def test_tensorboard_no_name(tmpdir, name):
70+
def test_tensorboard_no_name(tmp_path, name):
6971
"""Verify that None or empty name works."""
70-
logger = TensorBoardLogger(root_dir=tmpdir, name=name)
72+
logger = TensorBoardLogger(root_dir=tmp_path, name=name)
7173
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
72-
assert os.path.normpath(logger.root_dir) == tmpdir # use os.path.normpath to handle trailing /
73-
assert os.listdir(tmpdir / "version_0")
74+
assert os.path.normpath(logger.root_dir) == str(tmp_path) # use os.path.normpath to handle trailing /
75+
assert os.listdir(tmp_path / "version_0")
7476

7577

76-
def test_tensorboard_log_sub_dir(tmpdir):
78+
def test_tensorboard_log_sub_dir(tmp_path):
7779
# no sub_dir specified
78-
root_dir = tmpdir / "logs"
80+
root_dir = tmp_path / "logs"
7981
logger = TensorBoardLogger(root_dir, name="name", version="version")
8082
assert logger.log_dir == os.path.join(root_dir, "name", "version")
8183

@@ -104,14 +106,14 @@ def test_tensorboard_expand_env_vars():
104106

105107

106108
@pytest.mark.parametrize("step_idx", [10, None])
107-
def test_tensorboard_log_metrics(tmpdir, step_idx):
108-
logger = TensorBoardLogger(tmpdir)
109+
def test_tensorboard_log_metrics(tmp_path, step_idx):
110+
logger = TensorBoardLogger(tmp_path)
109111
metrics = {"float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), "IntTensor": torch.tensor(1)}
110112
logger.log_metrics(metrics, step_idx)
111113

112114

113-
def test_tensorboard_log_hyperparams(tmpdir):
114-
logger = TensorBoardLogger(tmpdir)
115+
def test_tensorboard_log_hyperparams(tmp_path):
116+
logger = TensorBoardLogger(tmp_path)
115117
hparams = {
116118
"float": 0.3,
117119
"int": 1,
@@ -127,8 +129,8 @@ def test_tensorboard_log_hyperparams(tmpdir):
127129
logger.log_hyperparams(hparams)
128130

129131

130-
def test_tensorboard_log_hparams_and_metrics(tmpdir):
131-
logger = TensorBoardLogger(tmpdir, default_hp_metric=False)
132+
def test_tensorboard_log_hparams_and_metrics(tmp_path):
133+
logger = TensorBoardLogger(tmp_path, default_hp_metric=False)
132134
hparams = {
133135
"float": 0.3,
134136
"int": 1,
@@ -146,15 +148,15 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
146148

147149

148150
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
149-
def test_tensorboard_log_graph(tmpdir, example_input_array):
151+
def test_tensorboard_log_graph(tmp_path, example_input_array):
150152
"""Test that log graph works with both model.example_input_array and if array is passed externally."""
151153
# TODO(fabric): Test both nn.Module and LightningModule
152154
# TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
153155
model = BoringModel()
154156
if example_input_array is not None:
155157
model.example_input_array = None
156158

157-
logger = TensorBoardLogger(tmpdir)
159+
logger = TensorBoardLogger(tmp_path)
158160
logger._experiment = Mock()
159161
logger.log_graph(model, example_input_array)
160162
if example_input_array is not None:
@@ -169,11 +171,11 @@ def test_tensorboard_log_graph(tmpdir, example_input_array):
169171

170172

171173
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
172-
def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
174+
def test_tensorboard_log_graph_warning_no_example_input_array(tmp_path):
173175
"""Test that log graph throws warning if model.example_input_array is None."""
174176
model = BoringModel()
175177
model.example_input_array = None
176-
logger = TensorBoardLogger(tmpdir, log_graph=True)
178+
logger = TensorBoardLogger(tmp_path, log_graph=True)
177179
with pytest.warns(
178180
UserWarning,
179181
match="Could not log computational graph to TensorBoard: The `model.example_input_array` .* was not given",
@@ -187,22 +189,22 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
187189
logger.log_graph(model)
188190

189191

190-
def test_tensorboard_finalize(monkeypatch, tmpdir):
192+
def test_tensorboard_finalize(monkeypatch, tmp_path):
191193
"""Test that the SummaryWriter closes in finalize."""
192194
if _TENSORBOARD_AVAILABLE:
193195
import torch.utils.tensorboard as tb
194196
else:
195197
import tensorboardX as tb
196198

197199
monkeypatch.setattr(tb, "SummaryWriter", Mock())
198-
logger = TensorBoardLogger(root_dir=tmpdir)
200+
logger = TensorBoardLogger(root_dir=tmp_path)
199201
assert logger._experiment is None
200202
logger.finalize("any")
201203

202204
# no log calls, no experiment created -> nothing to flush
203205
logger.experiment.assert_not_called()
204206

205-
logger = TensorBoardLogger(root_dir=tmpdir)
207+
logger = TensorBoardLogger(root_dir=tmp_path)
206208
logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment
207209
logger.finalize("any")
208210

@@ -212,10 +214,10 @@ def test_tensorboard_finalize(monkeypatch, tmpdir):
212214

213215

214216
@mock.patch("lightning.fabric.loggers.tensorboard.log")
215-
def test_tensorboard_with_symlink(log, tmpdir):
217+
def test_tensorboard_with_symlink(log, tmp_path):
216218
"""Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and
217219
relative paths."""
218-
os.chdir(tmpdir) # need to use relative paths
220+
os.chdir(tmp_path) # need to use relative paths
219221
source = os.path.join(".", "lightning_logs")
220222
dest = os.path.join(".", "sym_lightning_logs")
221223

@@ -228,10 +230,10 @@ def test_tensorboard_with_symlink(log, tmpdir):
228230
log.warning.assert_not_called()
229231

230232

231-
def test_tensorboard_missing_folder_warning(tmpdir, caplog):
233+
def test_tensorboard_missing_folder_warning(tmp_path, caplog):
232234
"""Verify that the logger throws a warning for invalid directory."""
233235
name = "fake_dir"
234-
logger = TensorBoardLogger(root_dir=tmpdir, name=name)
236+
logger = TensorBoardLogger(root_dir=tmp_path, name=name)
235237

236238
with caplog.at_level(logging.WARNING):
237239
assert logger.version == 0

0 commit comments

Comments
 (0)