Skip to content

Commit 172ceb3

Browse files
committed
fix: fix mypy
1 parent c371b20 commit 172ceb3

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

src/lightning/pytorch/loggers/utilities.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsIndex, TypeVar, Union
1919

2020
from torch import Tensor
21-
from typing_extensions import Self
21+
from typing_extensions import Self, overload
2222

2323
import lightning.pytorch as pl
2424
from lightning.pytorch.callbacks import Checkpoint
@@ -108,6 +108,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
108108

109109

110110
_T = TypeVar("_T")
111+
_PT = TypeVar("_PT")
111112

112113

113114
class _ListMap(list[_T]):
@@ -139,19 +140,20 @@ class _ListMap(list[_T]):
139140
140141
"""
141142

142-
def __init__(self, __iterable: Union[Mapping[str, _T], Iterable[_T]] = None):
143+
def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = None):
143144
if isinstance(__iterable, Mapping):
144145
# super inits list with values
145146
if any(not isinstance(x, str) for x in __iterable):
146147
raise TypeError("When providing a Mapping, all keys must be of type str.")
147148
super().__init__(__iterable.values())
148-
self._dict = dict(zip(__iterable.keys(), range(len(__iterable))))
149+
_dict = dict(zip(__iterable.keys(), range(len(__iterable))))
149150
else:
150151
default_dict = {}
151152
if isinstance(__iterable, _ListMap):
152153
default_dict = __iterable._dict.copy()
153154
super().__init__(() if __iterable is None else __iterable)
154-
self._dict: dict = default_dict
155+
_dict: dict = default_dict
156+
self._dict = _dict
155157

156158
def __eq__(self, other: Any) -> bool:
157159
list_eq = list.__eq__(self, other)
@@ -171,7 +173,7 @@ def extend(self, __iterable: Iterable[_T]) -> None:
171173
self._dict[key] = idx + offset
172174
super().extend(__iterable)
173175

174-
def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[Any] = None) -> _T:
176+
def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[_PT] = None) -> Union[_T, _PT]:
175177
if isinstance(key, int):
176178
ret = list.pop(self, key)
177179
for str_key, idx in list(self._dict.items()):
@@ -211,7 +213,7 @@ def sort(
211213
reverse: bool = False,
212214
) -> None:
213215
# Create a mapping from item to its name(s)
214-
item_to_names = {}
216+
item_to_names: dict[_T, list[int]] = {}
215217
for name, idx in self._dict.items():
216218
item = self[idx]
217219
item_to_names.setdefault(item, []).append(name)
@@ -225,8 +227,13 @@ def sort(
225227
new_dict[name] = idx
226228
self._dict = new_dict
227229

228-
# --- List-like interface ---
229-
def __getitem__(self, key: Union[int, slice, str]) -> _T:
230+
@overload
231+
def __getitem__(self, key: Union[SupportsIndex, str], /) -> _T: ...
232+
233+
@overload
234+
def __getitem__(self, key: slice, /) -> list[_T]: ...
235+
236+
def __getitem__(self, key, /):
230237
if isinstance(key, str):
231238
return self[self._dict[key]]
232239
return list.__getitem__(self, key)
@@ -245,7 +252,13 @@ def __iadd__(self, other: Union[list[_T], Self]) -> Self:
245252

246253
return super().__iadd__(other)
247254

248-
def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None:
255+
@overload
256+
def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ...
257+
258+
@overload
259+
def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ...
260+
261+
def __setitem__(self, key, value, /) -> None:
249262
if isinstance(key, (int, slice)):
250263
# replace element by index
251264
return list.__setitem__(self, key, value)
@@ -259,7 +272,7 @@ def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None:
259272
return None
260273
raise TypeError("Key must be int or str")
261274

262-
def __contains__(self, item: Union[_T, str]) -> bool:
275+
def __contains__(self, item: Union[object, str]) -> bool:
263276
if isinstance(item, str):
264277
return item in self._dict
265278
return list.__contains__(self, item)

src/lightning/pytorch/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1643,7 +1643,7 @@ def loggers(self) -> _ListMap[Logger]:
16431643
return self._loggers
16441644

16451645
@loggers.setter
1646-
def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger]]]) -> None:
1646+
def loggers(self, loggers: Optional[Union[list[Logger], Mapping[str, Logger], _ListMap[Logger]]]) -> None:
16471647
self._loggers = _ListMap(loggers)
16481648

16491649
@property

0 commit comments

Comments
 (0)