Skip to content

Commit 7900105

Browse files
vorkBordapre-commit-ci[bot]
authored
Nicer logging of list of dicts for hyper parameters (#19963)
* Add support for handling lists of dictionaries in logging * Apply suggestions from code review * Update CHANGELOG.md --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 40c682e commit 7900105

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
## [unreleased] - YYYY-MM-DD
88

9+
### Added
10+
11+
- Added logging support for list of dicts without collapsing to a single key ([#19957](https://github.com/Lightning-AI/pytorch-lightning/issues/19957))
12+
13+
914
### Removed
1015

11-
Removed legacy supoport for `lightning run model`. Use `fabric run` instead. ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588))
16+
- Removed legacy supoport for `lightning run model`. Use `fabric run` instead. ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588))
17+
1218

1319
## [2.5.0] - 2024-12-19
1420

src/lightning/fabric/utilities/logger.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent
9191
{'a/b': 123}
9292
>>> _flatten_dict({5: {'a': 123}})
9393
{'5/a': 123}
94+
>>> _flatten_dict({"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]})
95+
{'dl/0/a': 1, 'dl/0/c': 3, 'dl/1/b': 2, 'dl/1/d': 5, 'l': [1, 2, 3, 4]}
9496
9597
"""
9698
result: dict[str, Any] = {}
@@ -103,6 +105,10 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent
103105

104106
if isinstance(v, MutableMapping):
105107
result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)}
108+
# Also handle the case where v is a list of dictionaries
109+
elif isinstance(v, list) and all(isinstance(item, MutableMapping) for item in v):
110+
for i, item in enumerate(v):
111+
result = {**result, **_flatten_dict(item, parent_key=f"{new_key}/{i}", delimiter=delimiter)}
106112
else:
107113
result[new_key] = v
108114
return result

tests/tests_fabric/utilities/test_logger.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def test_flatten_dict():
6464
assert params["c/8"] == "foo"
6565
assert params["c/9/10"] == "bar"
6666

67+
# Test list of nested dicts flattening
68+
params = {"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]}
69+
params = _flatten_dict(params)
70+
71+
assert params == {"dl/0/a": 1, "dl/0/c": 3, "dl/1/b": 2, "dl/1/d": 5, "l": [1, 2, 3, 4]}
72+
6773
# Test flattening of argparse Namespace
6874
params = Namespace(a=1, b=2)
6975
wrapping_dict = {"params": params}

0 commit comments

Comments
 (0)