Skip to content

Commit f7e4799

Browse files
authored
Add mmcls.VisionTransformer backbone support (#1908)
* Add mmcls transformer backbones * Fix VisionTransformeroutput check * Add changes * Disable recording forward hooks in inferrer * Remove unused import
1 parent abe6aae commit f7e4799

File tree

7 files changed

+29
-7
lines changed

7 files changed

+29
-7
lines changed

otx/algorithms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22

33
# Copyright (C) 2022 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
5+
6+
TRANSFORMER_BACKBONES = ["VisionTransformer", "T2T_ViT", "Conformer"]

otx/algorithms/classification/configs/configuration.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ learning_parameters:
1010
stable. A larger batch size has higher memory requirements.
1111
editable: true
1212
header: Batch size
13-
max_value: 512
13+
max_value: 2048
1414
min_value: 1
1515
type: INTEGER
1616
ui_rules:

otx/algorithms/common/configs/training_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class BaseLearningParameters(ParameterGroup):
6565
batch_size = configurable_integer(
6666
default_value=5,
6767
min_value=1,
68-
max_value=512,
68+
max_value=2048,
6969
header="Batch size",
7070
description="The number of training samples seen in each iteration of training. Increasing thisvalue "
7171
"improves training time and may make the training more stable. A larger batch size has higher "

otx/cli/builder/builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from mmcv.utils import Registry, build_from_cfg
2929
from torch import nn
3030

31+
from otx.algorithms import TRANSFORMER_BACKBONES
3132
from otx.api.entities.model_template import TaskType
3233
from otx.cli.utils.importing import (
3334
get_backbone_list,
@@ -101,8 +102,8 @@ def update_backbone_args(backbone_config: dict, registry: Registry, backend: str
101102

102103
def update_channels(model_config: MPAConfig, out_channels: Any):
103104
"""Update in_channel of head or neck."""
104-
if hasattr(model_config.model, "neck"):
105-
if model_config.model.neck.type == "GlobalAveragePooling":
105+
if hasattr(model_config.model, "neck") and model_config.model.neck:
106+
if model_config.model.neck.get("type", None) == "GlobalAveragePooling":
106107
model_config.model.neck.pop("in_channels", None)
107108
else:
108109
print(f"\tUpdate model.neck.in_channels: {out_channels}")
@@ -212,6 +213,12 @@ def merge_backbone(
212213
out_channels = -1
213214
if hasattr(model_config.model, "head"):
214215
model_config.model.head.in_channels = -1
216+
# TODO: This is a hard coded part of the Transformer backbone and needs to be refactored.
217+
if backend == "mmcls" and backbone_class in TRANSFORMER_BACKBONES:
218+
if hasattr(model_config.model, "neck"):
219+
model_config.model.neck = None
220+
if hasattr(model_config.model, "head"):
221+
model_config.model.head["type"] = "VisionTransformerClsHead"
215222
else:
216223
# Need to update in/out channel configuration here
217224
out_channels = get_backbone_out_channels(backbone)

otx/cli/builder/supported_backbone/mmcls.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"options": {
1212
"arch": ["tiny", "small", "base"]
1313
},
14-
"available": []
14+
"available": ["CLASSIFICATION"]
1515
},
1616
"mmcls.ConvMixer": {
1717
"required": ["arch"],
@@ -287,7 +287,7 @@
287287
"mmcls.T2T_ViT": {
288288
"required": [],
289289
"options": {},
290-
"available": []
290+
"available": ["CLASSIFICATION"]
291291
},
292292
"mmcls.TIMMBackbone": {
293293
"required": ["model_name"],
@@ -341,7 +341,7 @@
341341
"deit-base"
342342
]
343343
},
344-
"available": []
344+
"available": ["CLASSIFICATION"]
345345
}
346346
}
347347
}

otx/mpa/cls/inferrer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mmcls.datasets import build_dataset as mmcls_build_dataset
1212
from mmcv import Config, ConfigDict
1313

14+
from otx.algorithms import TRANSFORMER_BACKBONES
1415
from otx.algorithms.common.adapters.mmcv.utils import (
1516
build_data_parallel,
1617
build_dataloader,
@@ -53,6 +54,10 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):
5354
model_builder = kwargs.get("model_builder", None)
5455
dump_features = kwargs.get("dump_features", False)
5556
dump_saliency_map = kwargs.get("dump_saliency_map", False)
57+
# TODO: It looks like we need to modify that code in an appropriate way.
58+
if model_cfg.model.head.get("type", None) == "VisionTransformerClsHead":
59+
dump_features = False
60+
dump_saliency_map = False
5661
eval = kwargs.get("eval", False)
5762
outputs = self.infer(
5863
cfg,

otx/mpa/cls/stage.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from mmcv import ConfigDict, build_from_cfg
1010

11+
from otx.algorithms import TRANSFORMER_BACKBONES
1112
from otx.algorithms.classification.adapters.mmcls.utils.builder import build_classifier
1213
from otx.mpa.stage import Stage
1314
from otx.mpa.utils.config_utils import recursively_update_cfg, update_or_add_custom_hook
@@ -89,6 +90,13 @@ def configure_in_channel(cfg):
8990
output = layer(torch.rand([1] + list(input_shape)))
9091
if isinstance(output, (tuple, list)):
9192
output = output[-1]
93+
94+
if layer.__class__.__name__ in TRANSFORMER_BACKBONES and isinstance(output, (tuple, list)):
95+
# mmcls.VisionTransformer outputs Tuple[List[...]] and the last index of List is the final logit.
96+
_, output = output
97+
if cfg.model.head.type != "VisionTransformerClsHead":
98+
raise ValueError(f"{layer.__class__.__name__ } needs VisionTransformerClsHead as head")
99+
92100
in_channels = output.shape[1]
93101
if cfg.model.get("neck") is not None:
94102
if cfg.model.neck.get("in_channels") is not None:

0 commit comments

Comments
 (0)