Skip to content

Commit c6e2952

Browse files
authored
Update Label Info handling (#4127)
* Update h-cls info * Revert h-cls head to linear one * Cosmetic changes * Add arrow-specific labels management logic for cls * Update export logic * Update label info usage * Update unit tests * Fix linter * Fix unit tests * Fix linter * Consider multilabel scenario in h-cls * Update dataset docstring * Add unit tests * Don't preprocess h-cls dataset for arrow * Fimussing labels in multilabel training * Revert hcls head for effnet b0 * Update converter to pick up cls task
1 parent cf035f6 commit c6e2952

29 files changed

+278
-82
lines changed

src/otx/algo/classification/efficientnet.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, OTXEfficientNet
1515
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
1616
from otx.algo.classification.heads import (
17-
HierarchicalCBAMClsHead,
17+
HierarchicalLinearClsHead,
1818
LinearClsHead,
1919
MultiLabelLinearClsHead,
2020
SemiSLLinearClsHead,
@@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
272272

273273
return HLabelClassifier(
274274
backbone=backbone,
275-
neck=nn.Identity(),
276-
head=HierarchicalCBAMClsHead(
277-
in_channels=backbone.num_features,
278-
**copied_head_config,
279-
),
275+
neck=GlobalAveragePooling(dim=2),
276+
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
280277
multiclass_loss=nn.CrossEntropyLoss(),
281278
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
282279
)

src/otx/algo/classification/mobilenet_v3.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from otx.algo.classification.backbones import OTXMobileNetV3
1616
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
1717
from otx.algo.classification.heads import (
18-
HierarchicalCBAMClsHead,
18+
HierarchicalLinearClsHead,
1919
LinearClsHead,
2020
MultiLabelNonLinearClsHead,
2121
SemiSLLinearClsHead,
@@ -313,14 +313,12 @@ def _build_model(self, head_config: dict) -> nn.Module:
313313

314314
copied_head_config = copy(head_config)
315315
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
316+
in_channels = 960 if self.mode == "large" else 576
316317

317318
return HLabelClassifier(
318319
backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size),
319-
neck=nn.Identity(),
320-
head=HierarchicalCBAMClsHead(
321-
in_channels=960,
322-
**copied_head_config,
323-
),
320+
neck=GlobalAveragePooling(dim=2),
321+
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=in_channels),
324322
multiclass_loss=nn.CrossEntropyLoss(),
325323
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
326324
)

src/otx/algo/classification/timm_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType
1616
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
1717
from otx.algo.classification.heads import (
18-
HierarchicalCBAMClsHead,
1918
LinearClsHead,
2019
MultiLabelLinearClsHead,
2120
SemiSLLinearClsHead,
2221
)
2322
from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore
23+
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
2424
from otx.algo.classification.necks.gap import GlobalAveragePooling
2525
from otx.algo.classification.utils import get_classification_layers
2626
from otx.algo.utils.support_otx_v1 import OTXv1Helper
@@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
272272
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
273273
return HLabelClassifier(
274274
backbone=backbone,
275-
neck=nn.Identity(),
276-
head=HierarchicalCBAMClsHead(
277-
in_channels=backbone.num_features,
278-
**copied_head_config,
279-
),
275+
neck=GlobalAveragePooling(dim=2),
276+
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
280277
multiclass_loss=nn.CrossEntropyLoss(),
281278
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
282279
)

src/otx/algo/classification/torchvision_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from otx.algo.classification.backbones.torchvision import TorchvisionBackbone, TVModelType
1515
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
1616
from otx.algo.classification.heads import (
17-
HierarchicalCBAMClsHead,
1817
LinearClsHead,
1918
MultiLabelLinearClsHead,
2019
SemiSLLinearClsHead,
2120
)
2221
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
22+
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
2323
from otx.algo.classification.necks.gap import GlobalAveragePooling
2424
from otx.algo.classification.utils import get_classification_layers
2525
from otx.core.data.entity.classification import (
@@ -315,11 +315,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
315315
backbone = TorchvisionBackbone(backbone=self.backbone, pretrained=self.pretrained)
316316
return HLabelClassifier(
317317
backbone=backbone,
318-
neck=nn.Identity(),
319-
head=HierarchicalCBAMClsHead(
320-
in_channels=backbone.in_features,
321-
**head_config,
322-
),
318+
neck=GlobalAveragePooling(dim=2),
319+
head=HierarchicalLinearClsHead(**head_config, in_channels=backbone.in_features),
323320
multiclass_loss=nn.CrossEntropyLoss(),
324321
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
325322
)

src/otx/algo/classification/vit.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer
2020
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
2121
from otx.algo.classification.heads import (
22-
HierarchicalCBAMClsHead,
2322
MultiLabelLinearClsHead,
2423
SemiSLVisionTransformerClsHead,
2524
VisionTransformerClsHead,
2625
)
2726
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
27+
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
2828
from otx.algo.classification.utils import get_classification_layers
2929
from otx.algo.explain.explain_algo import ViTReciproCAM, feature_vector_fn
3030
from otx.algo.utils.support_otx_v1 import OTXv1Helper
@@ -466,11 +466,7 @@ def _build_model(self, head_config: dict) -> nn.Module:
466466
return HLabelClassifier(
467467
backbone=vit_backbone,
468468
neck=None,
469-
head=HierarchicalCBAMClsHead(
470-
in_channels=vit_backbone.embed_dim,
471-
step_size=1,
472-
**head_config,
473-
),
469+
head=HierarchicalLinearClsHead(**head_config, in_channels=vit_backbone.embed_dim),
474470
multiclass_loss=nn.CrossEntropyLoss(),
475471
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
476472
init_cfg=init_cfg,

src/otx/core/data/dataset/action_classification.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
image_color_channel: ImageColorChannel = ImageColorChannel.BGR,
3838
stack_images: bool = True,
3939
to_tv_image: bool = True,
40+
data_format: str = "",
4041
) -> None:
4142
super().__init__(
4243
dm_subset,

src/otx/core/data/dataset/anomaly.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
5858
stack_images: bool = True,
5959
to_tv_image: bool = True,
60+
data_format: str = "",
6061
) -> None:
6162
self.task_type = task_type
6263
super().__init__(

src/otx/core/data/dataset/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class OTXDataset(Dataset, Generic[T_OTXDataEntity]):
7070
max_refetch: Maximum number of images to fetch in cache
7171
image_color_channel: Color channel of images
7272
stack_images: Whether or not to stack images in collate function in OTXBatchData entity.
73+
data_format: Source data format, which was originally passed to datumaro (could be arrow for instance).
7374
7475
"""
7576

@@ -83,6 +84,7 @@ def __init__(
8384
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
8485
stack_images: bool = True,
8586
to_tv_image: bool = True,
87+
data_format: str = "",
8688
) -> None:
8789
self.dm_subset = dm_subset
8890
self.transforms = transforms
@@ -92,8 +94,11 @@ def __init__(
9294
self.image_color_channel = image_color_channel
9395
self.stack_images = stack_images
9496
self.to_tv_image = to_tv_image
97+
self.data_format = data_format
9598

96-
if self.dm_subset.categories():
99+
if self.dm_subset.categories() and data_format == "arrow":
100+
self.label_info = LabelInfo.from_dm_label_groups_arrow(self.dm_subset.categories()[AnnotationType.label])
101+
elif self.dm_subset.categories():
97102
self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])
98103
else:
99104
self.label_info = NullLabelInfo()

src/otx/core/data/dataset/classification.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,21 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None:
8080
ignored_labels: list[int] = [] # This should be assigned form item
8181
img_data, img_shape, _ = self._get_img_data_and_shape(img)
8282

83-
label_anns = []
83+
label_ids = set()
8484
for ann in item.annotations:
85+
# multilabel information stored in 'multi_label_ids' attribute when the source format is arrow
86+
if "multi_label_ids" in ann.attributes:
87+
for lbl_idx in ann.attributes["multi_label_ids"]:
88+
label_ids.add(lbl_idx)
89+
8590
if isinstance(ann, Label):
86-
label_anns.append(ann)
91+
label_ids.add(ann.label)
8792
else:
8893
# If the annotation is not Label, it should be converted to Label.
8994
# For Chained Task: Detection (Bbox) -> Classification (Label)
9095
label = Label(label=ann.label)
91-
if label not in label_anns:
92-
label_anns.append(label)
93-
labels = torch.as_tensor([ann.label for ann in label_anns])
96+
label_ids.add(label.label)
97+
labels = torch.as_tensor(list(label_ids))
9498

9599
entity = MultilabelClsDataEntity(
96100
image=img_data,
@@ -128,13 +132,22 @@ def __init__(self, **kwargs) -> None:
128132
self.dm_categories = self.dm_subset.categories()[AnnotationType.label]
129133

130134
# Hlabel classification used HLabelInfo to insert the HLabelData.
131-
self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories)
135+
if self.data_format == "arrow":
136+
# arrow format stores label IDs as names, have to deal with that here
137+
self.label_info = HLabelInfo.from_dm_label_groups_arrow(self.dm_categories)
138+
else:
139+
self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories)
140+
141+
self.id_to_name_mapping = dict(zip(self.label_info.label_ids, self.label_info.label_names))
142+
self.id_to_name_mapping[""] = ""
143+
132144
if self.label_info.num_multiclass_heads == 0:
133145
msg = "The number of multiclass heads should be larger than 0."
134146
raise ValueError(msg)
135147

136-
for dm_item in self.dm_subset:
137-
self._add_ancestors(dm_item.annotations)
148+
if self.data_format != "arrow":
149+
for dm_item in self.dm_subset:
150+
self._add_ancestors(dm_item.annotations)
138151

139152
def _add_ancestors(self, label_anns: list[Label]) -> None:
140153
"""Add ancestors recursively if some label miss the ancestor information.
@@ -149,14 +162,16 @@ def _add_ancestors(self, label_anns: list[Label]) -> None:
149162
"""
150163

151164
def _label_idx_to_name(idx: int) -> str:
152-
return self.label_info.label_names[idx]
165+
return self.dm_categories[idx].name
153166

154167
def _label_name_to_idx(name: str) -> int:
155168
indices = [idx for idx, val in enumerate(self.label_info.label_names) if val == name]
156169
return indices[0]
157170

158171
def _get_label_group_idx(label_name: str) -> int:
159172
if isinstance(self.label_info, HLabelInfo):
173+
if self.data_format == "arrow":
174+
return self.label_info.class_to_group_idx[self.id_to_name_mapping[label_name]][0]
160175
return self.label_info.class_to_group_idx[label_name][0]
161176
msg = f"self.label_info should have HLabelInfo type, got {type(self.label_info)}"
162177
raise ValueError(msg)
@@ -197,17 +212,22 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None:
197212
ignored_labels: list[int] = [] # This should be assigned form item
198213
img_data, img_shape, _ = self._get_img_data_and_shape(img)
199214

200-
label_anns = []
215+
label_ids = set()
201216
for ann in item.annotations:
217+
# in h-cls scenario multilabel information stored in 'multi_label_ids' attribute
218+
if "multi_label_ids" in ann.attributes:
219+
for lbl_idx in ann.attributes["multi_label_ids"]:
220+
label_ids.add(lbl_idx)
221+
202222
if isinstance(ann, Label):
203-
label_anns.append(ann)
223+
label_ids.add(ann.label)
204224
else:
205225
# If the annotation is not Label, it should be converted to Label.
206226
# For Chained Task: Detection (Bbox) -> Classification (Label)
207227
label = Label(label=ann.label)
208-
if label not in label_anns:
209-
label_anns.append(label)
210-
hlabel_labels = self._convert_label_to_hlabel_format(label_anns, ignored_labels)
228+
label_ids.add(label.label)
229+
230+
hlabel_labels = self._convert_label_to_hlabel_format([Label(label=idx) for idx in label_ids], ignored_labels)
211231

212232
entity = HlabelClsDataEntity(
213233
image=img_data,
@@ -256,18 +276,18 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label
256276
class_indices[i] = -1
257277

258278
for ann in label_anns:
259-
ann_name = self.dm_categories.items[ann.label].name
260-
ann_parent = self.dm_categories.items[ann.label].parent
279+
if self.data_format == "arrow":
280+
# skips unknown labels for instance, the empty one
281+
if self.dm_categories.items[ann.label].name not in self.id_to_name_mapping:
282+
continue
283+
ann_name = self.id_to_name_mapping[self.dm_categories.items[ann.label].name]
284+
else:
285+
ann_name = self.dm_categories.items[ann.label].name
261286
group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name]
262-
(parent_group_idx, parent_in_group_idx) = (
263-
self.label_info.class_to_group_idx[ann_parent] if ann_parent else (None, None)
264-
)
265287

266288
if group_idx < num_multiclass_heads:
267289
class_indices[group_idx] = in_group_idx
268-
if parent_group_idx is not None and parent_in_group_idx is not None:
269-
class_indices[parent_group_idx] = parent_in_group_idx
270-
elif not ignored_labels or ann.label not in ignored_labels:
290+
elif ann.label not in ignored_labels:
271291
class_indices[num_multiclass_heads + in_group_idx] = 1
272292
else:
273293
class_indices[num_multiclass_heads + in_group_idx] = -1

src/otx/core/data/dataset/keypoint_detection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ def __init__(
5454
self.dm_subset = self._get_single_bbox_dataset(dm_subset)
5555

5656
if self.dm_subset.categories():
57+
kp_labels = self.dm_subset.categories()[AnnotationType.points][0].labels
5758
self.label_info = LabelInfo(
58-
label_names=self.dm_subset.categories()[AnnotationType.points][0].labels,
59+
label_names=kp_labels,
5960
label_groups=[],
61+
label_ids=[str(i) for i in range(len(kp_labels))],
6062
)
6163
else:
6264
self.label_info = NullLabelInfo()

0 commit comments

Comments
 (0)