Skip to content

Commit 0906811

Browse files
authored
Fix H-label classification (#2377)
* Fix h-labelissue * Update unit tests * Make black happy * Fix unittests * Make black happy * Fix update heades information func * Update the logic: consider the loss per batch
1 parent 43eb838 commit 0906811

File tree

8 files changed

+108
-38
lines changed

8 files changed

+108
-38
lines changed

src/otx/algorithms/classification/adapters/mmcls/configurer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ def configure_model(self, cfg, ir_options): # noqa: C901
132132
cfg.model.arch_type = cfg.model.type
133133
cfg.model.type = super_type
134134

135+
# Hierarchical
136+
if cfg.model.get("hierarchical"):
137+
assert cfg.data.train.hierarchical_info == cfg.data.val.hierarchical_info == cfg.data.test.hierarchical_info
138+
cfg.model.head.hierarchical_info = cfg.data.train.hierarchical_info
139+
135140
# OV-plugin
136141
ir_model_path = ir_options.get("ir_model_path")
137142
if ir_model_path:

src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,7 @@ def load_annotations(self):
309309
if item_labels:
310310
num_cls_heads = self.hierarchical_info["num_multiclass_heads"]
311311

312-
class_indices = [0] * (
313-
self.hierarchical_info["num_multiclass_heads"] + self.hierarchical_info["num_multilabel_classes"]
314-
)
312+
class_indices = [0] * (num_cls_heads + self.hierarchical_info["num_multilabel_classes"])
315313
for j in range(num_cls_heads):
316314
class_indices[j] = -1
317315
for otx_lbl in item_labels:
@@ -329,6 +327,19 @@ def load_annotations(self):
329327
self.gt_labels.append(class_indices)
330328
self.gt_labels = np.array(self.gt_labels)
331329

330+
self._update_heads_information()
331+
332+
def _update_heads_information(self):
333+
"""Update heads information to find the empty heads.
334+
335+
If there are no annotations at a specific head, this should be filtered out to calculate loss correctly.
336+
"""
337+
num_cls_heads = self.hierarchical_info["num_multiclass_heads"]
338+
for head_idx in range(num_cls_heads):
339+
labels_in_head = self.gt_labels[:, head_idx] # type: ignore[call-overload]
340+
if max(labels_in_head) < 0:
341+
self.hierarchical_info["empty_multiclass_head_indices"].append(head_idx)
342+
332343
@staticmethod
333344
def mean_top_k_accuracy(scores, labels, k=1):
334345
"""Return mean of top-k accuracy."""

src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,22 +87,26 @@ def forward_train(self, cls_score, gt_label, **kwargs):
8787
cls_score = self.fc(cls_score)
8888

8989
losses = dict(loss=0.0)
90+
num_effective_heads_in_batch = 0
9091
for i in range(self.hierarchical_info["num_multiclass_heads"]):
91-
head_gt = gt_label[:, i]
92-
head_logits = cls_score[
93-
:,
94-
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
95-
"head_idx_to_logits_range"
96-
][str(i)][1],
97-
]
98-
valid_mask = head_gt >= 0
99-
head_gt = head_gt[valid_mask].long()
100-
head_logits = head_logits[valid_mask, :]
101-
multiclass_loss = self.loss(head_logits, head_gt)
102-
losses["loss"] += multiclass_loss
92+
if i not in self.hierarchical_info["empty_multiclass_head_indices"]:
93+
head_gt = gt_label[:, i]
94+
head_logits = cls_score[
95+
:,
96+
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
97+
"head_idx_to_logits_range"
98+
][str(i)][1],
99+
]
100+
valid_mask = head_gt >= 0
101+
head_gt = head_gt[valid_mask].long()
102+
if len(head_gt) > 0:
103+
head_logits = head_logits[valid_mask, :]
104+
multiclass_loss = self.loss(head_logits, head_gt)
105+
losses["loss"] += multiclass_loss
106+
num_effective_heads_in_batch += 1
103107

104108
if self.hierarchical_info["num_multiclass_heads"] > 1:
105-
losses["loss"] /= self.hierarchical_info["num_multiclass_heads"]
109+
losses["loss"] /= num_effective_heads_in_batch
106110

107111
if self.compute_multilabel_loss:
108112
head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :]

src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -117,22 +117,26 @@ def forward_train(self, cls_score, gt_label, **kwargs):
117117
cls_score = self.classifier(cls_score)
118118

119119
losses = dict(loss=0.0)
120+
num_effective_heads_in_batch = 0
120121
for i in range(self.hierarchical_info["num_multiclass_heads"]):
121-
head_gt = gt_label[:, i]
122-
head_logits = cls_score[
123-
:,
124-
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
125-
"head_idx_to_logits_range"
126-
][str(i)][1],
127-
]
128-
valid_mask = head_gt >= 0
129-
head_gt = head_gt[valid_mask].long()
130-
head_logits = head_logits[valid_mask, :]
131-
multiclass_loss = self.loss(head_logits, head_gt)
132-
losses["loss"] += multiclass_loss
122+
if i not in self.hierarchical_info["empty_multiclass_head_indices"]:
123+
head_gt = gt_label[:, i]
124+
head_logits = cls_score[
125+
:,
126+
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
127+
"head_idx_to_logits_range"
128+
][str(i)][1],
129+
]
130+
valid_mask = head_gt >= 0
131+
head_gt = head_gt[valid_mask].long()
132+
if len(head_gt) > 0:
133+
head_logits = head_logits[valid_mask, :]
134+
multiclass_loss = self.loss(head_logits, head_gt)
135+
losses["loss"] += multiclass_loss
136+
num_effective_heads_in_batch += 1
133137

134138
if self.hierarchical_info["num_multiclass_heads"] > 1:
135-
losses["loss"] /= self.hierarchical_info["num_multiclass_heads"]
139+
losses["loss"] /= num_effective_heads_in_batch
136140

137141
if self.compute_multilabel_loss:
138142
head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :]

src/otx/algorithms/classification/utils/cls_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disabl
6262
"class_to_group_idx": class_to_idx,
6363
"all_groups": exclusive_groups + single_label_groups,
6464
"label_to_idx": label_to_idx,
65+
"empty_multiclass_head_indices": [],
6566
}
6667
return mixed_cls_heads_info
6768

tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,32 @@ def test_metric_hierarchical_adapter(self):
142142
dataset = OTXHierarchicalClsDataset(
143143
otx_dataset=self.dataset, labels=self.dataset.get_labels(), hierarchical_info=class_info
144144
)
145-
146145
results = np.zeros((len(dataset), dataset.num_classes))
147146
metrics = dataset.evaluate(results)
148147

149148
assert len(metrics) > 0
150149
assert metrics["accuracy"] > 0
150+
151+
@e2e_pytest_unit
152+
def test_hierarchical_with_empty_heads(self):
153+
self.task_environment, self.dataset = init_environment(
154+
self.hyper_parameters, self.model_template, False, True, self.dataset_len
155+
)
156+
class_info = get_multihead_class_info(self.task_environment.label_schema)
157+
dataset = OTXHierarchicalClsDataset(
158+
otx_dataset=self.dataset, labels=self.dataset.get_labels(), hierarchical_info=class_info
159+
)
160+
pseudo_gt_labels = []
161+
pseudo_head_idx = 0
162+
for label in dataset.gt_labels:
163+
pseudo_gt_label = label
164+
pseudo_gt_label[pseudo_head_idx] = -1
165+
pseudo_gt_labels.append(pseudo_gt_label)
166+
pseudo_gt_labels = np.array(pseudo_gt_labels)
167+
168+
from copy import deepcopy
169+
170+
pseudo_dataset = deepcopy(dataset)
171+
pseudo_dataset.gt_labels = pseudo_gt_labels
172+
pseudo_dataset._update_heads_information()
173+
assert pseudo_dataset.hierarchical_info["empty_multiclass_head_indices"][pseudo_head_idx] == 0

tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@ def head_type(self) -> None:
2424

2525
@pytest.fixture(autouse=True)
2626
def setup(self, head_type) -> None:
27-
self.num_classes = 3
28-
self.head_dim = 5
27+
self.num_classes = 6
28+
self.head_dim = 10
2929
self.cls_heads_info = {
30-
"num_multiclass_heads": 1,
31-
"num_multilabel_classes": 1,
32-
"head_idx_to_logits_range": {"0": (0, 2)},
33-
"num_single_label_classes": 2,
30+
"num_multiclass_heads": 3,
31+
"num_multilabel_classes": 0,
32+
"head_idx_to_logits_range": {"0": (0, 2), "1": (2, 4), "2": (4, 6)},
33+
"num_single_label_classes": 6,
34+
"empty_multiclass_head_indices": [],
3435
}
3536
self.loss = dict(type="CrossEntropyLoss", use_sigmoid=False, reduction="mean", loss_weight=1.0)
3637
self.multilabel_loss = dict(type=AsymmetricLossWithIgnore.__name__, reduction="sum")
@@ -43,13 +44,23 @@ def setup(self, head_type) -> None:
4344
)
4445
self.default_head.init_weights()
4546
self.default_input = torch.ones((2, self.head_dim))
46-
self.default_gt = torch.zeros((2, 2))
47+
self.default_gt = torch.zeros((2, 3))
4748

4849
@e2e_pytest_unit
4950
def test_forward(self) -> None:
5051
result = self.default_head.forward_train(self.default_input, self.default_gt)
5152
assert "loss" in result
52-
assert result["loss"] >= 0
53+
assert result["loss"] >= 0 and not torch.isnan(result["loss"])
54+
55+
empty_head_gt_full = torch.tensor([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]])
56+
result_include_empty_full = self.default_head.forward_train(self.default_input, empty_head_gt_full)
57+
assert "loss" in result_include_empty_full
58+
assert result_include_empty_full["loss"] >= 0 and not torch.isnan(result_include_empty_full["loss"])
59+
60+
empty_head_gt_partial = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 0.0, 0.0]])
61+
result_include_empty_partial = self.default_head.forward_train(self.default_input, empty_head_gt_partial)
62+
assert "loss" in result_include_empty_partial
63+
assert result_include_empty_partial["loss"] >= 0 and not torch.isnan(result_include_empty_partial["loss"])
5364

5465
@e2e_pytest_unit
5566
def test_simple_test(self) -> None:

tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ def setup(self) -> None:
2323
self.model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model.py"))
2424
self.data_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "data_pipeline.py"))
2525

26+
self.multilabel_model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_multilabel.py"))
27+
self.hierarchical_model_cfg = MPAConfig.fromfile(
28+
os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_hierarchical.py")
29+
)
30+
2631
@e2e_pytest_unit
2732
def test_configure(self, mocker):
2833
mock_cfg_base = mocker.patch.object(ClassificationConfigurer, "configure_base")
@@ -119,6 +124,12 @@ def test_configure_model(self):
119124
assert self.model_cfg.model_task
120125
assert self.model_cfg.model.head.in_channels == 960
121126

127+
multilabel_model_cfg = self.multilabel_model_cfg
128+
self.configurer.configure_model(multilabel_model_cfg, ir_options)
129+
130+
h_label_model_cfg = self.hierarchical_model_cfg
131+
self.configurer.configure_model(h_label_model_cfg, ir_options)
132+
122133
@e2e_pytest_unit
123134
def test_configure_model_not_classification_task(self):
124135
ir_options = {"ir_model_path": {"ir_weight_path": "", "ir_weight_init": ""}}

0 commit comments

Comments
 (0)