Skip to content

Commit 645cf7b

Browse files
sungmanceunwoosh
andauthored
Resolve label deletion issue (#2315)
* fix label detection issue --------- Co-authored-by: Eunwoo Shin <[email protected]>
1 parent 9fd5871 commit 645cf7b

File tree

16 files changed

+546
-23
lines changed

16 files changed

+546
-23
lines changed

otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,10 @@ def evaluate(
416416
)
417417

418418
eval_results["MHAcc"] = total_acc
419-
eval_results["avgClsAcc"] = total_acc_sl / self.hierarchical_info["num_multiclass_heads"]
419+
if self.hierarchical_info["num_multiclass_heads"] > 0:
420+
eval_results["avgClsAcc"] = total_acc_sl / self.hierarchical_info["num_multiclass_heads"]
421+
else:
422+
eval_results["avgClsAcc"] = total_acc_sl
420423
eval_results["mAP"] = mAP_value
421424
eval_results["accuracy"] = total_acc
422425

otx/algorithms/classification/adapters/mmcls/models/classifiers/sam_classifier.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
logger = get_logger()
1717

1818

19+
def is_hierarchical_chkpt(chkpt: dict):
20+
"""Detect whether previous checkpoint is hierarchical or not."""
21+
for k, v in chkpt.items():
22+
if "fc" in k:
23+
return True
24+
return False
25+
26+
1927
@CLASSIFIERS.register_module()
2028
class SAMImageClassifier(SAMClassifierMixin, ClsLossDynamicsTrackingMixin, ImageClassifier):
2129
"""SAM-enabled ImageClassifier."""
@@ -193,11 +201,19 @@ def load_state_dict_pre_hook(module, state_dict, prefix, *args, **kwargs): # no
193201
def load_state_dict_mixing_hook(
194202
model, model_classes, chkpt_classes, chkpt_dict, prefix, *args, **kwargs
195203
): # pylint: disable=unused-argument, too-many-branches, too-many-locals
196-
"""Modify input state_dict according to class name matching before weight loading."""
204+
"""Modify input state_dict according to class name matching before weight loading.
205+
206+
If previous training is hierarchical training,
207+
then the current training should be hierarchical training. vice versa.
208+
209+
"""
197210
backbone_type = type(model.backbone).__name__
198211
if backbone_type not in ["OTXMobileNetV3", "OTXEfficientNet", "OTXEfficientNetV2"]:
199212
return
200213

214+
if model.hierarchical != is_hierarchical_chkpt(chkpt_dict):
215+
return
216+
201217
# Dst to src mapping index
202218
model_classes = list(model_classes)
203219
chkpt_classes = list(chkpt_classes)
@@ -249,13 +265,15 @@ def load_state_dict_mixing_hook(
249265
continue
250266

251267
# Mix weights
252-
chkpt_param = chkpt_dict[chkpt_name]
253-
for module, c in enumerate(model2chkpt):
254-
if c >= 0:
255-
model_param[module].copy_(chkpt_param[c])
268+
# NOTE: Label mix is not supported for H-label classification.
269+
if not model.hierarchical:
270+
chkpt_param = chkpt_dict[chkpt_name]
271+
for module, c in enumerate(model2chkpt):
272+
if c >= 0:
273+
model_param[module].copy_(chkpt_param[c])
256274

257-
# Replace checkpoint weight by mixed weights
258-
chkpt_dict[chkpt_name] = model_param
275+
# Replace checkpoint weight by mixed weights
276+
chkpt_dict[chkpt_name] = model_param
259277

260278
def extract_feat(self, img):
261279
"""Directly extract features from the backbone + neck.

otx/algorithms/classification/task.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from otx.api.entities.inference_parameters import (
4848
default_progress_callback as default_infer_progress_callback,
4949
)
50+
from otx.api.entities.label import LabelEntity
51+
from otx.api.entities.label_schema import LabelGroup
5052
from otx.api.entities.metadata import FloatMetadata, FloatType
5153
from otx.api.entities.metrics import (
5254
CurveMetric,
@@ -127,16 +129,22 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]
127129
if self._task_environment.model is not None:
128130
self._load_model()
129131

132+
def _is_multi_label(self, label_groups: List[LabelGroup], all_labels: List[LabelEntity]):
133+
"""Check whether the current training mode is multi-label or not."""
134+
# NOTE: In the current Geti, multi-label should have `___` symbol for all group names.
135+
find_multilabel_symbol = ["___" in getattr(i, "name", "") for i in label_groups]
136+
return (
137+
(len(label_groups) > 1) and (len(label_groups) == len(all_labels)) and (False not in find_multilabel_symbol)
138+
)
139+
130140
def _set_train_mode(self):
131-
self._multilabel = len(self._task_environment.label_schema.get_groups(False)) > 1 and len(
132-
self._task_environment.label_schema.get_groups(False)
133-
) == len(
134-
self._task_environment.get_labels(include_empty=False)
135-
) # noqa:E127
141+
label_groups = self._task_environment.label_schema.get_groups(include_empty=False)
142+
all_labels = self._task_environment.label_schema.get_labels(include_empty=False)
143+
144+
self._multilabel = self._is_multi_label(label_groups, all_labels)
136145
if self._multilabel:
137146
logger.info("Classification mode: multilabel")
138-
139-
if not self._multilabel and len(self._task_environment.label_schema.get_groups(False)) > 1:
147+
elif len(label_groups) > 1:
140148
logger.info("Classification mode: hierarchical")
141149
self._hierarchical = True
142150
self._hierarchical_info = get_hierarchical_info(self._task_environment.label_schema)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
{
2+
"info": {},
3+
"categories": {
4+
"label": {
5+
"labels": [
6+
{
7+
"name": "right",
8+
"parent": "triangle",
9+
"attributes": []
10+
},
11+
{
12+
"name": "multi a",
13+
"parent": "triangle",
14+
"attributes": []
15+
},
16+
{
17+
"name": "equilateral",
18+
"parent": "triangle",
19+
"attributes": []
20+
},
21+
{
22+
"name": "square",
23+
"parent": "rectangle",
24+
"attributes": []
25+
},
26+
{
27+
"name": "triangle",
28+
"parent": "",
29+
"attributes": []
30+
},
31+
{
32+
"name": "non_square",
33+
"parent": "rectangle",
34+
"attributes": []
35+
},
36+
{
37+
"name": "rectangle",
38+
"parent": "",
39+
"attributes": []
40+
}
41+
],
42+
"label_groups": [
43+
{
44+
"name": "shape",
45+
"group_type": "exclusive",
46+
"labels": ["rectangle", "triangle"]
47+
},
48+
{
49+
"name": "rectangle default",
50+
"group_type": "exclusive",
51+
"labels": ["non_square", "square"]
52+
},
53+
{
54+
"name": "triangle default",
55+
"group_type": "exclusive",
56+
"labels": ["equilateral", "right"]
57+
},
58+
{
59+
"name": "shape___multiple example___multi a",
60+
"group_type": "exclusive",
61+
"labels": ["multi a"]
62+
}
63+
],
64+
"attributes": []
65+
},
66+
"mask": {
67+
"colormap": [
68+
{
69+
"label_id": 0,
70+
"r": 129,
71+
"g": 64,
72+
"b": 123
73+
},
74+
{
75+
"label_id": 1,
76+
"r": 91,
77+
"g": 105,
78+
"b": 255
79+
},
80+
{
81+
"label_id": 2,
82+
"r": 91,
83+
"g": 105,
84+
"b": 255
85+
},
86+
{
87+
"label_id": 3,
88+
"r": 255,
89+
"g": 86,
90+
"b": 98
91+
},
92+
{
93+
"label_id": 4,
94+
"r": 204,
95+
"g": 148,
96+
"b": 218
97+
},
98+
{
99+
"label_id": 5,
100+
"r": 0,
101+
"g": 251,
102+
"b": 87
103+
},
104+
{
105+
"label_id": 6,
106+
"r": 84,
107+
"g": 143,
108+
"b": 173
109+
}
110+
]
111+
}
112+
},
113+
"items": [
114+
{
115+
"id": "a",
116+
"annotations": [
117+
{
118+
"id": 0,
119+
"type": "label",
120+
"attributes": {},
121+
"group": 0,
122+
"label_id": 4
123+
},
124+
{
125+
"id": 0,
126+
"type": "label",
127+
"attributes": {},
128+
"group": 0,
129+
"label_id": 5
130+
},
131+
{
132+
"id": 0,
133+
"type": "label",
134+
"attributes": {},
135+
"group": 0,
136+
"label_id": 1
137+
}
138+
],
139+
"image": {
140+
"path": "a.jpg",
141+
"size": [10, 5]
142+
},
143+
"media": {
144+
"path": ""
145+
}
146+
},
147+
{
148+
"id": "b",
149+
"annotations": [
150+
{
151+
"id": 0,
152+
"type": "label",
153+
"attributes": {},
154+
"group": 0,
155+
"label_id": 6
156+
},
157+
{
158+
"id": 0,
159+
"type": "label",
160+
"attributes": {},
161+
"group": 0,
162+
"label_id": 5
163+
},
164+
{
165+
"id": 0,
166+
"type": "label",
167+
"attributes": {},
168+
"group": 0,
169+
"label_id": 2
170+
}
171+
],
172+
"image": {
173+
"path": "b.jpg",
174+
"size": [10, 5]
175+
},
176+
"media": {
177+
"path": ""
178+
}
179+
}
180+
]
181+
}

0 commit comments

Comments
 (0)