|
1 | 1 | """Base logger and its helper handlers.""" |
2 | 2 |
|
| 3 | +import collections.abc as collections |
3 | 4 | import numbers |
4 | 5 | import warnings |
5 | 6 | from abc import ABCMeta, abstractmethod |
@@ -145,30 +146,64 @@ def _setup_output_metrics_state_attrs( |
145 | 146 |
|
146 | 147 | metrics_state_attrs_dict: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict() |
147 | 148 |
|
148 | | - def key_tuple_tf(tag: str, name: str, *args: str) -> Tuple[str, ...]: |
149 | | - return (tag, name) + args |
| 149 | + def key_tuple_fn(parent_key: Union[str, Tuple[str, ...]], *args: str) -> Tuple[str, ...]: |
| 150 | + if parent_key is None or isinstance(parent_key, str): |
| 151 | + return (parent_key,) + args |
| 152 | + return parent_key + args |
150 | 153 |
|
151 | | - def key_str_tf(tag: str, name: str, *args: str) -> str: |
152 | | - return "/".join((tag, name) + args) |
| 154 | + def key_str_fn(parent_key: str, *args: str) -> str: |
| 155 | + args_str = "/".join(args) |
| 156 | + return f"{parent_key}/{args_str}" |
153 | 157 |
|
154 | | - key_tf = key_tuple_tf if key_tuple else key_str_tf |
| 158 | + key_fn = key_tuple_fn if key_tuple else key_str_fn |
155 | 159 |
|
156 | | - for name, value in metrics_state_attrs.items(): |
| 160 | + def handle_value_fn( |
| 161 | + value: Union[str, int, float, numbers.Number, torch.Tensor] |
| 162 | + ) -> Union[None, str, float, numbers.Number]: |
157 | 163 | if isinstance(value, numbers.Number): |
158 | | - metrics_state_attrs_dict[key_tf(self.tag, name)] = value |
| 164 | + return value |
159 | 165 | elif isinstance(value, torch.Tensor) and value.ndimension() == 0: |
160 | | - metrics_state_attrs_dict[key_tf(self.tag, name)] = value.item() |
161 | | - elif isinstance(value, torch.Tensor) and value.ndimension() == 1: |
162 | | - for i, v in enumerate(value): |
163 | | - metrics_state_attrs_dict[key_tf(self.tag, name, str(i))] = v.item() |
| 166 | + return value.item() |
164 | 167 | else: |
165 | 168 | if isinstance(value, str) and log_text: |
166 | | - metrics_state_attrs_dict[key_tf(self.tag, name)] = value |
| 169 | + return value |
167 | 170 | else: |
168 | 171 | warnings.warn(f"Logger output_handler can not log metrics value type {type(value)}") |
| 172 | + return None |
| 173 | + |
| 174 | + metrics_state_attrs_dict = _flatten_dict(metrics_state_attrs, key_fn, handle_value_fn, parent_key=self.tag) |
169 | 175 | return metrics_state_attrs_dict |
170 | 176 |
|
171 | 177 |
|
| 178 | +def _flatten_dict( |
| 179 | + in_dict: collections.Mapping, |
| 180 | + key_fn: Callable, |
| 181 | + value_fn: Callable, |
| 182 | + parent_key: Optional[Union[str, Tuple[str, ...]]] = None, |
| 183 | +) -> Dict: |
| 184 | + items = {} |
| 185 | + for key, value in in_dict.items(): |
| 186 | + new_key = key_fn(parent_key, key) |
| 187 | + if isinstance(value, collections.Mapping): |
| 188 | + items.update(_flatten_dict(value, key_fn, value_fn, new_key)) |
| 189 | + elif any( |
| 190 | + [ |
| 191 | + isinstance(value, tuple) and hasattr(value, "_fields"), # namedtuple |
| 192 | + not isinstance(value, str) and isinstance(value, collections.Sequence), |
| 193 | + ] |
| 194 | + ): |
| 195 | + for i, item in enumerate(value): |
| 196 | + items.update(_flatten_dict({str(i): item}, key_fn, value_fn, new_key)) |
| 197 | + elif isinstance(value, torch.Tensor) and value.ndimension() == 1: |
| 198 | + for i, item in enumerate(value): |
| 199 | + items.update(_flatten_dict({str(i): item.item()}, key_fn, value_fn, new_key)) |
| 200 | + else: |
| 201 | + new_value = value_fn(value) |
| 202 | + if new_value is not None: |
| 203 | + items[new_key] = new_value |
| 204 | + return items |
| 205 | + |
| 206 | + |
172 | 207 | class BaseWeightsScalarHandler(BaseWeightsHandler): |
173 | 208 | """ |
174 | 209 | Helper handler to log model's weights or gradients as scalars. |
|
0 commit comments