Skip to content

Commit 21508cd

Browse files
authored
Introduce Classification Factory and Simplify Model Imports (#4456)
* add factory for classficaiton * add mising files * minor * fix imports * fix imports in tests 2 * fix ruff * fix unit test * update factory. Reply comments * add literal to other backbones
1 parent c383371 commit 21508cd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+857
-299
lines changed

src/otx/backend/native/models/__init__.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,11 @@
55

66
from .anomaly import Padim, Stfpm, Uflow
77
from .classification import (
8-
EfficientNetHLabelCls,
9-
EfficientNetMulticlassCls,
10-
EfficientNetMultilabelCls,
11-
MobileNetV3HLabelCls,
12-
MobileNetV3MulticlassCls,
13-
MobileNetV3MultilabelCls,
14-
TimmModelHLabelCls,
15-
TimmModelMulticlassCls,
16-
TimmModelMultilabelCls,
17-
TVModelHLabelCls,
18-
TVModelMulticlassCls,
19-
TVModelMultilabelCls,
20-
VisionTransformerHLabelCls,
21-
VisionTransformerMulticlassCls,
22-
VisionTransformerMultilabelCls,
8+
EfficientNet,
9+
MobileNetV3,
10+
TimmModel,
11+
TVModel,
12+
VisionTransformer,
2313
)
2414
from .detection import ATSS, RTDETR, SSD, DFine, RTMDet
2515
from .instance_segmentation import MaskRCNN, MaskRCNNTV, RTMDetInst
@@ -30,21 +20,11 @@
3020
"Padim",
3121
"Stfpm",
3222
"Uflow",
33-
"EfficientNetHLabelCls",
34-
"EfficientNetMulticlassCls",
35-
"EfficientNetMultilabelCls",
36-
"MobileNetV3HLabelCls",
37-
"MobileNetV3MulticlassCls",
38-
"MobileNetV3MultilabelCls",
39-
"TimmModelHLabelCls",
40-
"TimmModelMulticlassCls",
41-
"TimmModelMultilabelCls",
42-
"TVModelHLabelCls",
43-
"TVModelMulticlassCls",
44-
"TVModelMultilabelCls",
45-
"VisionTransformerHLabelCls",
46-
"VisionTransformerMulticlassCls",
47-
"VisionTransformerMultilabelCls",
23+
"EfficientNet",
24+
"TimmModel",
25+
"MobileNetV3",
26+
"TVModel",
27+
"VisionTransformer",
4828
"ATSS",
4929
"DFine",
5030
"SSD",

src/otx/backend/native/models/classification/__init__.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,18 @@
33

44
"""Module for OTX classification models."""
55

6-
from .hlabel_models import (
7-
EfficientNetHLabelCls,
8-
MobileNetV3HLabelCls,
9-
TimmModelHLabelCls,
10-
TVModelHLabelCls,
11-
VisionTransformerHLabelCls,
12-
)
13-
from .multiclass_models import (
14-
EfficientNetMulticlassCls,
15-
MobileNetV3MulticlassCls,
16-
TimmModelMulticlassCls,
17-
TVModelMulticlassCls,
18-
VisionTransformerMulticlassCls,
19-
)
20-
from .multilabel_models import (
21-
EfficientNetMultilabelCls,
22-
MobileNetV3MultilabelCls,
23-
TimmModelMultilabelCls,
24-
TVModelMultilabelCls,
25-
VisionTransformerMultilabelCls,
6+
from .factory import (
7+
EfficientNet,
8+
MobileNetV3,
9+
TimmModel,
10+
TVModel,
11+
VisionTransformer,
2612
)
2713

2814
__all__ = [
29-
"EfficientNetMulticlassCls",
30-
"TimmModelMulticlassCls",
31-
"MobileNetV3MulticlassCls",
32-
"TVModelMulticlassCls",
33-
"VisionTransformerMulticlassCls",
34-
"EfficientNetHLabelCls",
35-
"TimmModelHLabelCls",
36-
"MobileNetV3HLabelCls",
37-
"TVModelHLabelCls",
38-
"VisionTransformerHLabelCls",
39-
"EfficientNetMultilabelCls",
40-
"TimmModelMultilabelCls",
41-
"MobileNetV3MultilabelCls",
42-
"TVModelMultilabelCls",
43-
"VisionTransformerMultilabelCls",
15+
"EfficientNet",
16+
"TimmModel",
17+
"MobileNetV3",
18+
"TVModel",
19+
"VisionTransformer",
4420
]

src/otx/backend/native/models/classification/backbones/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
from .mobilenet_v3 import MobileNetV3Backbone
88
from .timm import TimmBackbone
99
from .torchvision import TorchvisionBackbone
10-
from .vision_transformer import VisionTransformer
10+
from .vision_transformer import VisionTransformerBackbone
1111

12-
__all__ = ["EfficientNetBackbone", "TimmBackbone", "MobileNetV3Backbone", "VisionTransformer", "TorchvisionBackbone"]
12+
__all__ = [
13+
"EfficientNetBackbone",
14+
"TimmBackbone",
15+
"MobileNetV3Backbone",
16+
"VisionTransformerBackbone",
17+
"TorchvisionBackbone",
18+
]

src/otx/backend/native/models/classification/backbones/efficientnet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
404404
return self.conv(x)
405405

406406

407-
class EfficientNet(nn.Module):
408-
"""EfficientNet.
407+
class EfficientNetFeatureExtractor(nn.Module):
408+
"""Implementation of the EfficientNet Feature Extactor.
409409
410410
Args:
411411
channels : list of list of int. Number of output channels for each unit.
@@ -611,7 +611,7 @@ def __new__(
611611
input_size: tuple[int, int] | None = None,
612612
pretrained: bool = True,
613613
**kwargs,
614-
) -> EfficientNet:
614+
) -> EfficientNetFeatureExtractor:
615615
"""Create a new instance of the EfficientNet class.
616616
617617
Args:
@@ -621,7 +621,7 @@ def __new__(
621621
**kwargs: Additional keyword arguments to be passed to the EfficientNet constructor.
622622
623623
Returns:
624-
EfficientNet: The created EfficientNet model instance.
624+
EfficientNetFeatureExtractor: The created EfficientNetFeatureExtractor model instance.
625625
"""
626626
origin_input_size, depth_factor, width_factor = cls.EFFICIENTNET_CFG[model_name].values()
627627
input_size = input_size or origin_input_size
@@ -657,7 +657,7 @@ def __new__(
657657
if width_factor > 1.0:
658658
final_block_channels = round_channels(final_block_channels * width_factor)
659659

660-
model = EfficientNet(
660+
model = EfficientNetFeatureExtractor(
661661
channels=channels,
662662
init_block_channels=init_block_channels,
663663
final_block_channels=final_block_channels,

src/otx/backend/native/models/classification/backbones/mobilenet_v3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ def forward(
268268
return self.extract_features(x)
269269

270270

271-
class MobileNetV3(MobileNetV3Base):
272-
"""MobileNetV3 constructor.
271+
class MobileNetV3FeatureExtractor(MobileNetV3Base):
272+
"""MobileNetV3FeatureExtractor constructor.
273273
274274
Args:
275275
layer_cfgs (list): List of layer configurations.
@@ -396,7 +396,7 @@ def __new__(
396396
width_mult: float = 1.0,
397397
pretrained: bool = True,
398398
**kwargs,
399-
) -> MobileNetV3:
399+
) -> MobileNetV3FeatureExtractor:
400400
"""Create a new instance of the MobileNetV3 class.
401401
402402
Args:
@@ -412,7 +412,7 @@ def __new__(
412412
msg = f"Unknown MobileNetV3 model: {model_name}"
413413
raise ValueError(msg)
414414

415-
model = MobileNetV3(
415+
model = MobileNetV3FeatureExtractor(
416416
layer_cfgs=cls.MV3_CFG[model_name]["layer_cfgs"],
417417
width_mult=width_mult,
418418
**kwargs,

src/otx/backend/native/models/classification/backbones/vision_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import numpy as np
3636

3737

38-
class VisionTransformer(BaseModule):
38+
class VisionTransformerBackbone(BaseModule):
3939
"""Implementation of Vision Transformer from Timm.
4040
4141
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
@@ -553,7 +553,7 @@ def forward(
553553
@torch.no_grad()
554554
def _load_npz_weights( # noqa: C901
555555
self,
556-
model: VisionTransformer,
556+
model: VisionTransformerBackbone,
557557
checkpoint_path: str,
558558
prefix: str = "",
559559
) -> None:

0 commit comments

Comments
 (0)