Skip to content

Commit 69a2ef3

Browse files
committed
Merge remote-tracking branch 'origin/feat/logger_dict' into feat/logger_dict
2 parents 41f4311 + ec39fe5 commit 69a2ef3

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

src/lightning/pytorch/loggers/utilities.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsIndex, TypeVar, Union
1919

2020
from torch import Tensor
21-
from typing_extensions import Self
21+
from typing_extensions import Self, overload
2222

2323
import lightning.pytorch as pl
2424
from lightning.pytorch.callbacks import Checkpoint
@@ -108,6 +108,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
108108

109109

110110
_T = TypeVar("_T")
111+
_PT = TypeVar("_PT")
111112

112113

113114
class _ListMap(list[_T]):
@@ -139,27 +140,28 @@ class _ListMap(list[_T]):
139140
140141
"""
141142

142-
def __init__(self, __iterable: Union[Mapping[str, _T], Iterable[_T]] = None):
143+
def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = None):
143144
if isinstance(__iterable, Mapping):
144145
# super inits list with values
145146
if any(not isinstance(x, str) for x in __iterable):
146147
raise TypeError("When providing a Mapping, all keys must be of type str.")
147148
super().__init__(__iterable.values())
148-
self._dict = dict(zip(__iterable.keys(), range(len(__iterable))))
149+
_dict = dict(zip(__iterable.keys(), range(len(__iterable))))
149150
else:
150151
default_dict = {}
151152
if isinstance(__iterable, _ListMap):
152153
default_dict = __iterable._dict.copy()
153154
super().__init__(() if __iterable is None else __iterable)
154-
self._dict: dict = default_dict
155+
_dict: dict = default_dict
156+
self._dict = _dict
155157

156158
def __eq__(self, other: Any) -> bool:
157159
list_eq = list.__eq__(self, other)
158160
if isinstance(other, _ListMap):
159161
return list_eq and self._dict == other._dict
160162
return list_eq
161163

162-
def copy(self):
164+
def copy(self) -> Self:
163165
new_listmap = _ListMap(self)
164166
new_listmap._dict = self._dict.copy()
165167
return new_listmap
@@ -171,7 +173,16 @@ def extend(self, __iterable: Iterable[_T]) -> None:
171173
self._dict[key] = idx + offset
172174
super().extend(__iterable)
173175

174-
def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[Any] = None) -> _T:
176+
@overload
177+
def pop(self, key: SupportsIndex = -1, /) -> _T: ...
178+
179+
@overload
180+
def pop(self, key: str, /, default: _T) -> _T: ...
181+
182+
@overload
183+
def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ...
184+
185+
def pop(self, key=-1, default=None):
175186
if isinstance(key, int):
176187
ret = list.pop(self, key)
177188
for str_key, idx in list(self._dict.items()):
@@ -211,7 +222,7 @@ def sort(
211222
reverse: bool = False,
212223
) -> None:
213224
# Create a mapping from item to its name(s)
214-
item_to_names = {}
225+
item_to_names: dict[_T, list[int]] = {}
215226
for name, idx in self._dict.items():
216227
item = self[idx]
217228
item_to_names.setdefault(item, []).append(name)
@@ -225,8 +236,13 @@ def sort(
225236
new_dict[name] = idx
226237
self._dict = new_dict
227238

228-
# --- List-like interface ---
229-
def __getitem__(self, key: Union[int, slice, str]) -> _T:
239+
@overload
240+
def __getitem__(self, key: Union[SupportsIndex, str], /) -> _T: ...
241+
242+
@overload
243+
def __getitem__(self, key: slice, /) -> list[_T]: ...
244+
245+
def __getitem__(self, key, /):
230246
if isinstance(key, str):
231247
return self[self._dict[key]]
232248
return list.__getitem__(self, key)
@@ -245,7 +261,13 @@ def __iadd__(self, other: Union[list[_T], Self]) -> Self:
245261

246262
return super().__iadd__(other)
247263

248-
def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None:
264+
@overload
265+
def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ...
266+
267+
@overload
268+
def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ...
269+
270+
def __setitem__(self, key, value, /) -> None:
249271
if isinstance(key, (int, slice)):
250272
# replace element by index
251273
return list.__setitem__(self, key, value)
@@ -259,14 +281,14 @@ def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None:
259281
return None
260282
raise TypeError("Key must be int or str")
261283

262-
def __contains__(self, item: Union[_T, str]) -> bool:
284+
def __contains__(self, item: Union[object, str]) -> bool:
263285
if isinstance(item, str):
264286
return item in self._dict
265287
return list.__contains__(self, item)
266288

267289
# --- Dict-like interface ---
268290

269-
def __delitem__(self, key: Union[int, slice, str]) -> None:
291+
def __delitem__(self, key: Union[SupportsIndex, slice, str]) -> None:
270292
if isinstance(key, (int, slice)):
271293
list.__delitem__(self, key)
272294
for _key in key.indices(len(self)) if isinstance(key, slice) else [key]:
@@ -294,7 +316,13 @@ def items(self) -> ItemsView[str, _T]:
294316
d = {k: self[v] for k, v in self._dict.items()}
295317
return d.items()
296318

297-
def get(self, __key: str, default: Optional[Any] = None) -> _T:
319+
@overload
320+
def get(self, __key: str) -> Optional[_T]: ...
321+
322+
@overload
323+
def get(self, __key: str, default: _PT) -> Union[_T, _PT]: ...
324+
325+
def get(self, __key: str, default=None):
298326
if __key in self._dict:
299327
return self[self._dict[__key]]
300328
return default
@@ -308,6 +336,6 @@ def reverse(self) -> None:
308336
self._dict[key] = len(self) - 1 - idx
309337
list.reverse(self)
310338

311-
def clear(self):
339+
def clear(self) -> None:
312340
self._dict.clear()
313341
list.clear(self)

src/lightning/pytorch/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1643,7 +1643,7 @@ def loggers(self) -> _ListMap[Logger]:
16431643
return self._loggers
16441644

16451645
@loggers.setter
1646-
def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger]]]) -> None:
1646+
def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger], _ListMap[Logger]]]) -> None:
16471647
self._loggers = _ListMap(loggers)
16481648

16491649
@property

tests/tests_pytorch/loggers/test_utilities.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def test_listmap_getitem():
8787
lm = _ListMap([1, 2])
8888
assert lm[0] == 1
8989
assert lm[1] == 2
90+
assert lm[-1] == 2
91+
assert lm[0:2] == [1, 2]
9092

9193

9294
def test_listmap_setitem():

0 commit comments

Comments
 (0)