Skip to content

Commit a123ca7

Browse files
author
Songki Choi
authored
Fix graph metric order and label issues (#2356)
* Fix graph metric going backward issue * Add license notice * Fix pre-commit issue * Add rename items & logic for metric --------- Signed-off-by: Songki Choi <[email protected]>
1 parent ca5f74d commit a123ca7

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

src/otx/algorithms/common/adapters/mmcv/hooks/logger_hook.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
"""Logger hooks."""
2+
3+
# Copyright (C) 2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
26
from collections import defaultdict
37
from typing import Any, Dict, Optional
48

@@ -29,6 +33,19 @@ def __repr__(self):
2933
points.append(f"({x},{y})")
3034
return "curve[" + ",".join(points) + "]"
3135

36+
_TAGS_TO_SKIP = (
37+
"accuracy_top-1",
38+
"current_iters",
39+
"decode.acc_seg",
40+
"decode.loss_ce_ignore",
41+
)
42+
43+
_TAGS_TO_RENAME = {
44+
"train/time": "train/time (sec/iter)",
45+
"train/data_time": "train/data_time (sec/iter)",
46+
"val/accuracy": "val/accuracy (%)",
47+
}
48+
3249
def __init__(
3350
self,
3451
curves: Optional[Dict[Any, Curve]] = None,
@@ -43,12 +60,13 @@ def __init__(
4360
@master_only
4461
def log(self, runner: BaseRunner):
4562
"""Log function for OTXLoggerHook."""
46-
tags = self.get_loggable_tags(runner, allow_text=False, tags_to_skip=())
63+
tags = self.get_loggable_tags(runner, allow_text=False, tags_to_skip=self._TAGS_TO_SKIP)
4764
if runner.max_epochs is not None:
4865
normalized_iter = self.get_iter(runner) / runner.max_iters * runner.max_epochs
4966
else:
5067
normalized_iter = self.get_iter(runner)
5168
for tag, value in tags.items():
69+
tag = self._TAGS_TO_RENAME.get(tag, tag)
5270
curve = self.curves[tag]
5371
# Remove duplicates.
5472
if len(curve.x) > 0 and curve.x[-1] == normalized_iter:
@@ -57,6 +75,11 @@ def log(self, runner: BaseRunner):
5775
curve.x.append(normalized_iter)
5876
curve.y.append(value)
5977

78+
def before_run(self, runner: BaseRunner):
79+
"""Called before_run in OTXLoggerHook."""
80+
super().before_run(runner)
81+
self.curves.clear()
82+
6083
def after_train_epoch(self, runner: BaseRunner):
6184
"""Called after_train_epoch in OTXLoggerHook."""
6285
# Iteration counter is increased right after the last iteration in the epoch,

0 commit comments

Comments
 (0)