Skip to content

Commit e3b73a2

Browse files
authored
Fix label info dispatching (#4443)
* fix loading checkpoints for the fine tuning * add test to check label dispatch
1 parent a686e0f commit e3b73a2

File tree

4 files changed

+114
-8
lines changed

4 files changed

+114
-8
lines changed

src/otx/core/model/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import warnings
1414
from abc import abstractmethod
1515
from dataclasses import asdict
16-
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
16+
from typing import TYPE_CHECKING, Any, Callable, Literal
1717

1818
import numpy as np
1919
import openvino
@@ -410,8 +410,7 @@ def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -
410410
msg = "Checkpoint should have `label_info`."
411411
raise ValueError(msg, ckpt_label_info)
412412

413-
if isinstance(ckpt_label_info, dict):
414-
ckpt_label_info = LabelInfo(**ckpt_label_info)
413+
ckpt_label_info = self._dispatch_label_info(ckpt_label_info)
415414

416415
if not hasattr(ckpt_label_info, "label_ids"):
417416
msg = "Loading checkpoint from OTX < 2.2.1, label_ids are assigned automatically"
@@ -840,7 +839,7 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
840839
return LabelInfo(**label_info)
841840
if isinstance(label_info, int):
842841
return LabelInfo.from_num_classes(num_classes=label_info)
843-
if isinstance(label_info, Sequence) and all(isinstance(name, str) for name in label_info):
842+
if isinstance(label_info, (list, tuple)) and all(isinstance(name, str) for name in label_info):
844843
return LabelInfo(
845844
label_names=label_info,
846845
label_groups=[label_info],

src/otx/core/model/segmentation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
3030
from otx.core.schedulers import LRSchedulerListCallable
3131
from otx.core.types.export import TaskLevelExportParameters
32-
from otx.core.types.label import LabelInfo, LabelInfoTypes, SegLabelInfo
32+
from otx.core.types.label import LabelInfoTypes, SegLabelInfo
3333
from otx.core.utils.tile_merge import SegmentationTileMerge
3434

3535
if TYPE_CHECKING:
@@ -192,7 +192,7 @@ def _convert_pred_entity_to_compute_metric(
192192
]
193193

194194
@staticmethod
195-
def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
195+
def _dispatch_label_info(label_info: LabelInfoTypes) -> SegLabelInfo:
196196
if isinstance(label_info, dict):
197197
if "label_ids" not in label_info:
198198
# NOTE: This is for backward compatibility

src/otx/core/types/label.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ class SegLabelInfo(LabelInfo):
420420
ignore_index: int = 255
421421

422422
@classmethod
423-
def from_num_classes(cls, num_classes: int) -> LabelInfo:
423+
def from_num_classes(cls, num_classes: int) -> SegLabelInfo:
424424
"""Create this object from the number of classes.
425425
426426
Args:
@@ -437,7 +437,7 @@ def from_num_classes(cls, num_classes: int) -> LabelInfo:
437437
label_names = ["background", "label_0"]
438438
return SegLabelInfo(label_names=label_names, label_groups=[label_names], label_ids=["0", "1"])
439439

440-
return super().from_num_classes(num_classes)
440+
return super().from_num_classes(num_classes) # type: ignore[return-value]
441441

442442

443443
@dataclass

tests/unit/core/model/test_base.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@
77
import openvino as ov
88
import pytest
99
import torch
10+
from datumaro import LabelCategories
11+
from datumaro.components.annotation import GroupType
1012
from lightning import Trainer
1113
from lightning.pytorch.utilities.types import LRSchedulerConfig
1214
from model_api.models.result import ClassificationResult
1315
from pytest_mock import MockerFixture
1416

1517
from otx.core.data.entity.base import OTXBatchDataEntity
1618
from otx.core.model.base import OTXModel, OVModel
19+
from otx.core.model.classification import OTXHlabelClsModel, OTXMulticlassClsModel
20+
from otx.core.model.segmentation import OTXSegmentationModel
1721
from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler
22+
from otx.core.types.label import HLabelInfo, LabelInfo, SegLabelInfo
1823
from tests.unit.core.utils.test_utils import get_dummy_ov_cls_model
1924

2025

@@ -83,6 +88,108 @@ def test_smart_weight_loading(self, mocker) -> None:
8388
prev_state_dict["model.head.bias"],
8489
)
8590

91+
def test_label_info_dispatch(self, mocker):
92+
with mocker.patch.object(OTXModel, "_create_model", return_value=MockNNModule(3)):
93+
with pytest.raises(TypeError, match="invalid_label_info"):
94+
OTXModel(label_info="invalid_label_info")
95+
96+
# Test with LabelInfo
97+
label_info = OTXModel(
98+
label_info=LabelInfo(
99+
["label_1", "label_2"],
100+
label_ids=["1", "2"],
101+
label_groups=[["label_1", "label_2"]],
102+
),
103+
)
104+
assert isinstance(label_info.label_info, LabelInfo)
105+
106+
# Test with SegLabelInfo
107+
seg_label_info = OTXModel(label_info=SegLabelInfo.from_num_classes(3))
108+
assert isinstance(seg_label_info.label_info, SegLabelInfo)
109+
110+
with mocker.patch.object(OTXMulticlassClsModel, "_create_model", return_value=MockNNModule(3)):
111+
# Test simple Classfication model loading checkpoint
112+
cls_model = OTXMulticlassClsModel(
113+
label_info=LabelInfo(
114+
["label_1", "label_2"],
115+
label_ids=["1", "2"],
116+
label_groups=[["label_1", "label_2"]],
117+
),
118+
input_size=(224, 224),
119+
)
120+
label_info_dict = {
121+
"label_ids": ["1", "2"],
122+
"label_names": ["label_1", "label_2"],
123+
"label_groups": [["label_1", "label_2"]],
124+
}
125+
cls_model.load_state_dict_incrementally(
126+
{"state_dict": cls_model.state_dict(), "hyper_parameters": {"label_info": label_info_dict}},
127+
)
128+
assert isinstance(cls_model.label_info, LabelInfo)
129+
# test if ignore_index is not set
130+
label_info_dict["ignore_index"] = 255
131+
with pytest.raises(TypeError, match=r"unexpected keyword argument.*ignore_index"):
132+
cls_model.load_state_dict_incrementally(
133+
{"state_dict": cls_model.state_dict(), "hyper_parameters": {"label_info": label_info_dict}},
134+
)
135+
136+
with mocker.patch.object(OTXSegmentationModel, "_create_model", return_value=MockNNModule(3)):
137+
# test segmentation model loading checkpoint with SegLabelInfo
138+
segmentation_model = OTXSegmentationModel(
139+
label_info=SegLabelInfo.from_num_classes(3),
140+
input_size=(224, 224),
141+
model_name="segmentation_model",
142+
)
143+
segmentation_model.load_state_dict_incrementally(
144+
{"state_dict": segmentation_model.state_dict(), "hyper_parameters": {"label_info": label_info_dict}},
145+
)
146+
assert isinstance(segmentation_model.label_info, SegLabelInfo)
147+
assert hasattr(segmentation_model.label_info, "ignore_index")
148+
assert segmentation_model.label_info.ignore_index == 255
149+
150+
# test hlabel classification model loading checkpoint with HLabelInfo
151+
labels = [
152+
LabelCategories.Category(name="car", parent="vehicle"),
153+
LabelCategories.Category(name="truck", parent="vehicle"),
154+
LabelCategories.Category(name="plush toy", parent="plush toy"),
155+
LabelCategories.Category(name="No class"),
156+
]
157+
label_groups = [
158+
LabelCategories.LabelGroup(
159+
name="Detection labels___vehicle",
160+
labels=["car", "truck"],
161+
group_type=GroupType.EXCLUSIVE,
162+
),
163+
LabelCategories.LabelGroup(
164+
name="Detection labels___plush toy",
165+
labels=["plush toy"],
166+
group_type=GroupType.EXCLUSIVE,
167+
),
168+
LabelCategories.LabelGroup(name="No class", labels=["No class"], group_type=GroupType.RESTRICTED),
169+
]
170+
dm_label_categories = LabelCategories(items=labels, label_groups=label_groups)
171+
hlabel_info = HLabelInfo.from_dm_label_groups(dm_label_categories)
172+
hlabel_dict_label_info = hlabel_info.as_dict(normalize_label_names=True)
173+
174+
with mocker.patch.object(OTXHlabelClsModel, "_create_model", return_value=MockNNModule(3)):
175+
hlabel_model = OTXHlabelClsModel(hlabel_dict_label_info, input_size=(224, 224))
176+
hlabel_model.load_state_dict_incrementally(
177+
{"state_dict": hlabel_model.state_dict(), "hyper_parameters": {"label_info": hlabel_dict_label_info}},
178+
)
179+
180+
with pytest.raises(TypeError, match=r"unexpected keyword argument.*num_multiclass_heads"):
181+
segmentation_model.load_state_dict_incrementally(
182+
{
183+
"state_dict": segmentation_model.state_dict(),
184+
"hyper_parameters": {"label_info": hlabel_dict_label_info},
185+
},
186+
)
187+
188+
with pytest.raises(TypeError, match=r"unexpected keyword argument.*num_multiclass_heads"):
189+
cls_model.load_state_dict_incrementally(
190+
{"state_dict": cls_model.state_dict(), "hyper_parameters": {"label_info": hlabel_dict_label_info}},
191+
)
192+
86193
def test_lr_scheduler_step(self, mocker: MockerFixture) -> None:
87194
mock_linear_warmup_scheduler = mocker.create_autospec(spec=LinearWarmupScheduler)
88195
mock_main_scheduler = mocker.create_autospec(spec=torch.optim.lr_scheduler.LRScheduler)

0 commit comments

Comments
 (0)