Skip to content

Commit e205e61

Browse files
authored
Added support for logging dicts/iterables (#3369)
Fixes #3294
1 parent 795dccb commit e205e61

17 files changed

+439
-20
lines changed

ignite/handlers/base_logger.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Base logger and its helper handlers."""
22

3+
import collections.abc as collections
34
import numbers
45
import warnings
56
from abc import ABCMeta, abstractmethod
@@ -145,30 +146,64 @@ def _setup_output_metrics_state_attrs(
145146

146147
metrics_state_attrs_dict: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict()
147148

148-
def key_tuple_tf(tag: str, name: str, *args: str) -> Tuple[str, ...]:
149-
return (tag, name) + args
149+
def key_tuple_fn(parent_key: Union[str, Tuple[str, ...]], *args: str) -> Tuple[str, ...]:
150+
if parent_key is None or isinstance(parent_key, str):
151+
return (parent_key,) + args
152+
return parent_key + args
150153

151-
def key_str_tf(tag: str, name: str, *args: str) -> str:
152-
return "/".join((tag, name) + args)
154+
def key_str_fn(parent_key: str, *args: str) -> str:
155+
args_str = "/".join(args)
156+
return f"{parent_key}/{args_str}"
153157

154-
key_tf = key_tuple_tf if key_tuple else key_str_tf
158+
key_fn = key_tuple_fn if key_tuple else key_str_fn
155159

156-
for name, value in metrics_state_attrs.items():
160+
def handle_value_fn(
161+
value: Union[str, int, float, numbers.Number, torch.Tensor]
162+
) -> Union[None, str, float, numbers.Number]:
157163
if isinstance(value, numbers.Number):
158-
metrics_state_attrs_dict[key_tf(self.tag, name)] = value
164+
return value
159165
elif isinstance(value, torch.Tensor) and value.ndimension() == 0:
160-
metrics_state_attrs_dict[key_tf(self.tag, name)] = value.item()
161-
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
162-
for i, v in enumerate(value):
163-
metrics_state_attrs_dict[key_tf(self.tag, name, str(i))] = v.item()
166+
return value.item()
164167
else:
165168
if isinstance(value, str) and log_text:
166-
metrics_state_attrs_dict[key_tf(self.tag, name)] = value
169+
return value
167170
else:
168171
warnings.warn(f"Logger output_handler can not log metrics value type {type(value)}")
172+
return None
173+
174+
metrics_state_attrs_dict = _flatten_dict(metrics_state_attrs, key_fn, handle_value_fn, parent_key=self.tag)
169175
return metrics_state_attrs_dict
170176

171177

178+
def _flatten_dict(
179+
in_dict: collections.Mapping,
180+
key_fn: Callable,
181+
value_fn: Callable,
182+
parent_key: Optional[Union[str, Tuple[str, ...]]] = None,
183+
) -> Dict:
184+
items = {}
185+
for key, value in in_dict.items():
186+
new_key = key_fn(parent_key, key)
187+
if isinstance(value, collections.Mapping):
188+
items.update(_flatten_dict(value, key_fn, value_fn, new_key))
189+
elif any(
190+
[
191+
isinstance(value, tuple) and hasattr(value, "_fields"), # namedtuple
192+
not isinstance(value, str) and isinstance(value, collections.Sequence),
193+
]
194+
):
195+
for i, item in enumerate(value):
196+
items.update(_flatten_dict({str(i): item}, key_fn, value_fn, new_key))
197+
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
198+
for i, item in enumerate(value):
199+
items.update(_flatten_dict({str(i): item.item()}, key_fn, value_fn, new_key))
200+
else:
201+
new_value = value_fn(value)
202+
if new_value is not None:
203+
items[new_key] = new_value
204+
return items
205+
206+
172207
class BaseWeightsScalarHandler(BaseWeightsHandler):
173208
"""
174209
Helper handler to log model's weights or gradients as scalars.

ignite/handlers/clearml_logger.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ class ClearMLLogger(BaseLogger):
109109
log_handler=WeightsScalarHandler(model)
110110
)
111111
112+
Note:
113+
:class:`~ignite.handlers.clearml_logger.OutputHandler` can handle
114+
metrics, state attributes and engine output values of the following format:
115+
- scalar values (i.e. int, float)
116+
- 0d and 1d pytorch tensors
117+
- dicts and list/tuples of previous types
118+
112119
"""
113120

114121
def __init__(self, **kwargs: Any):
@@ -342,9 +349,10 @@ def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str,
342349
for key, value in metrics.items():
343350
if len(key) == 2:
344351
logger.clearml_logger.report_scalar(title=key[0], series=key[1], iteration=global_step, value=value)
345-
elif len(key) == 3:
352+
elif len(key) >= 3:
353+
series = "/".join(key[2:])
346354
logger.clearml_logger.report_scalar(
347-
title=f"{key[0]}/{key[1]}", series=key[2], iteration=global_step, value=value
355+
title=f"{key[0]}/{key[1]}", series=series, iteration=global_step, value=value
348356
)
349357

350358

ignite/handlers/mlflow_logger.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,49 @@ class MLflowLogger(BaseLogger):
8484
optimizer=optimizer,
8585
param_name='lr' # optional
8686
)
87+
88+
Note:
89+
:class:`~ignite.handlers.mlflow_logger.OutputHandler` can handle
90+
metrics, state attributes and engine output values of the following format:
91+
- scalar values (i.e. int, float)
92+
- 0d and 1d pytorch tensors
93+
- dicts and list/tuples of previous types
94+
95+
.. code-block:: python
96+
97+
# !!! This is not a runnable code !!!
98+
evalutator.state.metrics = {
99+
"a": 0,
100+
"dict_value": {
101+
"a": 111,
102+
"c": {"d": 23, "e": [123, 234]},
103+
},
104+
"list_value": [12, 13, {"aa": 33, "bb": 44}],
105+
"tuple_value": (112, 113, {"aaa": 33, "bbb": 44}),
106+
}
107+
108+
handler = OutputHandler(
109+
tag="tag",
110+
metric_names="all",
111+
)
112+
113+
handler(evaluator, mlflow_logger, event_name=Events.EPOCH_COMPLETED)
114+
# Behind it would call `mlflow_logger.log_metrics` on
115+
# {
116+
# "tag/a": 0,
117+
# "tag/dict_value/a": 111,
118+
# "tag/dict_value/c/d": 23,
119+
# "tag/dict_value/c/e/0": 123,
120+
# "tag/dict_value/c/e/1": 234,
121+
# "tag/list_value/0": 12,
122+
# "tag/list_value/1": 13,
123+
# "tag/list_value/2/aa": 33,
124+
# "tag/list_value/2/bb": 44,
125+
# "tag/tuple_value/0": 112,
126+
# "tag/tuple_value/1": 113,
127+
# "tag/tuple_value/2/aaa": 33,
128+
# "tag/tuple_value/2/bbb": 44,
129+
# }
87130
"""
88131

89132
def __init__(self, tracking_uri: Optional[str] = None):

ignite/handlers/neptune_logger.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ def score_function(engine):
153153
output_transform=lambda loss: {"loss": loss},
154154
)
155155
156+
Note:
157+
:class:`~ignite.handlers.neptune_logger.OutputHandler` can handle
158+
metrics, state attributes and engine output values of the following format:
159+
- scalar values (i.e. int, float)
160+
- 0d and 1d pytorch tensors
161+
- dicts and list/tuples of previous types
162+
156163
"""
157164

158165
def __getattr__(self, attr: Any) -> Any:

ignite/handlers/polyaxon_logger.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ class PolyaxonLogger(BaseLogger):
9292
)
9393
# to manually end a run
9494
plx_logger.close()
95+
96+
Note:
97+
:class:`~ignite.handlers.polyaxon_logger.OutputHandler` can handle
98+
metrics, state attributes and engine output values of the following format:
99+
- scalar values (i.e. int, float)
100+
- 0d and 1d pytorch tensors
101+
- dicts and list/tuples of previous types
95102
"""
96103

97104
def __init__(self, *args: Any, **kwargs: Any):

ignite/handlers/tensorboard_logger.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,49 @@ class TensorboardLogger(BaseLogger):
145145
output_transform=lambda loss: {"loss": loss}
146146
)
147147
148+
Note:
149+
:class:`~ignite.handlers.tensorboard_logger.OutputHandler` can handle
150+
metrics, state attributes and engine output values of the following format:
151+
- scalar values (i.e. int, float)
152+
- 0d and 1d pytorch tensors
153+
- dicts and list/tuples of previous types
154+
155+
.. code-block:: python
156+
157+
# !!! This is not a runnable code !!!
158+
evalutator.state.metrics = {
159+
"a": 0,
160+
"dict_value": {
161+
"a": 111,
162+
"c": {"d": 23, "e": [123, 234]},
163+
},
164+
"list_value": [12, 13, {"aa": 33, "bb": 44}],
165+
"tuple_value": (112, 113, {"aaa": 33, "bbb": 44}),
166+
}
167+
168+
handler = OutputHandler(
169+
tag="tag",
170+
metric_names="all",
171+
)
172+
173+
handler(evaluator, tb_logger, event_name=Events.EPOCH_COMPLETED)
174+
# Behind it would call `tb_logger.writer.add_scalar` on
175+
# {
176+
# "tag/a": 0,
177+
# "tag/dict_value/a": 111,
178+
# "tag/dict_value/c/d": 23,
179+
# "tag/dict_value/c/e/0": 123,
180+
# "tag/dict_value/c/e/1": 234,
181+
# "tag/list_value/0": 12,
182+
# "tag/list_value/1": 13,
183+
# "tag/list_value/2/aa": 33,
184+
# "tag/list_value/2/bb": 44,
185+
# "tag/tuple_value/0": 112,
186+
# "tag/tuple_value/1": 113,
187+
# "tag/tuple_value/2/aaa": 33,
188+
# "tag/tuple_value/2/bbb": 44,
189+
# }
190+
148191
"""
149192

150193
def __init__(self, *args: Any, **kwargs: Any):

ignite/handlers/tqdm_logger.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,7 @@ def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[str, E
298298
rendered_metrics = self._setup_output_metrics_state_attrs(engine, log_text=True)
299299
metrics = OrderedDict()
300300
for key, value in rendered_metrics.items():
301-
key = "_".join(key[1:]) # tqdm has tag as description
302-
301+
key = "_".join(key[1:]) # skip tag as tqdm has tag as description
303302
metrics[key] = value
304303

305304
if metrics:

ignite/handlers/visdom_logger.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ class VisdomLogger(BaseLogger):
137137
output_transform=lambda loss: {"loss": loss}
138138
)
139139
140+
Note:
141+
:class:`~ignite.handlers.visdom_logger.OutputHandler` can handle
142+
metrics, state attributes and engine output values of the following format:
143+
- scalar values (i.e. int, float)
144+
- 0d and 1d pytorch tensors
145+
- dicts and list/tuples of previous types
146+
140147
.. versionchanged:: 0.4.7
141148
accepts an optional list of `state_attributes`
142149
"""

ignite/handlers/wandb_logger.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ def score_function(engine):
120120
)
121121
evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model})
122122
123+
Note:
124+
:class:`~ignite.handlers.wandb_logger.OutputHandler` can handle
125+
metrics, state attributes and engine output values of the following format:
126+
- scalar values (i.e. int, float)
127+
- 0d and 1d pytorch tensors
128+
- dicts and list/tuples of previous types
123129
124130
"""
125131

tests/ignite/handlers/test_base_logger.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,44 @@ def test_base_output_handler_setup_output_metrics():
103103
metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False)
104104
assert metrics == {"tag/a": 0, "tag/b": 1}
105105

106+
# metrics with mappings, iterables
107+
true_metrics = {
108+
"a": 0,
109+
"b": "1",
110+
"dict_value": {
111+
"a": 111,
112+
"b": "222",
113+
"c": {"d": 23, "e": [123, 234]},
114+
"f": [{"g": 11, "h": (321, 432)}, 778],
115+
},
116+
"list_value": [12, "13", {"aa": 33, "bb": 44}],
117+
"tuple_value": (112, "113", {"aaa": 33, "bbb": 44}),
118+
}
119+
engine.state = State(metrics=true_metrics)
120+
handler = DummyOutputHandler("tag", metric_names="all", output_transform=None)
121+
metrics = handler._setup_output_metrics_state_attrs(engine=engine, key_tuple=False, log_text=True)
122+
assert metrics == {
123+
"tag/a": 0,
124+
"tag/b": "1",
125+
"tag/dict_value/a": 111,
126+
"tag/dict_value/b": "222",
127+
"tag/dict_value/c/d": 23,
128+
"tag/dict_value/c/e/0": 123,
129+
"tag/dict_value/c/e/1": 234,
130+
"tag/dict_value/f/0/g": 11,
131+
"tag/dict_value/f/0/h/0": 321,
132+
"tag/dict_value/f/0/h/1": 432,
133+
"tag/dict_value/f/1": 778,
134+
"tag/list_value/0": 12,
135+
"tag/list_value/1": "13",
136+
"tag/list_value/2/aa": 33,
137+
"tag/list_value/2/bb": 44,
138+
"tag/tuple_value/0": 112,
139+
"tag/tuple_value/1": "113",
140+
"tag/tuple_value/2/aaa": 33,
141+
"tag/tuple_value/2/bbb": 44,
142+
}
143+
106144

107145
def test_base_output_handler_setup_output_state_attrs():
108146
engine = Engine(lambda engine, batch: None)

0 commit comments

Comments
 (0)