Skip to content

Commit 67b888f

Browse files
committed
add reverse impl.
1 parent c309642 commit 67b888f

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

src/lightning/pytorch/loggers/utilities.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
# limitations under the License.
1414
"""Utilities for loggers."""
1515

16-
from collections.abc import ItemsView, KeysView, Mapping, ValuesView
16+
from collections.abc import ItemsView, Iterable, KeysView, Mapping, ValuesView
1717
from pathlib import Path
18-
from typing import Any, Optional, TypeVar, Union
18+
from typing import Any, Optional, SupportsIndex, TypeVar, Union
1919

2020
from torch import Tensor
2121
from typing_extensions import Self
@@ -110,7 +110,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
110110
class _ListMap(list[_T]):
111111
"""A hybrid container for loggers allowing both index and name access."""
112112

113-
def __init__(self, loggers: Union[list[_T], Mapping[str, _T]] = None):
113+
def __init__(self, loggers: Union[Iterable[_T], Mapping[str, _T]] = None):
114114
if isinstance(loggers, Mapping):
115115
# super inits list with values
116116
if any(not isinstance(x, str) for x in loggers):
@@ -145,7 +145,7 @@ def __iadd__(self, other: Union[list[_T], Self]) -> Self:
145145
# todo
146146
return list.__iadd__(self, other)
147147

148-
def __setitem__(self, key: Union[int, slice, str], value: _T) -> None:
148+
def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None:
149149
if isinstance(key, (int, slice)):
150150
# replace element by index
151151
return list.__setitem__(self, key, value)
@@ -192,7 +192,7 @@ def items(self) -> ItemsView[str, _T]:
192192
return d.items()
193193

194194
# --- List and Dict interface ---
195-
def pop(self, key: Union[int, str] = -1, default: Optional[Any] = None) -> _T:
195+
def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[Any] = None) -> _T:
196196
if isinstance(key, int):
197197
ret = list.pop(self, key)
198198
for str_key, idx in list(self._dict.items()):
@@ -210,3 +210,15 @@ def pop(self, key: Union[int, str] = -1, default: Optional[Any] = None) -> _T:
210210
def __repr__(self) -> str:
211211
ret = super().__repr__()
212212
return f"_ListMap({ret}, keys={list(self._dict.keys())})"
213+
214+
def __reversed__(self) -> Iterable[_T]:
215+
return reversed(list(self))
216+
217+
def reverse(self) -> None:
218+
for key, idx in self._dict.items():
219+
self._dict[key] = len(self) - 1 - idx
220+
list.reverse(self)
221+
222+
def clear(self):
223+
self._dict.clear()
224+
list.clear(self)

tests/tests_pytorch/loggers/test_utilities.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,30 @@ def test_listmap_remove():
122122
assert 2 not in lm
123123

124124

125+
def test_listmap_reverse():
126+
"""Test reversing the collection."""
127+
lm = _ListMap({"1": 1, "2": 2, "3": 3})
128+
lm.reverse()
129+
assert lm == [3, 2, 1]
130+
for (key, value), expected in zip(lm.items(), [("1", 1), ("2", 2), ("3", 3)]):
131+
assert (key, value) == expected
132+
133+
134+
def test_listmap_reversed():
135+
"""Test reversed iterator of the collection."""
136+
lm = _ListMap({"1": 1, "2": 2, "3": 3})
137+
rev_lm = list(reversed(lm))
138+
assert rev_lm == [3, 2, 1]
139+
140+
141+
def test_listmap_clear():
142+
"""Test clearing the collection."""
143+
lm = _ListMap({"1": 1, "2": 2, "3": 3})
144+
lm.clear()
145+
assert len(lm) == 0
146+
assert len(lm.keys()) == 0
147+
148+
125149
# Dict type properties tests
126150
def test_listmap_keys():
127151
lm = _ListMap({

0 commit comments

Comments
 (0)