|
28 | 28 | from mmcv.utils import Registry, build_from_cfg |
29 | 29 | from torch import nn |
30 | 30 |
|
| 31 | +from otx.algorithms import TRANSFORMER_BACKBONES |
31 | 32 | from otx.api.entities.model_template import TaskType |
32 | 33 | from otx.cli.utils.importing import ( |
33 | 34 | get_backbone_list, |
@@ -101,8 +102,8 @@ def update_backbone_args(backbone_config: dict, registry: Registry, backend: str |
101 | 102 |
|
102 | 103 | def update_channels(model_config: MPAConfig, out_channels: Any): |
103 | 104 | """Update in_channel of head or neck.""" |
104 | | - if hasattr(model_config.model, "neck"): |
105 | | - if model_config.model.neck.type == "GlobalAveragePooling": |
| 105 | + if hasattr(model_config.model, "neck") and model_config.model.neck: |
| 106 | + if model_config.model.neck.get("type", None) == "GlobalAveragePooling": |
106 | 107 | model_config.model.neck.pop("in_channels", None) |
107 | 108 | else: |
108 | 109 | print(f"\tUpdate model.neck.in_channels: {out_channels}") |
@@ -212,6 +213,12 @@ def merge_backbone( |
212 | 213 | out_channels = -1 |
213 | 214 | if hasattr(model_config.model, "head"): |
214 | 215 | model_config.model.head.in_channels = -1 |
| 216 | + # TODO: This is a hard coded part of the Transformer backbone and needs to be refactored. |
| 217 | + if backend == "mmcls" and backbone_class in TRANSFORMER_BACKBONES: |
| 218 | + if hasattr(model_config.model, "neck"): |
| 219 | + model_config.model.neck = None |
| 220 | + if hasattr(model_config.model, "head"): |
| 221 | + model_config.model.head["type"] = "VisionTransformerClsHead" |
215 | 222 | else: |
216 | 223 | # Need to update in/out channel configuration here |
217 | 224 | out_channels = get_backbone_out_channels(backbone) |
|
0 commit comments