Skip to content
1 change: 1 addition & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added sanitization for classes before logging them as hyperparameters ([#19771](https://github.com/Lightning-AI/pytorch-lightning/pull/19771))
- Added logging support for list of dicts without collapsing to a single key ([#19957](https://github.com/Lightning-AI/pytorch-lightning/issues/19957))
- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI ([#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))
- Added the ability to explicitly mark forward methods in Fabric via `_FabricModule.mark_forward_method()` ([#19690](https://github.com/Lightning-AI/pytorch-lightning/pull/19690))
- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))
Expand Down
6 changes: 6 additions & 0 deletions src/lightning/fabric/utilities/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent
{'a/b': 123}
>>> _flatten_dict({5: {'a': 123}})
{'5/a': 123}
>>> _flatten_dict({"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]})
{'dl/0/a': 1, 'dl/0/c': 3, 'dl/1/b': 2, 'dl/1/d': 5, 'l': [1, 2, 3, 4]}

"""
result: Dict[str, Any] = {}
Expand All @@ -101,6 +103,10 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent

if isinstance(v, MutableMapping):
result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)}
# Also handle the case where v is a list of dictionaries
elif isinstance(v, list) and all(isinstance(item, MutableMapping) for item in v):
for i, item in enumerate(v):
result = {**result, **_flatten_dict(item, parent_key=f"{new_key}/{i}", delimiter=delimiter)}
else:
result[new_key] = v
return result
Expand Down
11 changes: 11 additions & 0 deletions tests/tests_fabric/utilities/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ def test_flatten_dict():
assert params["c/8"] == "foo"
assert params["c/9/10"] == "bar"

# Test list of nested dicts flattening
params = {"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]}
params = _flatten_dict(params)

assert "dl" not in params
assert params["dl/0/a"] == 1
assert params["dl/0/c"] == 3
assert params["dl/1/b"] == 2
assert params["dl/1/d"] == 5
assert params["l"] == [1, 2, 3, 4]

# Test flattening of argparse Namespace
params = Namespace(a=1, b=2)
wrapping_dict = {"params": params}
Expand Down