diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 992dbbb4fbff9..a8d5d825e8707 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -6,9 +6,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [unreleased] - YYYY-MM-DD +### Added + +- Added logging support for list of dicts without collapsing to a single key ([#19957](https://github.com/Lightning-AI/pytorch-lightning/issues/19957)) + + ### Removed -Removed legacy supoport for `lightning run model`. Use `fabric run` instead. ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588)) +- Removed legacy supoport for `lightning run model`. Use `fabric run` instead. ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588)) + ## [2.5.0] - 2024-12-19 diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py index dd2b0a3663fc9..04b9069dd0788 100644 --- a/src/lightning/fabric/utilities/logger.py +++ b/src/lightning/fabric/utilities/logger.py @@ -91,6 +91,8 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent {'a/b': 123} >>> _flatten_dict({5: {'a': 123}}) {'5/a': 123} + >>> _flatten_dict({"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]}) + {'dl/0/a': 1, 'dl/0/c': 3, 'dl/1/b': 2, 'dl/1/d': 5, 'l': [1, 2, 3, 4]} """ result: dict[str, Any] = {} @@ -103,6 +105,10 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent if isinstance(v, MutableMapping): result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)} + # Also handle the case where v is a list of dictionaries + elif isinstance(v, list) and all(isinstance(item, MutableMapping) for item in v): + for i, item in enumerate(v): + result = {**result, **_flatten_dict(item, parent_key=f"{new_key}/{i}", delimiter=delimiter)} else: result[new_key] = v return result diff --git a/tests/tests_fabric/utilities/test_logger.py b/tests/tests_fabric/utilities/test_logger.py index 26823143102a7..4cb55c4cb68d8 100644 --- a/tests/tests_fabric/utilities/test_logger.py +++ b/tests/tests_fabric/utilities/test_logger.py @@ -64,6 +64,12 @@ def test_flatten_dict(): assert params["c/8"] == "foo" assert params["c/9/10"] == "bar" + # Test list of nested dicts flattening + params = {"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]} + params = _flatten_dict(params) + + assert params == {"dl/0/a": 1, "dl/0/c": 3, "dl/1/b": 2, "dl/1/d": 5, "l": [1, 2, 3, 4]} + # Test flattening of argparse Namespace params = Namespace(a=1, b=2) wrapping_dict = {"params": params}