Skip to content

Commit 53a7d9a

Browse files
authored
Fix get_item for Chained Tasks in Classification (#3931)
* Fix Task Chain * Add multi-label case as well * Add multi-label case as well2 * Add H-label case
1 parent 706f99b commit 53a7d9a

File tree

3 files changed

+159
-4
lines changed

3 files changed

+159
-4
lines changed

src/otx/core/data/dataset/classification.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,16 @@ def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None:
3434
img = item.media_as(Image)
3535
img_data, img_shape = self._get_img_data_and_shape(img)
3636

37-
label_anns = [ann for ann in item.annotations if isinstance(ann, Label)]
37+
label_anns = []
38+
for ann in item.annotations:
39+
if isinstance(ann, Label):
40+
label_anns.append(ann)
41+
else:
42+
# If the annotation is not Label, it should be converted to Label.
43+
# For Chained Task: Detection (Bbox) -> Classification (Label)
44+
label = Label(label=ann.label)
45+
if label not in label_anns:
46+
label_anns.append(label)
3847
if len(label_anns) > 1:
3948
msg = f"Multi-class Classification can't use the multi-label, currently len(labels) = {len(label_anns)}"
4049
raise ValueError(msg)
@@ -71,7 +80,16 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None:
7180
ignored_labels: list[int] = [] # This should be assigned form item
7281
img_data, img_shape = self._get_img_data_and_shape(img)
7382

74-
label_anns = [ann for ann in item.annotations if isinstance(ann, Label)]
83+
label_anns = []
84+
for ann in item.annotations:
85+
if isinstance(ann, Label):
86+
label_anns.append(ann)
87+
else:
88+
# If the annotation is not Label, it should be converted to Label.
89+
# For Chained Task: Detection (Bbox) -> Classification (Label)
90+
label = Label(label=ann.label)
91+
if label not in label_anns:
92+
label_anns.append(label)
7593
labels = torch.as_tensor([ann.label for ann in label_anns])
7694

7795
entity = MultilabelClsDataEntity(
@@ -179,7 +197,16 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None:
179197
ignored_labels: list[int] = [] # This should be assigned form item
180198
img_data, img_shape = self._get_img_data_and_shape(img)
181199

182-
label_anns = [ann for ann in item.annotations if isinstance(ann, Label)]
200+
label_anns = []
201+
for ann in item.annotations:
202+
if isinstance(ann, Label):
203+
label_anns.append(ann)
204+
else:
205+
# If the annotation is not Label, it should be converted to Label.
206+
# For Chained Task: Detection (Bbox) -> Classification (Label)
207+
label = Label(label=ann.label)
208+
if label not in label_anns:
209+
label_anns.append(label)
183210
hlabel_labels = self._convert_label_to_hlabel_format(label_anns, ignored_labels)
184211

185212
entity = HlabelClsDataEntity(

tests/unit/core/data/conftest.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,37 @@ def fxt_dm_item(request, tmpdir) -> DatasetItem:
9696
)
9797

9898

99+
@pytest.fixture(params=["bytes", "file"])
100+
def fxt_dm_item_bbox_only(request, tmpdir) -> DatasetItem:
101+
np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8)
102+
np_img[:, :, 0] = 0 # Set 0 for B channel
103+
np_img[:, :, 1] = 1 # Set 1 for G channel
104+
np_img[:, :, 2] = 2 # Set 2 for R channel
105+
106+
if request.param == "bytes":
107+
_, np_bytes = cv2.imencode(".png", np_img)
108+
media = Image.from_bytes(np_bytes.tobytes())
109+
media.path = ""
110+
elif request.param == "file":
111+
fname = str(uuid.uuid4())
112+
fpath = str(Path(tmpdir) / f"{fname}.png")
113+
cv2.imwrite(fpath, np_img)
114+
media = Image.from_file(fpath)
115+
else:
116+
raise ValueError(request.param)
117+
118+
return DatasetItem(
119+
id="item",
120+
subset="train",
121+
media=media,
122+
annotations=[
123+
Bbox(x=0, y=0, w=1, h=1, label=0),
124+
Bbox(x=1, y=0, w=1, h=1, label=0),
125+
Bbox(x=1, y=1, w=1, h=1, label=0),
126+
],
127+
)
128+
129+
99130
@pytest.fixture()
100131
def fxt_mock_dm_subset(mocker: MockerFixture, fxt_dm_item: DatasetItem) -> MagicMock:
101132
mock_dm_subset = mocker.MagicMock(spec=DmDataset)
@@ -105,6 +136,15 @@ def fxt_mock_dm_subset(mocker: MockerFixture, fxt_dm_item: DatasetItem) -> Magic
105136
return mock_dm_subset
106137

107138

139+
@pytest.fixture()
140+
def fxt_mock_det_dm_subset(mocker: MockerFixture, fxt_dm_item_bbox_only: DatasetItem) -> MagicMock:
141+
mock_dm_subset = mocker.MagicMock(spec=DmDataset)
142+
mock_dm_subset.__getitem__.return_value = fxt_dm_item_bbox_only
143+
mock_dm_subset.__len__.return_value = 1
144+
mock_dm_subset.categories().__getitem__.return_value = LabelCategories.from_iterable(_LABEL_NAMES)
145+
return mock_dm_subset
146+
147+
108148
@pytest.fixture(
109149
params=[
110150
(OTXHlabelClsDataset, HlabelClsDataEntity, {}),

tests/unit/core/data/dataset/test_classification.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,65 @@
55

66
from unittest.mock import MagicMock
77

8-
from otx.core.data.dataset.classification import OTXHlabelClsDataset
8+
from otx.core.data.dataset.classification import (
9+
HLabelInfo,
10+
OTXHlabelClsDataset,
11+
OTXMulticlassClsDataset,
12+
OTXMultilabelClsDataset,
13+
)
14+
from otx.core.data.entity.classification import HlabelClsDataEntity, MulticlassClsDataEntity, MultilabelClsDataEntity
15+
16+
17+
class TestOTXMulticlassClsDataset:
18+
def test_get_item(
19+
self,
20+
fxt_mock_dm_subset,
21+
) -> None:
22+
dataset = OTXMulticlassClsDataset(
23+
dm_subset=fxt_mock_dm_subset,
24+
transforms=[lambda x: x],
25+
mem_cache_img_max_size=None,
26+
max_refetch=3,
27+
)
28+
assert isinstance(dataset[0], MulticlassClsDataEntity)
29+
30+
def test_get_item_from_bbox_dataset(
31+
self,
32+
fxt_mock_det_dm_subset,
33+
) -> None:
34+
dataset = OTXMulticlassClsDataset(
35+
dm_subset=fxt_mock_det_dm_subset,
36+
transforms=[lambda x: x],
37+
mem_cache_img_max_size=None,
38+
max_refetch=3,
39+
)
40+
assert isinstance(dataset[0], MulticlassClsDataEntity)
41+
42+
43+
class TestOTXMultilabelClsDataset:
44+
def test_get_item(
45+
self,
46+
fxt_mock_dm_subset,
47+
) -> None:
48+
dataset = OTXMultilabelClsDataset(
49+
dm_subset=fxt_mock_dm_subset,
50+
transforms=[lambda x: x],
51+
mem_cache_img_max_size=None,
52+
max_refetch=3,
53+
)
54+
assert isinstance(dataset[0], MultilabelClsDataEntity)
55+
56+
def test_get_item_from_bbox_dataset(
57+
self,
58+
fxt_mock_det_dm_subset,
59+
) -> None:
60+
dataset = OTXMultilabelClsDataset(
61+
dm_subset=fxt_mock_det_dm_subset,
62+
transforms=[lambda x: x],
63+
mem_cache_img_max_size=None,
64+
max_refetch=3,
65+
)
66+
assert isinstance(dataset[0], MultilabelClsDataEntity)
967

1068

1169
class TestOTXHlabelClsDataset:
@@ -20,3 +78,33 @@ def test_add_ancestors(self, fxt_hlabel_dataset_subset):
2078
# Added the ancestor
2179
adjusted_anns = hlabel_dataset.dm_subset.get(id=0, subset="train").annotations
2280
assert len(adjusted_anns) == 2
81+
82+
def test_get_item(
83+
self,
84+
mocker,
85+
fxt_mock_dm_subset,
86+
fxt_mock_hlabelinfo,
87+
) -> None:
88+
mocker.patch.object(HLabelInfo, "from_dm_label_groups", return_value=fxt_mock_hlabelinfo)
89+
dataset = OTXHlabelClsDataset(
90+
dm_subset=fxt_mock_dm_subset,
91+
transforms=[lambda x: x],
92+
mem_cache_img_max_size=None,
93+
max_refetch=3,
94+
)
95+
assert isinstance(dataset[0], HlabelClsDataEntity)
96+
97+
def test_get_item_from_bbox_dataset(
98+
self,
99+
mocker,
100+
fxt_mock_det_dm_subset,
101+
fxt_mock_hlabelinfo,
102+
) -> None:
103+
mocker.patch.object(HLabelInfo, "from_dm_label_groups", return_value=fxt_mock_hlabelinfo)
104+
dataset = OTXHlabelClsDataset(
105+
dm_subset=fxt_mock_det_dm_subset,
106+
transforms=[lambda x: x],
107+
mem_cache_img_max_size=None,
108+
max_refetch=3,
109+
)
110+
assert isinstance(dataset[0], HlabelClsDataEntity)

0 commit comments

Comments
 (0)