Skip to content

Commit 16dd688

Browse files
author
Galina Zalesskaya
authored
[release 1.5.0] DeiT: enable tests + add ViTFeatureVectorHook (#2630)
Add ViT feature vector hook
1 parent 6d3dd34 commit 16dd688

File tree

4 files changed

+19
-67
lines changed

4 files changed

+19
-67
lines changed

src/otx/algorithms/classification/adapters/mmcls/task.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from mmcv.runner import wrap_fp16_model
2020
from mmcv.utils import Config, ConfigDict
2121

22-
from otx.algorithms import TRANSFORMER_BACKBONES
2322
from otx.algorithms.classification.adapters.mmcls.utils.exporter import (
2423
ClassificationExporter,
2524
)
@@ -31,6 +30,7 @@
3130
EigenCamHook,
3231
FeatureVectorHook,
3332
ReciproCAMHook,
33+
ViTFeatureVectorHook,
3434
ViTReciproCAMHook,
3535
)
3636
from otx.algorithms.common.adapters.mmcv.utils import (
@@ -225,7 +225,6 @@ def _infer_model(
225225
)
226226
)
227227

228-
dump_features = True
229228
dump_saliency_map = not inference_parameters.is_evaluation if inference_parameters else True
230229

231230
self._init_task()
@@ -274,16 +273,16 @@ def hook(module, inp, outp):
274273
forward_explainer_hook: Union[nullcontext, BaseRecordingForwardHook]
275274
if model_type == "VisionTransformer":
276275
forward_explainer_hook = ViTReciproCAMHook(feature_model)
277-
elif (
278-
not dump_saliency_map or model_type in TRANSFORMER_BACKBONES
279-
): # TODO: remove latter "or" condition after resolving Issue#2098
276+
elif not dump_saliency_map:
280277
forward_explainer_hook = nullcontext()
281278
else:
282279
forward_explainer_hook = ReciproCAMHook(feature_model)
283-
if (
284-
not dump_features or model_type in TRANSFORMER_BACKBONES
285-
): # TODO: remove latter "or" condition after resolving Issue#2098
286-
feature_vector_hook: Union[nullcontext, BaseRecordingForwardHook] = nullcontext()
280+
281+
feature_vector_hook: Union[nullcontext, BaseRecordingForwardHook]
282+
if model_type == "VisionTransformer":
283+
feature_vector_hook = ViTFeatureVectorHook(feature_model)
284+
elif not dump_saliency_map:
285+
feature_vector_hook = nullcontext()
287286
else:
288287
feature_vector_hook = FeatureVectorHook(feature_model)
289288

@@ -533,11 +532,6 @@ def _export_model(self, precision: ModelPrecision, export_format: ExportType, du
533532
export_options["precision"] = str(precision)
534533
export_options["type"] = str(export_format)
535534

536-
# [TODO] Enable dump_features for ViT backbones
537-
model_type = cfg.model.backbone.type.split(".")[-1] # mmcls.VisionTransformer => VisionTransformer
538-
if model_type in TRANSFORMER_BACKBONES:
539-
dump_features = False
540-
541535
export_options["deploy_cfg"]["dump_features"] = dump_features
542536
if dump_features:
543537
output_names = export_options["deploy_cfg"]["ir_config"]["output_names"]

src/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import annotations
1717

1818
from abc import ABC
19-
from typing import List, Optional, Sequence, Union
19+
from typing import List, Optional, Sequence, Tuple, Union
2020

2121
import numpy as np
2222
import torch
@@ -172,6 +172,16 @@ def func(feature_map: Union[torch.Tensor, Sequence[torch.Tensor]], fpn_idx: int
172172
return feature_vector
173173

174174

175+
class ViTFeatureVectorHook(BaseRecordingForwardHook):
176+
"""FeatureVectorHook for transformer-based classifiers."""
177+
178+
@staticmethod
179+
def func(features: Tuple[List[torch.Tensor]], fpn_idx: int = -1) -> torch.Tensor:
180+
"""Generate the feature vector for transformer-based classifiers by returning the cls token."""
181+
_, cls_token = features[0]
182+
return cls_token
183+
184+
175185
class ReciproCAMHook(BaseRecordingForwardHook):
176186
"""Implementation of recipro-cam for class-wise saliency map.
177187

tests/e2e/cli/classification/test_classification.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ def test_otx_resume(self, template, tmp_dir_path):
137137
@pytest.mark.parametrize("template", templates, ids=templates_ids)
138138
@pytest.mark.parametrize("dump_features", [True, False])
139139
def test_otx_export(self, template, tmp_dir_path, dump_features):
140-
if template.name == "DeiT-Tiny" and dump_features:
141-
pytest.skip(reason="Issue#2098 ViT template does not support dump_features.")
142140
tmp_dir_path = tmp_dir_path / "multi_class_cls"
143141
otx_export_testing(template, tmp_dir_path, dump_features)
144142

@@ -160,17 +158,13 @@ def test_otx_eval(self, template, tmp_dir_path):
160158
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
161159
@pytest.mark.parametrize("template", templates, ids=templates_ids)
162160
def test_otx_explain(self, template, tmp_dir_path):
163-
if template.name == "DeiT-Tiny":
164-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
165161
tmp_dir_path = tmp_dir_path / "multi_class_cls"
166162
otx_explain_testing(template, tmp_dir_path, otx_dir, args)
167163

168164
@e2e_pytest_component
169165
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
170166
@pytest.mark.parametrize("template", templates, ids=templates_ids)
171167
def test_otx_explain_openvino(self, template, tmp_dir_path):
172-
if template.name == "DeiT-Tiny":
173-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
174168
tmp_dir_path = tmp_dir_path / "multi_class_cls"
175169
otx_explain_openvino_testing(template, tmp_dir_path, otx_dir, args)
176170

@@ -383,8 +377,6 @@ def test_otx_resume(self, template, tmp_dir_path):
383377
@pytest.mark.parametrize("template", templates, ids=templates_ids)
384378
@pytest.mark.parametrize("dump_features", [True, False])
385379
def test_otx_export(self, template, tmp_dir_path, dump_features):
386-
if template.name == "DeiT-Tiny" and dump_features:
387-
pytest.skip(reason="Issue#2098 ViT template does not support dump_features.")
388380
tmp_dir_path = tmp_dir_path / "multi_label_cls"
389381
otx_export_testing(template, tmp_dir_path, dump_features)
390382

@@ -399,8 +391,6 @@ def test_otx_eval(self, template, tmp_dir_path):
399391
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
400392
@pytest.mark.parametrize("template", templates, ids=templates_ids)
401393
def test_otx_explain(self, template, tmp_dir_path):
402-
if template.name == "DeiT-Tiny":
403-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
404394
tmp_dir_path = tmp_dir_path / "multi_label_cls"
405395
otx_explain_testing(template, tmp_dir_path, otx_dir, args_m)
406396

@@ -546,8 +536,6 @@ def test_otx_resume(self, template, tmp_dir_path):
546536
@pytest.mark.parametrize("template", templates, ids=templates_ids)
547537
@pytest.mark.parametrize("dump_features", [True, False])
548538
def test_otx_export(self, template, tmp_dir_path, dump_features):
549-
if template.name == "DeiT-Tiny" and dump_features:
550-
pytest.skip(reason="Issue#2098 ViT template does not support dump_features.")
551539
tmp_dir_path = tmp_dir_path / "h_label_cls"
552540
otx_export_testing(template, tmp_dir_path, dump_features)
553541

@@ -562,8 +550,6 @@ def test_otx_eval(self, template, tmp_dir_path):
562550
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
563551
@pytest.mark.parametrize("template", templates, ids=templates_ids)
564552
def test_otx_explain(self, template, tmp_dir_path):
565-
if template.name == "DeiT-Tiny":
566-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
567553
tmp_dir_path = tmp_dir_path / "h_label_cls"
568554
otx_explain_testing(template, tmp_dir_path, otx_dir, args_h)
569555

tests/integration/cli/classification/test_classification.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,6 @@ def test_otx_resume(self, template, tmp_dir_path):
124124
@pytest.mark.parametrize("template", templates, ids=templates_ids)
125125
@pytest.mark.parametrize("dump_features", [True, False])
126126
def test_otx_export(self, template, tmp_dir_path, dump_features):
127-
if template.name == "DeiT-Tiny" and dump_features:
128-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
129127
tmp_dir_path = tmp_dir_path / "multi_class_cls"
130128
otx_export_testing(template, tmp_dir_path, dump_features, check_ir_meta=True)
131129

@@ -150,48 +148,36 @@ def test_otx_eval(self, template, tmp_dir_path):
150148
@e2e_pytest_component
151149
@pytest.mark.parametrize("template", templates, ids=templates_ids)
152150
def test_otx_explain(self, template, tmp_dir_path):
153-
if template.name == "DeiT-Tiny":
154-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
155151
tmp_dir_path = tmp_dir_path / "multi_class_cls"
156152
otx_explain_testing(template, tmp_dir_path, otx_dir, args)
157153

158154
@e2e_pytest_component
159155
@pytest.mark.parametrize("template", templates, ids=templates_ids)
160156
def test_otx_explain_all_classes(self, template, tmp_dir_path):
161-
if template.name == "DeiT-Tiny":
162-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
163157
tmp_dir_path = tmp_dir_path / "multi_class_cls"
164158
otx_explain_testing_all_classes(template, tmp_dir_path, otx_dir, args)
165159

166160
@e2e_pytest_component
167161
@pytest.mark.parametrize("template", templates, ids=templates_ids)
168162
def test_otx_explain_process_saliency_maps(self, template, tmp_dir_path):
169-
if template.name == "DeiT-Tiny":
170-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
171163
tmp_dir_path = tmp_dir_path / "multi_class_cls"
172164
otx_explain_testing_process_saliency_maps(template, tmp_dir_path, otx_dir, args)
173165

174166
@e2e_pytest_component
175167
@pytest.mark.parametrize("template", templates, ids=templates_ids)
176168
def test_otx_explain_openvino(self, template, tmp_dir_path):
177-
if template.name == "DeiT-Tiny":
178-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
179169
tmp_dir_path = tmp_dir_path / "multi_class_cls"
180170
otx_explain_openvino_testing(template, tmp_dir_path, otx_dir, args)
181171

182172
@e2e_pytest_component
183173
@pytest.mark.parametrize("template", templates, ids=templates_ids)
184174
def test_otx_explain_all_classes_openvino(self, template, tmp_dir_path):
185-
if template.name == "DeiT-Tiny":
186-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
187175
tmp_dir_path = tmp_dir_path / "multi_class_cls"
188176
otx_explain_all_classes_openvino_testing(template, tmp_dir_path, otx_dir, args)
189177

190178
@e2e_pytest_component
191179
@pytest.mark.parametrize("template", templates, ids=templates_ids)
192180
def test_otx_explain_process_saliency_maps_openvino(self, template, tmp_dir_path):
193-
if template.name == "DeiT-Tiny":
194-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
195181
tmp_dir_path = tmp_dir_path / "multi_class_cls"
196182
otx_explain_process_saliency_maps_openvino_testing(template, tmp_dir_path, otx_dir, args)
197183

@@ -365,48 +351,36 @@ def test_otx_eval(self, template, tmp_dir_path):
365351
@e2e_pytest_component
366352
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
367353
def test_otx_explain(self, template, tmp_dir_path):
368-
if template.name == "DeiT-Tiny":
369-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
370354
tmp_dir_path = tmp_dir_path / "multi_label_cls"
371355
otx_explain_testing(template, tmp_dir_path, otx_dir, args_m)
372356

373357
@e2e_pytest_component
374358
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
375359
def test_otx_explain_all_classes(self, template, tmp_dir_path):
376-
if template.name == "DeiT-Tiny":
377-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
378360
tmp_dir_path = tmp_dir_path / "multi_label_cls"
379361
otx_explain_testing_all_classes(template, tmp_dir_path, otx_dir, args_m)
380362

381363
@e2e_pytest_component
382364
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
383365
def test_otx_explain_process_saliency_maps(self, template, tmp_dir_path):
384-
if template.name == "DeiT-Tiny":
385-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
386366
tmp_dir_path = tmp_dir_path / "multi_label_cls"
387367
otx_explain_testing_process_saliency_maps(template, tmp_dir_path, otx_dir, args_m)
388368

389369
@e2e_pytest_component
390370
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
391371
def test_otx_explain_openvino(self, template, tmp_dir_path):
392-
if template.name == "DeiT-Tiny":
393-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
394372
tmp_dir_path = tmp_dir_path / "multi_label_cls"
395373
otx_explain_openvino_testing(template, tmp_dir_path, otx_dir, args_m)
396374

397375
@e2e_pytest_component
398376
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
399377
def test_otx_explain_all_classes_openvino(self, template, tmp_dir_path):
400-
if template.name == "DeiT-Tiny":
401-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
402378
tmp_dir_path = tmp_dir_path / "multi_label_cls"
403379
otx_explain_all_classes_openvino_testing(template, tmp_dir_path, otx_dir, args_m)
404380

405381
@e2e_pytest_component
406382
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
407383
def test_otx_explain_process_saliency_maps_openvino(self, template, tmp_dir_path):
408-
if template.name == "DeiT-Tiny":
409-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
410384
tmp_dir_path = tmp_dir_path / "multi_label_cls"
411385
otx_explain_process_saliency_maps_openvino_testing(template, tmp_dir_path, otx_dir, args_m)
412386

@@ -502,48 +476,36 @@ def test_otx_eval(self, template, tmp_dir_path):
502476
@e2e_pytest_component
503477
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
504478
def test_otx_explain(self, template, tmp_dir_path):
505-
if template.name == "DeiT-Tiny":
506-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
507479
tmp_dir_path = tmp_dir_path / "h_label_cls"
508480
otx_explain_testing(template, tmp_dir_path, otx_dir, args_h)
509481

510482
@e2e_pytest_component
511483
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
512484
def test_otx_explain_all_classes(self, template, tmp_dir_path):
513-
if template.name == "DeiT-Tiny":
514-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
515485
tmp_dir_path = tmp_dir_path / "h_label_cls"
516486
otx_explain_testing_all_classes(template, tmp_dir_path, otx_dir, args_h)
517487

518488
@e2e_pytest_component
519489
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
520490
def test_otx_explain_process_saliency_maps(self, template, tmp_dir_path):
521-
if template.name == "DeiT-Tiny":
522-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
523491
tmp_dir_path = tmp_dir_path / "h_label_cls"
524492
otx_explain_testing_process_saliency_maps(template, tmp_dir_path, otx_dir, args_h)
525493

526494
@e2e_pytest_component
527495
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
528496
def test_otx_explain_openvino(self, template, tmp_dir_path):
529-
if template.name == "DeiT-Tiny":
530-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
531497
tmp_dir_path = tmp_dir_path / "h_label_cls"
532498
otx_explain_openvino_testing(template, tmp_dir_path, otx_dir, args_h)
533499

534500
@e2e_pytest_component
535501
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
536502
def test_otx_explain_all_classes_openvino(self, template, tmp_dir_path):
537-
if template.name == "DeiT-Tiny":
538-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
539503
tmp_dir_path = tmp_dir_path / "h_label_cls"
540504
otx_explain_all_classes_openvino_testing(template, tmp_dir_path, otx_dir, args_h)
541505

542506
@e2e_pytest_component
543507
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
544508
def test_otx_explain_process_saliency_maps_openvino(self, template, tmp_dir_path):
545-
if template.name == "DeiT-Tiny":
546-
pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.")
547509
tmp_dir_path = tmp_dir_path / "h_label_cls"
548510
otx_explain_process_saliency_maps_openvino_testing(template, tmp_dir_path, otx_dir, args_h)
549511

0 commit comments

Comments
 (0)