Skip to content

Commit c7efcbc

Browse files
Support ImageFromBytes (#3948)
* add image_from_bytes Signed-off-by: Ashwin Vaidya <[email protected]> * refactor code Signed-off-by: Ashwin Vaidya <[email protected]> * allow empty anomalous masks Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 0f87c86 commit c7efcbc

File tree

2 files changed

+128
-42
lines changed

2 files changed

+128
-42
lines changed

src/otx/core/data/dataset/anomaly.py

Lines changed: 122 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55

66
from __future__ import annotations
77

8+
from enum import Enum
89
from pathlib import Path
910
from typing import Callable
1011

12+
import cv2
13+
import numpy as np
1114
import torch
1215
from anomalib.data.utils import masks_to_boxes
1316
from datumaro import Dataset as DmDataset
14-
from datumaro import Image
17+
from datumaro import DatasetItem, Image
18+
from datumaro.components.annotation import AnnotationType, Bbox, Ellipse, Polygon
19+
from datumaro.components.media import ImageFromBytes, ImageFromFile
1520
from torchvision import io
1621
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Mask
1722

@@ -31,6 +36,13 @@
3136
from otx.core.types.task import OTXTaskType
3237

3338

39+
class AnomalyLabel(Enum):
40+
"""Anomaly label to tensor mapping."""
41+
42+
NORMAL = torch.tensor(0.0)
43+
ANOMALOUS = torch.tensor(1.0)
44+
45+
3446
class AnomalyDataset(OTXDataset):
3547
"""OTXDataset class for anomaly classification task."""
3648

@@ -58,6 +70,7 @@ def __init__(
5870
to_tv_image,
5971
)
6072
self.label_info = AnomalyLabelInfo()
73+
self._label_mapping = self._map_id_to_label()
6174

6275
def _get_item_impl(
6376
self,
@@ -67,12 +80,9 @@ def _get_item_impl(
6780
img = datumaro_item.media_as(Image)
6881
# returns image in RGB format if self.image_color_channel is RGB
6982
img_data, img_shape = self._get_img_data_and_shape(img)
70-
# Note: This assumes that the dataset is in MVTec format.
71-
# We can't use datumaro label id as it returns some number like 3 for good from which it is hard to infer
72-
# whether the image is Anomalous or Normal. Because it leads to other questions like what do numbers 0,1,2 mean?
73-
label: torch.LongTensor = (
74-
torch.tensor(0.0, dtype=torch.long) if "good" in datumaro_item.id else torch.tensor(1.0, dtype=torch.long)
75-
)
83+
84+
label = self._get_label(datumaro_item)
85+
7686
item: AnomalyClassificationDataItem | AnomalySegmentationDataItem | AnomalyDetectionDataItem
7787
if self.task_type == OTXTaskType.ANOMALY_CLASSIFICATION:
7888
item = AnomalyClassificationDataItem(
@@ -88,15 +98,6 @@ def _get_item_impl(
8898
elif self.task_type == OTXTaskType.ANOMALY_SEGMENTATION:
8999
# Note: this part of code is brittle. Ideally Datumaro should return masks
90100
# Another major problem with this is that it assumes that the dataset passed is in MVTec format
91-
mask_file_path = (
92-
Path("/".join(datumaro_item.media.path.split("/")[:-3]))
93-
/ "ground_truth"
94-
/ f"{('/'.join(datumaro_item.media.path.split('/')[-2:])).replace('.png','_mask.png')}"
95-
)
96-
mask = torch.zeros(1, img_shape[0], img_shape[1], dtype=torch.uint8)
97-
if mask_file_path.exists():
98-
# read and convert to binary mask
99-
mask = (io.read_image(str(mask_file_path), mode=io.ImageReadMode.GRAY) / 255).to(torch.uint8)
100101
item = AnomalySegmentationDataItem(
101102
image=img_data,
102103
img_info=ImageInfo(
@@ -106,20 +107,9 @@ def _get_item_impl(
106107
image_color_channel=self.image_color_channel,
107108
),
108109
label=label,
109-
mask=Mask(mask),
110+
mask=Mask(self._get_mask(datumaro_item, label, img_shape)),
110111
)
111112
elif self.task_type == OTXTaskType.ANOMALY_DETECTION:
112-
# Note: this part of code is brittle. Ideally Datumaro should return masks
113-
mask_file_path = (
114-
Path("/".join(datumaro_item.media.path.split("/")[:-3]))
115-
/ "ground_truth"
116-
/ f"{('/'.join(datumaro_item.media.path.split('/')[-2:])).replace('.png','_mask.png')}"
117-
)
118-
mask = torch.zeros(1, img_shape[0], img_shape[1], dtype=torch.uint8)
119-
if mask_file_path.exists():
120-
# read and convert to binary mask
121-
mask = (io.read_image(str(mask_file_path), mode=io.ImageReadMode.GRAY) / 255).to(torch.uint8)
122-
boxes, _ = masks_to_boxes(mask)
123113
item = AnomalyDetectionDataItem(
124114
image=img_data,
125115
img_info=ImageInfo(
@@ -129,9 +119,9 @@ def _get_item_impl(
129119
image_color_channel=self.image_color_channel,
130120
),
131121
label=label,
132-
boxes=BoundingBoxes(boxes[0], format=BoundingBoxFormat.XYXY, canvas_size=img_shape),
122+
boxes=self._get_boxes(datumaro_item, label, img_shape),
133123
# mask is used for pixel-level metric computation. We can't assume that this will always be available
134-
mask=Mask(mask),
124+
mask=Mask(self._get_mask(datumaro_item, label, img_shape)),
135125
)
136126
else:
137127
msg = f"Task {self.task_type} is not supported yet."
@@ -142,6 +132,108 @@ def _get_item_impl(
142132
# "AnomalyClassificationDataItem | AnomalySegmentationDataBatch | AnomalyDetectionDataBatch")
143133
return self._apply_transforms(item) # type: ignore[return-value]
144134

135+
def _get_mask(self, datumaro_item: DatasetItem, label: torch.Tensor, img_shape: tuple[int, int]) -> torch.Tensor:
136+
"""Get mask from datumaro_item.
137+
138+
Converts bounding boxes to mask if mask is not available.
139+
"""
140+
if isinstance(datumaro_item.media, ImageFromFile):
141+
if label == AnomalyLabel.ANOMALOUS.value:
142+
mask = self._mask_image_from_file(datumaro_item, img_shape)
143+
else:
144+
mask = torch.zeros(1, *img_shape).to(torch.uint8)
145+
elif isinstance(datumaro_item.media, ImageFromBytes):
146+
mask = torch.zeros(1, *img_shape).to(torch.uint8)
147+
if label == AnomalyLabel.ANOMALOUS.value:
148+
for annotation in datumaro_item.annotations:
149+
# There is only one mask
150+
if isinstance(annotation, (Ellipse, Polygon)):
151+
polygons = np.asarray(annotation.as_polygon(), dtype=np.int32).reshape((-1, 1, 2))
152+
mask = np.zeros(img_shape, dtype=np.uint8)
153+
mask = cv2.drawContours(
154+
mask,
155+
[polygons],
156+
0,
157+
(1, 1, 1),
158+
thickness=cv2.FILLED,
159+
)
160+
mask = torch.from_numpy(mask).to(torch.uint8).unsqueeze(0)
161+
break
162+
# If there is no mask, create a mask from bbox
163+
if isinstance(annotation, Bbox):
164+
bbox = annotation
165+
mask = self._bbox_to_mask(bbox, img_shape)
166+
break
167+
return mask
168+
169+
def _get_boxes(self, datumaro_item: DatasetItem, label: torch.Tensor, img_shape: tuple[int, int]) -> BoundingBoxes:
170+
"""Get bounding boxes from datumaro item.
171+
172+
Uses masks if available to get bounding boxes.
173+
"""
174+
boxes = BoundingBoxes(torch.empty(0, 4), format=BoundingBoxFormat.XYXY, canvas_size=img_shape)
175+
if isinstance(datumaro_item.media, ImageFromFile):
176+
if label == AnomalyLabel.ANOMALOUS.value:
177+
mask = self._mask_image_from_file(datumaro_item, img_shape)
178+
boxes, _ = masks_to_boxes(mask)
179+
# Assumes only one bounding box is present
180+
boxes = BoundingBoxes(boxes[0], format=BoundingBoxFormat.XYXY, canvas_size=img_shape)
181+
elif isinstance(datumaro_item.media, ImageFromBytes) and label == AnomalyLabel.ANOMALOUS.value:
182+
for annotation in datumaro_item.annotations:
183+
if isinstance(annotation, Bbox):
184+
bbox = annotation
185+
boxes = BoundingBoxes(bbox.get_bbox(), format=BoundingBoxFormat.XYXY, canvas_size=img_shape)
186+
break
187+
return boxes
188+
189+
def _bbox_to_mask(self, bbox: Bbox, img_shape: tuple[int, int]) -> torch.Tensor:
190+
mask = torch.zeros(1, *img_shape).to(torch.uint8)
191+
x1, y1, x2, y2 = bbox.get_bbox()
192+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
193+
mask[:, y1:y2, x1:x2] = 1
194+
return mask
195+
196+
def _get_label(self, datumaro_item: DatasetItem) -> torch.LongTensor:
197+
"""Get label from datumaro item."""
198+
if isinstance(datumaro_item.media, ImageFromFile):
199+
# Note: This assumes that the dataset is in MVTec format.
200+
# We can't use datumaro label id as it returns some number like 3 for good from which it is hard to infer
201+
# whether the image is Anomalous or Normal. Because it leads to other questions like what do numbers 0,1,2
202+
# mean?
203+
label: torch.LongTensor = AnomalyLabel.NORMAL if "good" in datumaro_item.id else AnomalyLabel.ANOMALOUS
204+
elif isinstance(datumaro_item.media, ImageFromBytes):
205+
label = self._label_mapping[datumaro_item.annotations[0].label]
206+
else:
207+
msg = f"Media type {type(datumaro_item.media)} is not supported."
208+
raise NotImplementedError(msg)
209+
return label.value
210+
211+
def _map_id_to_label(self) -> dict[int, torch.Tensor]:
212+
"""Map label id to label tensor."""
213+
id_label_mapping = {}
214+
categories = self.dm_subset.categories()[AnnotationType.label]
215+
for label_item in categories.items:
216+
if any("normal" in attribute.lower() for attribute in label_item.attributes):
217+
label = AnomalyLabel.NORMAL
218+
else:
219+
label = AnomalyLabel.ANOMALOUS
220+
id_label_mapping[categories.find(label_item.name)[0]] = label
221+
return id_label_mapping
222+
223+
def _mask_image_from_file(self, datumaro_item: DatasetItem, img_shape: tuple[int, int]) -> torch.Tensor:
224+
"""Assumes MVTec format and returns mask from disk."""
225+
mask_file_path = (
226+
Path("/".join(datumaro_item.media.path.split("/")[:-3]))
227+
/ "ground_truth"
228+
/ f"{('/'.join(datumaro_item.media.path.split('/')[-2:])).replace('.png','_mask.png')}"
229+
)
230+
if mask_file_path.exists():
231+
return (io.read_image(str(mask_file_path), mode=io.ImageReadMode.GRAY) / 255).to(torch.uint8)
232+
233+
# Note: This is a workaround to handle the case where mask is not available otherwise the tests fail.
234+
# This is problematic because it assigns empty masks to an Anomalous image.
235+
return torch.zeros(1, *img_shape).to(torch.uint8)
236+
145237
@property
146238
def collate_fn(self) -> Callable:
147239
"""Collection function to collect SegDataEntity into SegBatchDataEntity in data loader."""

src/otx/core/model/anomaly.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -180,18 +180,12 @@ def _customize_inputs(
180180
inputs: AnomalyModelInputs,
181181
) -> dict[str, Any]:
182182
"""Customize inputs for the model."""
183-
return_dict = {}
184-
if isinstance(inputs, AnomalyClassificationDataBatch):
185-
return_dict = {"image": inputs.images, "label": torch.vstack(inputs.labels).squeeze()}
186-
if isinstance(inputs, AnomalySegmentationDataBatch):
187-
return_dict = {"image": inputs.images, "label": torch.vstack(inputs.labels).squeeze(), "mask": inputs.masks}
188-
if isinstance(inputs, AnomalyDetectionDataBatch):
189-
return_dict = {
190-
"image": inputs.images,
191-
"label": torch.vstack(inputs.labels).squeeze(),
192-
"mask": inputs.masks,
193-
"boxes": inputs.boxes,
194-
}
183+
return_dict = {"image": inputs.images, "label": torch.vstack(inputs.labels).squeeze()}
184+
if isinstance(inputs, AnomalySegmentationDataBatch) and inputs.masks is not None:
185+
return_dict["mask"] = inputs.masks
186+
if isinstance(inputs, AnomalyDetectionDataBatch) and inputs.masks is not None and inputs.boxes is not None:
187+
return_dict["mask"] = inputs.masks
188+
return_dict["boxes"] = inputs.boxes
195189

196190
if return_dict["label"].size() == torch.Size([]): # when last batch size is 1
197191
return_dict["label"] = return_dict["label"].unsqueeze(0)

0 commit comments

Comments
 (0)