Skip to content

Commit 3e9e398

Browse files
committed
implement list methods.
1 parent b3a3a70 commit 3e9e398

File tree

2 files changed

+128
-40
lines changed

2 files changed

+128
-40
lines changed

src/lightning/pytorch/loggers/utilities.py

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515

1616
from collections.abc import ItemsView, Iterable, KeysView, Mapping, ValuesView
1717
from pathlib import Path
18-
from typing import Any, Optional, SupportsIndex, TypeVar, Union
18+
from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsIndex, TypeVar, Union
1919

2020
from torch import Tensor
2121
from typing_extensions import Self
2222

2323
import lightning.pytorch as pl
2424
from lightning.pytorch.callbacks import Checkpoint
2525

26+
if TYPE_CHECKING:
27+
from _typeshed import SupportsRichComparison
28+
2629

2730
def _version(loggers: list[Any], separator: str = "_") -> Union[int, str]:
2831
if len(loggers) == 1:
@@ -122,28 +125,96 @@ def __init__(self, loggers: Union[Iterable[_T], Mapping[str, _T]] = None):
122125
self._dict: dict = {}
123126

124127
def __eq__(self, other: Any) -> bool:
125-
self_list = list(self)
128+
list_eq = list.__eq__(self, other)
126129
if isinstance(other, _ListMap):
127-
return self_list == list(other) and self._dict == other._dict
128-
if isinstance(other, list):
129-
return self_list == other
130-
return False
130+
dict_eq = self._dict == other._dict
131+
return list_eq and dict_eq
132+
return list_eq
133+
134+
def copy(self):
135+
new_listmap = _ListMap(self)
136+
new_listmap._dict = self._dict.copy()
137+
return new_listmap
138+
139+
def extend(self, __iterable: Iterable[_T]) -> None:
140+
if isinstance(__iterable, _ListMap):
141+
offset = len(self)
142+
for key, idx in __iterable._dict.items():
143+
self._dict[key] = idx + offset
144+
super().extend(__iterable)
145+
146+
def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[Any] = None) -> _T:
147+
if isinstance(key, int):
148+
ret = list.pop(self, key)
149+
for str_key, idx in list(self._dict.items()):
150+
if idx == key:
151+
self._dict.pop(str_key)
152+
elif idx > key:
153+
self._dict[str_key] = idx - 1
154+
return ret
155+
if isinstance(key, str):
156+
if key not in self._dict:
157+
return default
158+
return self.pop(self._dict[key])
159+
raise TypeError("Key must be int or str")
160+
161+
def insert(self, index: SupportsIndex, __object: _T) -> None:
162+
for key, idx in self._dict.items():
163+
if idx >= index:
164+
self._dict[key] = idx + 1
165+
list.insert(self, index, __object)
166+
167+
def remove(self, __object: _T) -> None:
168+
idx = self.index(__object)
169+
name = None
170+
for key, val in self._dict.items():
171+
if val == idx:
172+
name = key
173+
elif val > idx:
174+
self._dict[key] = val - 1
175+
if name:
176+
self._dict.pop(name, None)
177+
list.remove(self, __object)
178+
179+
def sort(
180+
self,
181+
*,
182+
key: Optional[Callable[[_T], "SupportsRichComparison"]] = None,
183+
reverse: bool = False,
184+
) -> None:
185+
# Create a mapping from item to its name(s)
186+
item_to_names = {}
187+
for name, idx in self._dict.items():
188+
item = self[idx]
189+
item_to_names.setdefault(item, []).append(name)
190+
# Sort the list
191+
list.sort(self, key=key, reverse=reverse)
192+
# Update _dict with new indices
193+
new_dict = {}
194+
for idx, item in enumerate(self):
195+
if item in item_to_names:
196+
for name in item_to_names[item]:
197+
new_dict[name] = idx
198+
self._dict = new_dict
131199

132200
# --- List-like interface ---
133201
def __getitem__(self, key: Union[int, slice, str]) -> _T:
134-
if isinstance(key, (int, slice)):
135-
return list.__getitem__(self, key)
136202
if isinstance(key, str):
137-
return list.__getitem__(self, self._dict[key])
138-
raise TypeError("Key must be int / slice (for index) or str (for name).")
203+
return self[self._dict[key]]
204+
return list.__getitem__(self, key)
139205

140-
def __add__(self, other: Union[list[_T], Self]) -> list[_T]:
141-
# todo
142-
return list.__add__(self, other)
206+
def __add__(self, other: Union[list[_T], Self]) -> Self:
207+
new_listmap = self.copy()
208+
new_listmap += other
209+
return new_listmap
143210

144211
def __iadd__(self, other: Union[list[_T], Self]) -> Self:
145-
# todo
146-
return list.__iadd__(self, other)
212+
if isinstance(other, _ListMap):
213+
offset = len(self)
214+
for key, idx in other._dict.items():
215+
self._dict[key] = idx + offset
216+
217+
return super().__iadd__(other)
147218

148219
def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None:
149220
if isinstance(key, (int, slice)):
@@ -192,20 +263,6 @@ def items(self) -> ItemsView[str, _T]:
192263
return d.items()
193264

194265
# --- List and Dict interface ---
195-
def pop(self, key: Union[SupportsIndex, str] = -1, default: Optional[Any] = None) -> _T:
196-
if isinstance(key, int):
197-
ret = list.pop(self, key)
198-
for str_key, idx in list(self._dict.items()):
199-
if idx == key:
200-
self._dict.pop(str_key)
201-
elif idx > key:
202-
self._dict[str_key] = idx - 1
203-
return ret
204-
if isinstance(key, str):
205-
if key not in self._dict:
206-
return default
207-
return self.pop(self._dict[key])
208-
raise TypeError("Key must be int or str")
209266

210267
def __repr__(self) -> str:
211268
ret = super().__repr__()

tests/tests_pytorch/loggers/test_utilities.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,20 +98,24 @@ def test_listmap_setitem():
9898
def test_listmap_add():
9999
"""Test adding two collections together."""
100100
lm1 = _ListMap([1, 2])
101-
lm2 = _ListMap([3, 4])
101+
lm2 = _ListMap({"3": 3, "5": 5})
102102
combined = lm1 + lm2
103-
assert isinstance(combined, list)
103+
assert isinstance(combined, _ListMap)
104104
assert len(combined) == 4
105-
assert combined[0] == 1
106-
assert combined[1] == 2
107-
assert combined[2] == 3
108-
assert combined[3] == 4
105+
assert combined is not lm1
106+
assert combined == [1, 2, 3, 5]
107+
assert combined["3"] == 3
108+
assert combined["5"] == 5
109109

110-
combined += lm1
111-
assert isinstance(combined, list)
112-
assert len(combined) == 6
113-
for item, expected in zip(combined, [1, 2, 3, 4, 1, 2]):
114-
assert item == expected
110+
ori_lm1_id = id(lm1)
111+
112+
lm1 += lm2
113+
assert ori_lm1_id == id(lm1)
114+
assert isinstance(lm1, _ListMap)
115+
assert len(lm1) == 4
116+
assert lm1 == [1, 2, 3, 5]
117+
assert lm1["3"] == 3
118+
assert lm1["5"] == 5
115119

116120

117121
def test_listmap_remove():
@@ -190,6 +194,13 @@ def test_listmap_dict_pop():
190194
assert "b" not in lm
191195
assert len(lm) == 2
192196

197+
value = lm.pop(0)
198+
assert value == 1
199+
assert lm["c"] == 3 # still accessible by key
200+
assert len(lm) == 1
201+
with pytest.raises(KeyError):
202+
lm["a"] # "a" was removed
203+
193204

194205
def test_listmap_dict_setitem():
195206
lm = _ListMap({
@@ -201,3 +212,23 @@ def test_listmap_dict_setitem():
201212
lm["c"] = 3
202213
assert lm["c"] == 3
203214
assert len(lm) == 3
215+
216+
217+
def test_listmap_sort():
218+
lm = _ListMap({"b": 1, "c": 3, "a": 2, "z": -7})
219+
220+
lm.extend([-1, -2, 5])
221+
lm.sort(key=lambda x: abs(x))
222+
assert lm == [1, -1, 2, -2, 3, 5, -7]
223+
assert lm["a"] == 2
224+
assert lm["b"] == 1
225+
assert lm["c"] == 3
226+
assert lm["z"] == -7
227+
228+
lm = _ListMap({"b": 1, "c": 3, "a": 2, "z": -7})
229+
lm.sort(reverse=True)
230+
assert lm == [3, 2, 1, -7]
231+
assert lm["a"] == 2
232+
assert lm["b"] == 1
233+
assert lm["c"] == 3
234+
assert lm["z"] == -7

0 commit comments

Comments
 (0)