Skip to content

Commit 21aeb8d

Browse files
sovrasovJihwanEom
andauthored
[FEATURE] Add Semi-SL multilabel classification algorithm (#1805)
* Add multilabel semi sl cls * Add semi sl multilabel recipies and modules * Fix linters * Fix multilabel classifier * Fix linting issues * Remove obsolete base model path in classificaiton templates * Add cli test * Update ssl configs * Update multilabel semisl head to be consistent with smeisl hooks * Add semi sl mlc * Fix number of lalbeled images in semisl docs * Fix multilabel semisl configs * Add unit tests for mlc ssl heads * Add tests for ssl mlc classifier * Update licenses in tests * Fix a typo in docstring Co-authored-by: Jihwan Eom <[email protected]> * Avoid extra multilabel model copying in config manager * Move semisl helpers to utils * Add a missing arg to docstring * Update docs * Fix linters * Update MLC ssl docs --------- Co-authored-by: Jihwan Eom <[email protected]>
1 parent 759d293 commit 21aeb8d

File tree

24 files changed

+593
-16
lines changed

24 files changed

+593
-16
lines changed

docs/source/guide/explanation/algorithms/classification/multi_label_classification.rst

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,56 @@ In the table below the `mAP <https://en.wikipedia.org/w/index.php?title=Informat
5252
| EfficientNet-V2-S | 91.91 | 77.28 | 71.52 | 80.24 |
5353
+-----------------------+-----------------+-----------+------------------+-----------+
5454

55-
.. ************************
56-
.. Semi-supervised Learning
57-
.. ************************
55+
************************
56+
Semi-supervised Learning
57+
************************
5858

59-
.. To be added soon
59+
Semi-SL (Semi-supervised Learning) is a type of machine learning algorithm that uses both labeled and unlabeled data to improve the performance of the model. This is particularly useful when labeled data is limited, expensive or time-consuming to obtain.
60+
61+
To utilize unlabeled data during training, we use `BarlowTwins loss <https://arxiv.org/abs/2103.03230>`_ as an auxiliary loss for Semi-SL task solving. BarlowTwins enforces consistency across augmented versions of the same data (both labeled and unlabeled): each sample is augmented first with `Augmix <https://arxiv.org/abs/1912.02781>`_, then strongly augmented sample is generated by applying a pre-defined `RandAugment <https://arxiv.org/abs/1909.13719>`_ strategy on top of the basic augmentation.
62+
63+
.. _mlc_cls_semi_supervised_pipeline:
64+
65+
- ``BarlowTwins loss``: A specific implementation of Semi-SL that combines the use of a consistency loss with strong data augmentations, and a specific optimizer called Sharpness-Aware Minimization (`SAM <https://arxiv.org/abs/2010.01412>`_) to improve the performance of the model.
66+
67+
- ``Adaptive loss auxiliary loss weighting``: A technique for assigning such a weight for an auxiliary loss that the resulting value is a predefined fraction of the EMA-smoothed main loss value. This method allows aligning contribution of the losses during different training phases.
68+
69+
- ``Exponential Moving Average (EMA)``: A technique for maintaining a moving average of the model's parameters, which can improve the generalization performance of the model.
70+
71+
- ``Additional techniques``: Other than that, we use several solutions that apply to supervised learning (No bias Decay, Augmentations, Early-Stopping, etc.)
72+
73+
Please, refer to the :doc:`tutorial <../../../tutorials/advanced/semi_sl>` on how to train semi-supervised learning.
74+
Training time depends on the number of images and can be up to several times longer than conventional supervised learning.
75+
76+
In the table below the mAP metric on some public datasets using our pipeline is presented.
77+
78+
+-----------------------+---------+----------------------+----------------+---------+----------------+---------+
79+
| Dataset | AerialMaritime 3 cls | | VOC 2007 3 cls | | COCO 14 5 cls | |
80+
+=======================+======================+=========+================+=========+================+=========+
81+
| | SL | Semi-SL | SL | Semi-SL | SL | Semi-SL |
82+
+-----------------------+----------------------+---------+----------------+---------+----------------+---------+
83+
| MobileNet-V3-large-1x | 74.28 | 74.41 | 96.34 | 97.29 | 82.39 | 83.77 |
84+
+-----------------------+----------------------+---------+----------------+---------+----------------+---------+
85+
| EfficientNet-B0 | 79.59 | 80.91 | 97.75 | 98.59 | 83.24 | 84.19 |
86+
+-----------------------+----------------------+---------+----------------+---------+----------------+---------+
87+
| EfficientNet-V2-S | 75.91 | 81.91 | 95.65 | 96.43 | 85.19 | 84.24 |
88+
+-----------------------+----------------------+---------+----------------+---------+----------------+---------+
89+
90+
AerialMaritime was sampled with 5 images per class. VOC was sampled with 10 images per class and COCO was sampled with 20 images per class.
91+
Additionel information abount the datasets can be found in the table below.
92+
93+
+-----------------------+----------------+----------------------+
94+
| Dataset | Labeled images | Unlabeled images |
95+
+=======================+================+======================+
96+
| AerialMaritime 3 cls | 10 | 42 |
97+
+-----------------------+----------------+----------------------+
98+
| VOC 2007 3 cls | 30 | 798 |
99+
+-----------------------+----------------+----------------------+
100+
| COCO 14 5 cls | 95 | 10142 |
101+
+-----------------------+----------------+----------------------+
102+
103+
.. note::
104+
This result can vary depending on the image selected for each class. Also, since there are few labeled settings for the Semi-SL algorithm. Some models may require larger datasets for better results.
60105

61106
.. ************************
62107
.. Self-supervised Learning

otx/algorithms/classification/configs/base/data/semisl/data_pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@
3030

3131
__train_pipeline = [
3232
*__common_pipeline,
33-
dict(type="PILImageToNDArray", keys=["img"]),
33+
dict(type="PostAug", keys=dict(img_strong=__strong_pipeline)),
34+
dict(type="PILImageToNDArray", keys=["img", "img_strong"]),
3435
dict(type="Normalize", **__img_norm_cfg),
35-
dict(type="ImageToTensor", keys=["img"]),
36+
dict(type="ImageToTensor", keys=["img", "img_strong"]),
3637
dict(type="ToTensor", keys=["gt_label"]),
37-
dict(type="Collect", keys=["img", "gt_label"]),
38+
dict(type="Collect", keys=["img", "img_strong", "gt_label"]),
3839
]
3940

4041
__unlabeled_pipeline = [
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""EfficientNet-B0 config for semi-supervised multi-label classification."""
2+
3+
# pylint: disable=invalid-name
4+
5+
_base_ = ["../../../../../recipes/stages/classification/multilabel/semisl.yaml", "../../base/models/efficientnet.py"]
6+
7+
model = dict(
8+
task="classification",
9+
type="SemiSLMultilabelClassifier",
10+
backbone=dict(
11+
version="b0",
12+
),
13+
head=dict(
14+
type="SemiLinearMultilabelClsHead",
15+
use_dynamic_loss_weighting=True,
16+
unlabeled_coef=0.1,
17+
in_channels=-1,
18+
aux_mlp=dict(hid_channels=0, out_channels=1024),
19+
normalized=True,
20+
scale=7.0,
21+
loss=dict(type="AsymmetricAngularLossWithIgnore", gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
22+
aux_loss=dict(
23+
type="BarlowTwinsLoss",
24+
off_diag_penality=1.0 / 128.0,
25+
loss_weight=1.0,
26+
),
27+
),
28+
)
29+
30+
fp16 = dict(loss_scale=512.0)

otx/algorithms/classification/configs/efficientnet_b0_cls_incr/template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ entrypoints:
1515
base: otx.algorithms.classification.tasks.ClassificationTrainTask
1616
openvino: otx.algorithms.classification.tasks.ClassificationOpenVINOTask
1717
nncf: otx.algorithms.classification.tasks.nncf.ClassificationNNCFTask
18-
base_model_path: ../../adapters/deep_object_reid/configs/efficientnet_b0/template_experimental.yaml
1918

2019
# Capabilities.
2120
capabilities:
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""EfficientNet-V2 config for semi-supervised multi-label classification."""
2+
3+
# pylint: disable=invalid-name
4+
5+
_base_ = ["../../../../../recipes/stages/classification/multilabel/semisl.yaml", "../../base/models/efficientnet_v2.py"]
6+
7+
model = dict(
8+
task="classification",
9+
type="SemiSLMultilabelClassifier",
10+
backbone=dict(
11+
version="s_21k",
12+
),
13+
head=dict(
14+
type="SemiLinearMultilabelClsHead",
15+
use_dynamic_loss_weighting=True,
16+
unlabeled_coef=0.1,
17+
in_channels=-1,
18+
aux_mlp=dict(hid_channels=0, out_channels=1024),
19+
normalized=True,
20+
scale=7.0,
21+
loss=dict(type="AsymmetricAngularLossWithIgnore", gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
22+
aux_loss=dict(
23+
type="BarlowTwinsLoss",
24+
off_diag_penality=1.0 / 128.0,
25+
loss_weight=1.0,
26+
),
27+
),
28+
)
29+
30+
fp16 = dict(loss_scale=512.0)

otx/algorithms/classification/configs/efficientnet_v2_s_cls_incr/template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ entrypoints:
1515
base: otx.algorithms.classification.tasks.ClassificationTrainTask
1616
openvino: otx.algorithms.classification.tasks.ClassificationOpenVINOTask
1717
nncf: otx.algorithms.classification.tasks.nncf.ClassificationNNCFTask
18-
base_model_path: ../../adapters/deep_object_reid/configs/efficientnet_v2_s/template_experimental.yaml
1918

2019
# Capabilities.
2120
capabilities:

otx/algorithms/classification/configs/mobilenet_v3_large_075_cls_incr/template_experiment.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ entrypoints:
1515
base: otx.algorithms.classification.tasks.ClassificationTrainTask
1616
openvino: otx.algorithms.classification.tasks.ClassificationOpenVINOTask
1717
nncf: otx.algorithms.classification.tasks.nncf.ClassificationNNCFTask
18-
base_model_path: ../../adapters/deep_object_reid/configs/mobilenet_v3_large_075/template_experimental.yaml
1918

2019
# Capabilities.
2120
capabilities:
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""MobileNet-V3-large-1 config for semi-supervised multi-label classification."""
2+
3+
# pylint: disable=invalid-name
4+
5+
_base_ = ["../../../../../recipes/stages/classification/multilabel/semisl.yaml", "../../base/models/mobilenet_v3.py"]
6+
7+
model = dict(
8+
task="classification",
9+
type="SemiSLMultilabelClassifier",
10+
backbone=dict(mode="large"),
11+
head=dict(
12+
type="SemiNonLinearMultilabelClsHead",
13+
in_channels=960,
14+
hid_channels=1280,
15+
use_dynamic_loss_weighting=True,
16+
unlabeled_coef=0.1,
17+
aux_mlp=dict(hid_channels=0, out_channels=1024),
18+
normalized=True,
19+
scale=7.0,
20+
loss=dict(type="AsymmetricAngularLossWithIgnore", gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
21+
aux_loss=dict(
22+
type="BarlowTwinsLoss",
23+
off_diag_penality=1.0 / 128.0,
24+
loss_weight=1.0,
25+
),
26+
),
27+
)
28+
29+
fp16 = dict(loss_scale=512.0)

otx/algorithms/classification/configs/mobilenet_v3_large_1_cls_incr/template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ entrypoints:
1515
base: otx.algorithms.classification.tasks.ClassificationTrainTask
1616
openvino: otx.algorithms.classification.tasks.ClassificationOpenVINOTask
1717
nncf: otx.algorithms.classification.tasks.nncf.ClassificationNNCFTask
18-
base_model_path: ../../adapters/deep_object_reid/configs/mobilenet_v3_large_1/template_experimental.yaml
1918

2019
# Capabilities.
2120
capabilities:

otx/algorithms/classification/configs/mobilenet_v3_small_cls_incr/template_experiment.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ entrypoints:
1515
base: otx.algorithms.classification.tasks.ClassificationTrainTask
1616
openvino: otx.algorithms.classification.tasks.ClassificationOpenVINOTask
1717
nncf: otx.algorithms.classification.tasks.nncf.ClassificationNNCFTask
18-
base_model_path: ../../adapters/deep_object_reid/configs/mobilenet_v3_small/template_experimental.yaml
1918

2019
# Capabilities.
2120
capabilities:

0 commit comments

Comments
 (0)