Skip to content

Commit b3a3a70

Browse files
committed
Merge remote-tracking branch 'origin/feat/logger_dict' into feat/logger_dict
2 parents e5a38ed + 67b888f commit b3a3a70

File tree

4 files changed

+62
-34
lines changed

4 files changed

+62
-34
lines changed

src/lightning/pytorch/loggers/utilities.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414
"""Utilities for loggers."""
1515

16-
from collections.abc import Mapping
16+
from collections.abc import ItemsView, Iterable, KeysView, Mapping, ValuesView
1717
from pathlib import Path
18-
from typing import Any, Optional, Self, TypeVar, Union
18+
from typing import Any, Optional, SupportsIndex, TypeVar, Union
1919

2020
from torch import Tensor
21+
from typing_extensions import Self
2122

2223
import lightning.pytorch as pl
2324
from lightning.pytorch.callbacks import Checkpoint
@@ -109,7 +110,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
109110
class _ListMap(list[_T]):
110111
"""A hybrid container for loggers allowing both index and name access."""
111112

112-
def __init__(self, loggers: Union[list[_T], Mapping[str, _T]] = None):
113+
def __init__(self, loggers: Union[Iterable[_T], Mapping[str, _T]] = None):
113114
if isinstance(loggers, Mapping):
114115
# super inits list with values
115116
if any(not isinstance(x, str) for x in loggers):
@@ -120,7 +121,7 @@ def __init__(self, loggers: Union[list[_T], Mapping[str, _T]] = None):
120121
super().__init__(() if loggers is None else loggers)
121122
self._dict: dict = {}
122123

123-
def __eq__(self, other):
124+
def __eq__(self, other: Any) -> bool:
124125
self_list = list(self)
125126
if isinstance(other, _ListMap):
126127
return self_list == list(other) and self._dict == other._dict
@@ -144,7 +145,7 @@ def __iadd__(self, other: Union[list[_T], Self]) -> Self:
144145
# todo
145146
return list.__iadd__(self, other)
146147

147-
def __setitem__(self, key, value):
148+
def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None:
148149
if isinstance(key, (int, slice)):
149150
# replace element by index
150151
return list.__setitem__(self, key, value)
@@ -158,14 +159,14 @@ def __setitem__(self, key, value):
158159
return None
159160
raise TypeError("Key must be int or str")
160161

161-
def __contains__(self, item):
162+
def __contains__(self, item: Union[_T, str]) -> bool:
162163
if isinstance(item, str):
163164
return item in self._dict
164165
return list.__contains__(self, item)
165166

166167
# --- Dict-like interface ---
167168

168-
def __delitem__(self, key):
169+
def __delitem__(self, key: Union[int, slice, str]) -> None:
169170
if isinstance(key, (int, slice)):
170171
loggers = list.__getitem__(self, key)
171172
super(list, self).__delitem__(key)
@@ -179,19 +180,19 @@ def __delitem__(self, key):
179180
else:
180181
raise TypeError("Key must be int or str")
181182

182-
def keys(self):
183+
def keys(self) -> KeysView[str]:
183184
return self._dict.keys()
184185

185-
def values(self):
186+
def values(self) -> ValuesView[_T]:
186187
d = {k: self[v] for k, v in self._dict.items()}
187188
return d.values()
188189

189-
def items(self):
190+
def items(self) -> ItemsView[str, _T]:
190191
d = {k: self[v] for k, v in self._dict.items()}
191192
return d.items()
192193

193194
# --- List and Dict interface ---
194-
def pop(self, key: Union[int, str] = -1, default: Optional[Any] = None) -> _T:
195+
def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[Any] = None) -> _T:
195196
if isinstance(key, int):
196197
ret = list.pop(self, key)
197198
for str_key, idx in list(self._dict.items()):
@@ -206,6 +207,18 @@ def pop(self, key: Union[int, str] = -1, default: Optional[Any] = None) -> _T:
206207
return self.pop(self._dict[key])
207208
raise TypeError("Key must be int or str")
208209

209-
def __repr__(self):
210+
def __repr__(self) -> str:
210211
ret = super().__repr__()
211212
return f"_ListMap({ret}, keys={list(self._dict.keys())})"
213+
214+
def __reversed__(self) -> Iterable[_T]:
215+
return reversed(list(self))
216+
217+
def reverse(self) -> None:
218+
for key, idx in self._dict.items():
219+
self._dict[key] = len(self) - 1 - idx
220+
list.reverse(self)
221+
222+
def clear(self):
223+
self._dict.clear()
224+
list.clear(self)

src/lightning/pytorch/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def __init__(
494494
setup._init_profiler(self, profiler)
495495

496496
# init logger flags
497-
self._loggers: list[Logger]
497+
self._loggers: _ListMap[Logger]
498498
self._logger_connector.on_trainer_init(logger, log_every_n_steps)
499499

500500
# init debugging flags

tests/tests_pytorch/loggers/test_utilities.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,30 @@ def test_listmap_remove():
122122
assert 2 not in lm
123123

124124

125+
def test_listmap_reverse():
126+
"""Test reversing the collection."""
127+
lm = _ListMap({"1": 1, "2": 2, "3": 3})
128+
lm.reverse()
129+
assert lm == [3, 2, 1]
130+
for (key, value), expected in zip(lm.items(), [("1", 1), ("2", 2), ("3", 3)]):
131+
assert (key, value) == expected
132+
133+
134+
def test_listmap_reversed():
135+
"""Test reversed iterator of the collection."""
136+
lm = _ListMap({"1": 1, "2": 2, "3": 3})
137+
rev_lm = list(reversed(lm))
138+
assert rev_lm == [3, 2, 1]
139+
140+
141+
def test_listmap_clear():
142+
"""Test clearing the collection."""
143+
lm = _ListMap({"1": 1, "2": 2, "3": 3})
144+
lm.clear()
145+
assert len(lm) == 0
146+
assert len(lm.keys()) == 0
147+
148+
125149
# Dict type properties tests
126150
def test_listmap_keys():
127151
lm = _ListMap({

tests/tests_pytorch/trainer/properties/test_loggers.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,21 @@ def test_trainer_loggers_property():
2323
"""Test for correct initialization of loggers in Trainer."""
2424
logger1 = CustomLogger()
2525
logger2 = CustomLogger()
26-
logger3 = CustomLogger()
26+
CustomLogger()
2727

2828
# trainer.loggers should be a copy of the input list
2929
trainer = Trainer(logger=[logger1, logger2])
3030

3131
assert trainer.loggers == [logger1, logger2]
32-
assert trainer.logger_map == {}
3332

3433
# trainer.loggers should create a list of size 1
3534
trainer = Trainer(logger=logger1)
3635

3736
assert trainer.logger == logger1
3837
assert trainer.loggers == [logger1]
39-
assert trainer.logger_map == {}
4038

4139
trainer.loggers.append(logger2)
4240
assert trainer.loggers == [logger1, logger2]
43-
assert trainer.logger_map == {}
4441

4542
# trainer.loggers should be a list of size 1 holding the default logger
4643
trainer = Trainer(logger=True)
@@ -51,19 +48,15 @@ def test_trainer_loggers_property():
5148
trainer = Trainer(logger={"log1": logger1, "log2": logger2})
5249
assert trainer.logger == logger1
5350
assert trainer.loggers == [logger1, logger2]
54-
assert isinstance(trainer.logger_map, dict)
55-
assert trainer.logger_map == {"log1": logger1, "log2": logger2}
56-
57-
trainer.loggers.append(logger3)
58-
assert trainer.loggers == [logger1, logger2, logger3]
59-
assert trainer.logger_map == {"log1": logger1, "log2": logger2}
51+
assert trainer.loggers["log1"] is logger1
52+
assert trainer.loggers["log2"] is logger2
6053

6154

6255
def test_trainer_loggers_setters():
6356
"""Test the behavior of setters for trainer.logger and trainer.loggers."""
6457
logger1 = CustomLogger()
6558
logger2 = CustomLogger()
66-
logger3 = CustomLogger()
59+
CustomLogger()
6760

6861
trainer = Trainer()
6962
assert type(trainer.logger) is TensorBoardLogger
@@ -78,46 +71,45 @@ def test_trainer_loggers_setters():
7871
assert trainer.logger is None
7972
assert trainer.loggers == []
8073
assert isinstance(trainer.loggers, list)
81-
assert trainer.logger_map == {}
8274

8375
# Test setters for trainer.loggers
8476
trainer.loggers = [logger1, logger2]
8577
assert trainer.loggers == [logger1, logger2]
8678
assert isinstance(trainer.loggers, list)
87-
assert trainer.logger_map == {}
8879

8980
trainer.loggers = [logger1]
9081
assert trainer.loggers == [logger1]
9182
assert trainer.logger == logger1
92-
assert trainer.logger_map == {}
9383

9484
trainer.loggers = []
9585
assert trainer.loggers == []
9686
assert trainer.logger is None
9787
assert isinstance(trainer.loggers, list)
98-
assert trainer.logger_map == {}
9988

10089
trainer.loggers = None
10190
assert trainer.loggers == []
10291
assert trainer.logger is None
10392
assert isinstance(trainer.loggers, list)
104-
assert trainer.logger_map == {}
93+
94+
trainer.loggers = {}
95+
assert trainer.loggers == []
96+
assert trainer.logger is None
97+
assert isinstance(trainer.loggers, list)
10598

10699
trainer.loggers = {"log1": logger1, "log2": logger2}
107100
assert trainer.loggers == [logger1, logger2]
108101
assert isinstance(trainer.loggers, list)
109-
assert isinstance(trainer.logger_map, dict)
110-
assert trainer.logger_map == {"log1": logger1, "log2": logger2}
111102

112-
trainer.loggers.append(logger3)
113-
assert trainer.logger_map == {"log1": logger1, "log2": logger2}
103+
assert trainer.loggers["log1"] is logger1
104+
assert trainer.loggers["log2"] is logger2
114105

115106

116107
@pytest.mark.parametrize(
117108
"logger_value",
118109
[
119110
False,
120111
[],
112+
{},
121113
],
122114
)
123115
def test_no_logger(tmp_path, logger_value):
@@ -130,4 +122,3 @@ def test_no_logger(tmp_path, logger_value):
130122
assert trainer.logger is None
131123
assert trainer.loggers == []
132124
assert trainer.log_dir == str(tmp_path)
133-
assert trainer.logger_map == {}

0 commit comments

Comments
 (0)