Skip to content

Commit 0d8f725

Browse files
committed
refactor
1 parent 085f167 commit 0d8f725

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/lightning/pytorch/loggers/utilities.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,16 @@ def __init__(self, loggers: Union[Iterable[_T], Mapping[str, _T]] = None):
121121
super().__init__(loggers.values())
122122
self._dict = dict(zip(loggers.keys(), range(len(loggers))))
123123
else:
124+
default_dict = {}
125+
if isinstance(loggers, _ListMap):
126+
default_dict = loggers._dict.copy()
124127
super().__init__(() if loggers is None else loggers)
125-
self._dict: dict = {}
128+
self._dict: dict = default_dict
126129

127130
def __eq__(self, other: Any) -> bool:
128131
list_eq = list.__eq__(self, other)
129132
if isinstance(other, _ListMap):
130-
dict_eq = self._dict == other._dict
131-
return list_eq and dict_eq
133+
return list_eq and self._dict == other._dict
132134
return list_eq
133135

134136
def copy(self):

tests/tests_pytorch/loggers/test_utilities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,16 @@ def test_version(tmp_path):
4242
[1, 2],
4343
{1, 2},
4444
range(2),
45+
_ListMap({"a": 1, "b": 2}),
4546
],
4647
)
4748
def test_listmap_init(args):
4849
"""Test initialization with different iterable types."""
4950
lm = _ListMap(args)
5051
assert len(lm) == len(args)
5152
assert isinstance(lm, list)
53+
if isinstance(args, _ListMap):
54+
assert lm == args
5255

5356

5457
def test_listmap_append():

0 commit comments

Comments
 (0)