|
7 | 7 | import openvino as ov |
8 | 8 | import pytest |
9 | 9 | import torch |
| 10 | +from datumaro import LabelCategories |
| 11 | +from datumaro.components.annotation import GroupType |
10 | 12 | from lightning import Trainer |
11 | 13 | from lightning.pytorch.utilities.types import LRSchedulerConfig |
12 | 14 | from model_api.models.result import ClassificationResult |
13 | 15 | from pytest_mock import MockerFixture |
14 | 16 |
|
15 | 17 | from otx.core.data.entity.base import OTXBatchDataEntity |
16 | 18 | from otx.core.model.base import OTXModel, OVModel |
| 19 | +from otx.core.model.classification import OTXHlabelClsModel, OTXMulticlassClsModel |
| 20 | +from otx.core.model.segmentation import OTXSegmentationModel |
17 | 21 | from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler |
| 22 | +from otx.core.types.label import HLabelInfo, LabelInfo, SegLabelInfo |
18 | 23 | from tests.unit.core.utils.test_utils import get_dummy_ov_cls_model |
19 | 24 |
|
20 | 25 |
|
@@ -83,6 +88,108 @@ def test_smart_weight_loading(self, mocker) -> None: |
83 | 88 | prev_state_dict["model.head.bias"], |
84 | 89 | ) |
85 | 90 |
|
| 91 | + def test_label_info_dispatch(self, mocker): |
| 92 | + with mocker.patch.object(OTXModel, "_create_model", return_value=MockNNModule(3)): |
| 93 | + with pytest.raises(TypeError, match="invalid_label_info"): |
| 94 | + OTXModel(label_info="invalid_label_info") |
| 95 | + |
| 96 | + # Test with LabelInfo |
| 97 | + label_info = OTXModel( |
| 98 | + label_info=LabelInfo( |
| 99 | + ["label_1", "label_2"], |
| 100 | + label_ids=["1", "2"], |
| 101 | + label_groups=[["label_1", "label_2"]], |
| 102 | + ), |
| 103 | + ) |
| 104 | + assert isinstance(label_info.label_info, LabelInfo) |
| 105 | + |
| 106 | + # Test with SegLabelInfo |
| 107 | + seg_label_info = OTXModel(label_info=SegLabelInfo.from_num_classes(3)) |
| 108 | + assert isinstance(seg_label_info.label_info, SegLabelInfo) |
| 109 | + |
| 110 | + with mocker.patch.object(OTXMulticlassClsModel, "_create_model", return_value=MockNNModule(3)): |
| 111 | + # Test simple Classfication model loading checkpoint |
| 112 | + cls_model = OTXMulticlassClsModel( |
| 113 | + label_info=LabelInfo( |
| 114 | + ["label_1", "label_2"], |
| 115 | + label_ids=["1", "2"], |
| 116 | + label_groups=[["label_1", "label_2"]], |
| 117 | + ), |
| 118 | + input_size=(224, 224), |
| 119 | + ) |
| 120 | + label_info_dict = { |
| 121 | + "label_ids": ["1", "2"], |
| 122 | + "label_names": ["label_1", "label_2"], |
| 123 | + "label_groups": [["label_1", "label_2"]], |
| 124 | + } |
| 125 | + cls_model.load_state_dict_incrementally( |
| 126 | + {"state_dict": cls_model.state_dict(), "hyper_parameters": {"label_info": label_info_dict}}, |
| 127 | + ) |
| 128 | + assert isinstance(cls_model.label_info, LabelInfo) |
| 129 | + # test if ignore_index is not set |
| 130 | + label_info_dict["ignore_index"] = 255 |
| 131 | + with pytest.raises(TypeError, match=r"unexpected keyword argument.*ignore_index"): |
| 132 | + cls_model.load_state_dict_incrementally( |
| 133 | + {"state_dict": cls_model.state_dict(), "hyper_parameters": {"label_info": label_info_dict}}, |
| 134 | + ) |
| 135 | + |
| 136 | + with mocker.patch.object(OTXSegmentationModel, "_create_model", return_value=MockNNModule(3)): |
| 137 | + # test segmentation model loading checkpoint with SegLabelInfo |
| 138 | + segmentation_model = OTXSegmentationModel( |
| 139 | + label_info=SegLabelInfo.from_num_classes(3), |
| 140 | + input_size=(224, 224), |
| 141 | + model_name="segmentation_model", |
| 142 | + ) |
| 143 | + segmentation_model.load_state_dict_incrementally( |
| 144 | + {"state_dict": segmentation_model.state_dict(), "hyper_parameters": {"label_info": label_info_dict}}, |
| 145 | + ) |
| 146 | + assert isinstance(segmentation_model.label_info, SegLabelInfo) |
| 147 | + assert hasattr(segmentation_model.label_info, "ignore_index") |
| 148 | + assert segmentation_model.label_info.ignore_index == 255 |
| 149 | + |
| 150 | + # test hlabel classification model loading checkpoint with HLabelInfo |
| 151 | + labels = [ |
| 152 | + LabelCategories.Category(name="car", parent="vehicle"), |
| 153 | + LabelCategories.Category(name="truck", parent="vehicle"), |
| 154 | + LabelCategories.Category(name="plush toy", parent="plush toy"), |
| 155 | + LabelCategories.Category(name="No class"), |
| 156 | + ] |
| 157 | + label_groups = [ |
| 158 | + LabelCategories.LabelGroup( |
| 159 | + name="Detection labels___vehicle", |
| 160 | + labels=["car", "truck"], |
| 161 | + group_type=GroupType.EXCLUSIVE, |
| 162 | + ), |
| 163 | + LabelCategories.LabelGroup( |
| 164 | + name="Detection labels___plush toy", |
| 165 | + labels=["plush toy"], |
| 166 | + group_type=GroupType.EXCLUSIVE, |
| 167 | + ), |
| 168 | + LabelCategories.LabelGroup(name="No class", labels=["No class"], group_type=GroupType.RESTRICTED), |
| 169 | + ] |
| 170 | + dm_label_categories = LabelCategories(items=labels, label_groups=label_groups) |
| 171 | + hlabel_info = HLabelInfo.from_dm_label_groups(dm_label_categories) |
| 172 | + hlabel_dict_label_info = hlabel_info.as_dict(normalize_label_names=True) |
| 173 | + |
| 174 | + with mocker.patch.object(OTXHlabelClsModel, "_create_model", return_value=MockNNModule(3)): |
| 175 | + hlabel_model = OTXHlabelClsModel(hlabel_dict_label_info, input_size=(224, 224)) |
| 176 | + hlabel_model.load_state_dict_incrementally( |
| 177 | + {"state_dict": hlabel_model.state_dict(), "hyper_parameters": {"label_info": hlabel_dict_label_info}}, |
| 178 | + ) |
| 179 | + |
| 180 | + with pytest.raises(TypeError, match=r"unexpected keyword argument.*num_multiclass_heads"): |
| 181 | + segmentation_model.load_state_dict_incrementally( |
| 182 | + { |
| 183 | + "state_dict": segmentation_model.state_dict(), |
| 184 | + "hyper_parameters": {"label_info": hlabel_dict_label_info}, |
| 185 | + }, |
| 186 | + ) |
| 187 | + |
| 188 | + with pytest.raises(TypeError, match=r"unexpected keyword argument.*num_multiclass_heads"): |
| 189 | + cls_model.load_state_dict_incrementally( |
| 190 | + {"state_dict": cls_model.state_dict(), "hyper_parameters": {"label_info": hlabel_dict_label_info}}, |
| 191 | + ) |
| 192 | + |
86 | 193 | def test_lr_scheduler_step(self, mocker: MockerFixture) -> None: |
87 | 194 | mock_linear_warmup_scheduler = mocker.create_autospec(spec=LinearWarmupScheduler) |
88 | 195 | mock_main_scheduler = mocker.create_autospec(spec=torch.optim.lr_scheduler.LRScheduler) |
|
0 commit comments