Skip to content

Commit 9d0d39d

Browse files
committed
fix: mypy
1 parent a2709c2 commit 9d0d39d

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

src/lightning/pytorch/loggers/utilities.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ class _ListMap(list[_T]):
141141
"""
142142

143143
def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = None):
144+
_dict: dict[str, int]
144145
if isinstance(__iterable, Mapping):
145146
# super inits list with values
146147
if any(not isinstance(x, str) for x in __iterable):
@@ -177,7 +178,7 @@ def extend(self, __iterable: Iterable[_T]) -> None:
177178
def pop(self, key: SupportsIndex = -1, /) -> _T: ...
178179

179180
@overload
180-
def pop(self, key: str, /, default: _T) -> _T: ...
181+
def pop(self, key: Union[str, SupportsIndex], default: _T, /) -> _T: ...
181182

182183
@overload
183184
def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ...
@@ -222,14 +223,14 @@ def sort(
222223
reverse: bool = False,
223224
) -> None:
224225
# Create a mapping from item to its name(s)
225-
item_to_names: dict[_T, list[int]] = {}
226+
item_to_names: dict[_T, list[str]] = {}
226227
for name, idx in self._dict.items():
227228
item = self[idx]
228229
item_to_names.setdefault(item, []).append(name)
229230
# Sort the list
230231
super().sort(key=key, reverse=reverse)
231232
# Update _dict with new indices
232-
new_dict = {}
233+
new_dict: dict[str, int] = {}
233234
for idx, item in enumerate(self):
234235
if item in item_to_names:
235236
for name in item_to_names[item]:
@@ -242,12 +243,12 @@ def __getitem__(self, key: Union[SupportsIndex, str], /) -> _T: ...
242243
@overload
243244
def __getitem__(self, key: slice, /) -> list[_T]: ...
244245

245-
def __getitem__(self, key, /):
246+
def __getitem__(self, key):
246247
if isinstance(key, str):
247248
return self[self._dict[key]]
248249
return list.__getitem__(self, key)
249250

250-
def __add__(self, other: Union[list[_T], Self]) -> Self:
251+
def __add__(self, other: Union[list[_T], "_ListMap[_T]"]) -> "_ListMap[_T]":
251252
new_listmap = self.copy()
252253
new_listmap += other
253254
return new_listmap
@@ -267,7 +268,7 @@ def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ...
267268
@overload
268269
def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ...
269270

270-
def __setitem__(self, key, value, /) -> None:
271+
def __setitem__(self, key, value):
271272
if isinstance(key, (int, slice)):
272273
# replace element by index
273274
return super().__setitem__(key, value)
@@ -289,14 +290,17 @@ def __contains__(self, item: Union[object, str]) -> bool:
289290
# --- Dict-like interface ---
290291

291292
def __delitem__(self, key: Union[SupportsIndex, slice, str]) -> None:
293+
index: Union[SupportsIndex, slice]
292294
if isinstance(key, str):
293295
if key not in self._dict:
294296
raise KeyError(f"Key '{key}' not found.")
295-
key: int = self._dict[key]
297+
index = self._dict[key]
298+
else:
299+
index = key
296300

297-
if isinstance(key, (int, slice)):
298-
super().__delitem__(key)
299-
for _key in key.indices(len(self)) if isinstance(key, slice) else [key]:
301+
if isinstance(index, (int, slice)):
302+
super().__delitem__(index)
303+
for _key in index.indices(len(self)) if isinstance(index, slice) else [index]:
300304
# update indices in the dict
301305
for str_key, idx in list(self._dict.items()):
302306
if idx == _key:
@@ -310,20 +314,18 @@ def keys(self) -> KeysView[str]:
310314
return self._dict.keys()
311315

312316
def values(self) -> ValuesView[_T]:
313-
d = {k: self[v] for k, v in self._dict.items()}
314-
return d.values()
317+
return {k: self[v] for k, v in self._dict.items()}.values()
315318

316319
def items(self) -> ItemsView[str, _T]:
317-
d = {k: self[v] for k, v in self._dict.items()}
318-
return d.items()
320+
return {k: self[v] for k, v in self._dict.items()}.items()
319321

320322
@overload
321323
def get(self, __key: str) -> Optional[_T]: ...
322324

323325
@overload
324326
def get(self, __key: str, default: _PT) -> Union[_T, _PT]: ...
325327

326-
def get(self, __key: str, default=None):
328+
def get(self, __key, default=None):
327329
if __key in self._dict:
328330
return self[self._dict[__key]]
329331
return default

0 commit comments

Comments
 (0)