Skip to content

Commit 8063b62

Browse files
authored
[OTE][release]Fix ignore labels handling in multilabel cls (#1627)
* Fix ignore handling in mpa dataset * Always add ignore to label schema * Update MPA
1 parent c79db6a commit 8063b62

File tree

3 files changed

+7
-16
lines changed

3 files changed

+7
-16
lines changed

external/model-preparation-algorithm/mpa_tasks/extensions/datasets/mpa_cls_dataset.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,21 +169,17 @@ def get_indices(self, new_classes):
169169
def load_annotations(self):
170170
include_empty = self.empty_label in self.labels
171171
for i, _ in enumerate(self.ote_dataset):
172-
class_indices = []
173172
item_labels = self.ote_dataset[i].get_roi_labels(self.labels, include_empty=include_empty)
174173
ignored_labels = self.ote_dataset[i].ignored_labels
174+
onehot_indices = np.zeros(len(self.labels))
175175
if item_labels:
176176
for ote_lbl in item_labels:
177177
if ote_lbl not in ignored_labels:
178-
class_indices.append(self.label_names.index(ote_lbl.name))
178+
onehot_indices[self.label_names.index(ote_lbl.name)] = 1
179179
else:
180-
class_indices.append(-1)
181-
else: # this supposed to happen only on inference stage or if we have a negative in multilabel data
182-
class_indices.append(-1)
183-
onehot_indices = np.zeros(len(self.labels))
184-
for idx in class_indices:
185-
if idx != -1: # TODO: handling ignored label?
186-
onehot_indices[idx] = 1
180+
# during training we filter ignored classes out,
181+
# during validation mmcv's mAP also filters -1 labels
182+
onehot_indices[self.label_names.index(ote_lbl.name)] = -1
187183
self.gt_labels.append(onehot_indices)
188184
self.gt_labels = np.array(self.gt_labels)
189185

ote_cli/ote_cli/utils/io.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,8 @@ def add_subtask_labels(dataset, info):
220220
for stask in subtask: # if has several subtasks
221221
add_subtask_labels(dataset, stask)
222222

223+
label_schema.add_group(empty_group) # empty group is always added in geti
223224
for info in hierarchy_info:
224-
if info[
225-
"task_type"
226-
] == "multi-label" and emptylabel not in label_schema.get_labels(
227-
include_empty=True
228-
):
229-
label_schema.add_group(empty_group)
230225
add_subtask_labels(dataset, info)
231226
else:
232227
main_group = LabelGroup(

0 commit comments

Comments
 (0)