Skip to content

Commit d9f8e6c

Browse files
authored
Revert h-cls head to linear (#4221)
* Revert h-cls head to linear * Update changelog
1 parent c4b74c7 commit d9f8e6c

File tree

6 files changed

+18
-31
lines changed

6 files changed

+18
-31
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ All notable changes to this project will be documented in this file.
7272
(<https://github.com/openvinotoolkit/training_extensions/pull/4199>)
7373
- Fix label info on loading checkpoint
7474
(<https://github.com/openvinotoolkit/training_extensions/pull/4200>)
75+
- Revert h-cls head to linear one
76+
(<https://github.com/openvinotoolkit/training_extensions/pull/4221>)
7577

7678
## \[2.2.2\]
7779

src/otx/algo/classification/efficientnet.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, EfficientNetBackbone
1515
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
1616
from otx.algo.classification.heads import (
17-
HierarchicalCBAMClsHead,
17+
HierarchicalLinearClsHead,
1818
LinearClsHead,
1919
MultiLabelLinearClsHead,
2020
SemiSLLinearClsHead,
@@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
272272

273273
return HLabelClassifier(
274274
backbone=backbone,
275-
neck=nn.Identity(),
276-
head=HierarchicalCBAMClsHead(
277-
in_channels=backbone.num_features,
278-
**copied_head_config,
279-
),
275+
neck=GlobalAveragePooling(dim=2),
276+
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
280277
multiclass_loss=nn.CrossEntropyLoss(),
281278
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
282279
)

src/otx/algo/classification/mobilenet_v3.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from otx.algo.classification.backbones import MobileNetV3Backbone
1616
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
1717
from otx.algo.classification.heads import (
18-
HierarchicalCBAMClsHead,
18+
HierarchicalLinearClsHead,
1919
LinearClsHead,
2020
MultiLabelNonLinearClsHead,
2121
SemiSLLinearClsHead,
@@ -314,15 +314,13 @@ def _build_model(self, head_config: dict) -> nn.Module:
314314

315315
copied_head_config = copy(head_config)
316316
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
317-
317+
in_channels = MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"]
318318
backbone = MobileNetV3Backbone(mode=self.mode, input_size=self.input_size)
319+
319320
return HLabelClassifier(
320321
backbone=backbone,
321-
neck=nn.Identity(),
322-
head=HierarchicalCBAMClsHead(
323-
in_channels=MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"],
324-
**copied_head_config,
325-
),
322+
neck=GlobalAveragePooling(dim=2),
323+
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=in_channels),
326324
multiclass_loss=nn.CrossEntropyLoss(),
327325
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
328326
)

src/otx/algo/classification/timm_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from otx.algo.classification.backbones.timm import TimmBackbone
1616
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
1717
from otx.algo.classification.heads import (
18-
HierarchicalCBAMClsHead,
18+
HierarchicalLinearClsHead,
1919
LinearClsHead,
2020
MultiLabelLinearClsHead,
2121
SemiSLLinearClsHead,
@@ -348,11 +348,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
348348
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
349349
return HLabelClassifier(
350350
backbone=backbone,
351-
neck=nn.Identity(),
352-
head=HierarchicalCBAMClsHead(
353-
in_channels=backbone.num_features,
354-
**copied_head_config,
355-
),
351+
neck=GlobalAveragePooling(dim=2),
352+
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
356353
multiclass_loss=nn.CrossEntropyLoss(),
357354
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
358355
)

src/otx/algo/classification/torchvision_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from otx.algo.classification.backbones.torchvision import TorchvisionBackbone, TVModelType
1515
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
1616
from otx.algo.classification.heads import (
17-
HierarchicalCBAMClsHead,
17+
HierarchicalLinearClsHead,
1818
LinearClsHead,
1919
MultiLabelLinearClsHead,
2020
SemiSLLinearClsHead,
@@ -315,11 +315,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
315315
backbone = TorchvisionBackbone(backbone=self.backbone, pretrained=self.pretrained)
316316
return HLabelClassifier(
317317
backbone=backbone,
318-
neck=nn.Identity(),
319-
head=HierarchicalCBAMClsHead(
320-
in_channels=backbone.in_features,
321-
**head_config,
322-
),
318+
neck=GlobalAveragePooling(dim=2),
319+
head=HierarchicalLinearClsHead(**head_config, in_channels=backbone.in_features),
323320
multiclass_loss=nn.CrossEntropyLoss(),
324321
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
325322
)

src/otx/algo/classification/vit.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer
2020
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
2121
from otx.algo.classification.heads import (
22-
HierarchicalCBAMClsHead,
22+
HierarchicalLinearClsHead,
2323
MultiLabelLinearClsHead,
2424
SemiSLVisionTransformerClsHead,
2525
VisionTransformerClsHead,
@@ -466,11 +466,7 @@ def _build_model(self, head_config: dict) -> nn.Module:
466466
return HLabelClassifier(
467467
backbone=vit_backbone,
468468
neck=None,
469-
head=HierarchicalCBAMClsHead(
470-
in_channels=vit_backbone.embed_dim,
471-
step_size=1,
472-
**head_config,
473-
),
469+
head=HierarchicalLinearClsHead(**head_config, in_channels=vit_backbone.embed_dim),
474470
multiclass_loss=nn.CrossEntropyLoss(),
475471
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
476472
init_cfg=init_cfg,

0 commit comments

Comments
 (0)