Skip to content

Commit c292788

Browse files
SeanNarenpre-commit-ci[bot]
authored andcommitted
[bugfix] Minor improvements to apply_to_collection and type signature of log_dict (#7851)
* minor fixeS * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 009e05d)
1 parent 8a5a56b commit c292788

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from argparse import Namespace
2626
from functools import partial
2727
from pathlib import Path
28-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
28+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
2929

3030
import torch
3131
from torch import ScriptModule, Tensor
@@ -347,7 +347,7 @@ def log(
347347

348348
def log_dict(
349349
self,
350-
dictionary: dict,
350+
dictionary: Mapping[str, Any],
351351
prog_bar: bool = False,
352352
logger: bool = True,
353353
on_step: Optional[bool] = None,

pytorch_lightning/utilities/apply_func.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import operator
1515
from abc import ABC
16+
from collections import OrderedDict
1617
from collections.abc import Mapping, Sequence
1718
from copy import copy
1819
from functools import partial
@@ -85,10 +86,12 @@ def apply_to_collection(
8586

8687
# Recursively apply to collection items
8788
if isinstance(data, Mapping):
88-
return elem_type({
89-
k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
90-
for k, v in data.items()
91-
})
89+
return elem_type(
90+
OrderedDict({
91+
k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
92+
for k, v in data.items()
93+
})
94+
)
9295

9396
if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
9497
return elem_type(

tests/utilities/test_apply_func.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import numbers
15-
from collections import namedtuple
15+
from collections import namedtuple, OrderedDict
1616

1717
import numpy as np
1818
import torch
@@ -76,3 +76,19 @@ def test_recursive_application_to_collection():
7676

7777
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor'
7878
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'
79+
80+
# mapping support
81+
reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x))
82+
assert reduced == {'a': '1', 'b': '2'}
83+
reduced = apply_to_collection(OrderedDict([('b', 2), ('a', 1)]), int, lambda x: str(x))
84+
assert reduced == OrderedDict([('b', '2'), ('a', '1')])
85+
86+
# custom mappings
87+
class _CustomCollection(dict):
88+
89+
def __init__(self, initial_dict):
90+
super().__init__(initial_dict)
91+
92+
to_reduce = _CustomCollection({'a': 1, 'b': 2, 'c': 3})
93+
reduced = apply_to_collection(to_reduce, int, lambda x: str(x))
94+
assert reduced == _CustomCollection({'a': '1', 'b': '2', 'c': '3'})

0 commit comments

Comments
 (0)