55
66from __future__ import annotations
77
8+ from enum import Enum
89from pathlib import Path
910from typing import Callable
1011
12+ import cv2
13+ import numpy as np
1114import torch
1215from anomalib .data .utils import masks_to_boxes
1316from datumaro import Dataset as DmDataset
14- from datumaro import Image
17+ from datumaro import DatasetItem , Image
18+ from datumaro .components .annotation import AnnotationType , Bbox , Ellipse , Polygon
19+ from datumaro .components .media import ImageFromBytes , ImageFromFile
1520from torchvision import io
1621from torchvision .tv_tensors import BoundingBoxes , BoundingBoxFormat , Mask
1722
3136from otx .core .types .task import OTXTaskType
3237
3338
39+ class AnomalyLabel (Enum ):
40+ """Anomaly label to tensor mapping."""
41+
42+ NORMAL = torch .tensor (0.0 )
43+ ANOMALOUS = torch .tensor (1.0 )
44+
45+
3446class AnomalyDataset (OTXDataset ):
3547 """OTXDataset class for anomaly classification task."""
3648
@@ -58,6 +70,7 @@ def __init__(
5870 to_tv_image ,
5971 )
6072 self .label_info = AnomalyLabelInfo ()
73+ self ._label_mapping = self ._map_id_to_label ()
6174
6275 def _get_item_impl (
6376 self ,
@@ -67,12 +80,9 @@ def _get_item_impl(
6780 img = datumaro_item .media_as (Image )
6881 # returns image in RGB format if self.image_color_channel is RGB
6982 img_data , img_shape = self ._get_img_data_and_shape (img )
70- # Note: This assumes that the dataset is in MVTec format.
71- # We can't use datumaro label id as it returns some number like 3 for good from which it is hard to infer
72- # whether the image is Anomalous or Normal. Because it leads to other questions like what do numbers 0,1,2 mean?
73- label : torch .LongTensor = (
74- torch .tensor (0.0 , dtype = torch .long ) if "good" in datumaro_item .id else torch .tensor (1.0 , dtype = torch .long )
75- )
83+
84+ label = self ._get_label (datumaro_item )
85+
7686 item : AnomalyClassificationDataItem | AnomalySegmentationDataItem | AnomalyDetectionDataItem
7787 if self .task_type == OTXTaskType .ANOMALY_CLASSIFICATION :
7888 item = AnomalyClassificationDataItem (
@@ -88,15 +98,6 @@ def _get_item_impl(
8898 elif self .task_type == OTXTaskType .ANOMALY_SEGMENTATION :
8999 # Note: this part of code is brittle. Ideally Datumaro should return masks
90100 # Another major problem with this is that it assumes that the dataset passed is in MVTec format
91- mask_file_path = (
92- Path ("/" .join (datumaro_item .media .path .split ("/" )[:- 3 ]))
93- / "ground_truth"
94- / f"{ ('/' .join (datumaro_item .media .path .split ('/' )[- 2 :])).replace ('.png' ,'_mask.png' )} "
95- )
96- mask = torch .zeros (1 , img_shape [0 ], img_shape [1 ], dtype = torch .uint8 )
97- if mask_file_path .exists ():
98- # read and convert to binary mask
99- mask = (io .read_image (str (mask_file_path ), mode = io .ImageReadMode .GRAY ) / 255 ).to (torch .uint8 )
100101 item = AnomalySegmentationDataItem (
101102 image = img_data ,
102103 img_info = ImageInfo (
@@ -106,20 +107,9 @@ def _get_item_impl(
106107 image_color_channel = self .image_color_channel ,
107108 ),
108109 label = label ,
109- mask = Mask (mask ),
110+ mask = Mask (self . _get_mask ( datumaro_item , label , img_shape ) ),
110111 )
111112 elif self .task_type == OTXTaskType .ANOMALY_DETECTION :
112- # Note: this part of code is brittle. Ideally Datumaro should return masks
113- mask_file_path = (
114- Path ("/" .join (datumaro_item .media .path .split ("/" )[:- 3 ]))
115- / "ground_truth"
116- / f"{ ('/' .join (datumaro_item .media .path .split ('/' )[- 2 :])).replace ('.png' ,'_mask.png' )} "
117- )
118- mask = torch .zeros (1 , img_shape [0 ], img_shape [1 ], dtype = torch .uint8 )
119- if mask_file_path .exists ():
120- # read and convert to binary mask
121- mask = (io .read_image (str (mask_file_path ), mode = io .ImageReadMode .GRAY ) / 255 ).to (torch .uint8 )
122- boxes , _ = masks_to_boxes (mask )
123113 item = AnomalyDetectionDataItem (
124114 image = img_data ,
125115 img_info = ImageInfo (
@@ -129,9 +119,9 @@ def _get_item_impl(
129119 image_color_channel = self .image_color_channel ,
130120 ),
131121 label = label ,
132- boxes = BoundingBoxes ( boxes [ 0 ], format = BoundingBoxFormat . XYXY , canvas_size = img_shape ),
122+ boxes = self . _get_boxes ( datumaro_item , label , img_shape ),
133123 # mask is used for pixel-level metric computation. We can't assume that this will always be available
134- mask = Mask (mask ),
124+ mask = Mask (self . _get_mask ( datumaro_item , label , img_shape ) ),
135125 )
136126 else :
137127 msg = f"Task { self .task_type } is not supported yet."
@@ -142,6 +132,108 @@ def _get_item_impl(
142132 # "AnomalyClassificationDataItem | AnomalySegmentationDataBatch | AnomalyDetectionDataBatch")
143133 return self ._apply_transforms (item ) # type: ignore[return-value]
144134
135+ def _get_mask (self , datumaro_item : DatasetItem , label : torch .Tensor , img_shape : tuple [int , int ]) -> torch .Tensor :
136+ """Get mask from datumaro_item.
137+
138+ Converts bounding boxes to mask if mask is not available.
139+ """
140+ if isinstance (datumaro_item .media , ImageFromFile ):
141+ if label == AnomalyLabel .ANOMALOUS .value :
142+ mask = self ._mask_image_from_file (datumaro_item , img_shape )
143+ else :
144+ mask = torch .zeros (1 , * img_shape ).to (torch .uint8 )
145+ elif isinstance (datumaro_item .media , ImageFromBytes ):
146+ mask = torch .zeros (1 , * img_shape ).to (torch .uint8 )
147+ if label == AnomalyLabel .ANOMALOUS .value :
148+ for annotation in datumaro_item .annotations :
149+ # There is only one mask
150+ if isinstance (annotation , (Ellipse , Polygon )):
151+ polygons = np .asarray (annotation .as_polygon (), dtype = np .int32 ).reshape ((- 1 , 1 , 2 ))
152+ mask = np .zeros (img_shape , dtype = np .uint8 )
153+ mask = cv2 .drawContours (
154+ mask ,
155+ [polygons ],
156+ 0 ,
157+ (1 , 1 , 1 ),
158+ thickness = cv2 .FILLED ,
159+ )
160+ mask = torch .from_numpy (mask ).to (torch .uint8 ).unsqueeze (0 )
161+ break
162+ # If there is no mask, create a mask from bbox
163+ if isinstance (annotation , Bbox ):
164+ bbox = annotation
165+ mask = self ._bbox_to_mask (bbox , img_shape )
166+ break
167+ return mask
168+
169+ def _get_boxes (self , datumaro_item : DatasetItem , label : torch .Tensor , img_shape : tuple [int , int ]) -> BoundingBoxes :
170+ """Get bounding boxes from datumaro item.
171+
172+ Uses masks if available to get bounding boxes.
173+ """
174+ boxes = BoundingBoxes (torch .empty (0 , 4 ), format = BoundingBoxFormat .XYXY , canvas_size = img_shape )
175+ if isinstance (datumaro_item .media , ImageFromFile ):
176+ if label == AnomalyLabel .ANOMALOUS .value :
177+ mask = self ._mask_image_from_file (datumaro_item , img_shape )
178+ boxes , _ = masks_to_boxes (mask )
179+ # Assumes only one bounding box is present
180+ boxes = BoundingBoxes (boxes [0 ], format = BoundingBoxFormat .XYXY , canvas_size = img_shape )
181+ elif isinstance (datumaro_item .media , ImageFromBytes ) and label == AnomalyLabel .ANOMALOUS .value :
182+ for annotation in datumaro_item .annotations :
183+ if isinstance (annotation , Bbox ):
184+ bbox = annotation
185+ boxes = BoundingBoxes (bbox .get_bbox (), format = BoundingBoxFormat .XYXY , canvas_size = img_shape )
186+ break
187+ return boxes
188+
189+ def _bbox_to_mask (self , bbox : Bbox , img_shape : tuple [int , int ]) -> torch .Tensor :
190+ mask = torch .zeros (1 , * img_shape ).to (torch .uint8 )
191+ x1 , y1 , x2 , y2 = bbox .get_bbox ()
192+ x1 , y1 , x2 , y2 = int (x1 ), int (y1 ), int (x2 ), int (y2 )
193+ mask [:, y1 :y2 , x1 :x2 ] = 1
194+ return mask
195+
196+ def _get_label (self , datumaro_item : DatasetItem ) -> torch .LongTensor :
197+ """Get label from datumaro item."""
198+ if isinstance (datumaro_item .media , ImageFromFile ):
199+ # Note: This assumes that the dataset is in MVTec format.
200+ # We can't use datumaro label id as it returns some number like 3 for good from which it is hard to infer
201+ # whether the image is Anomalous or Normal. Because it leads to other questions like what do numbers 0,1,2
202+ # mean?
203+ label : torch .LongTensor = AnomalyLabel .NORMAL if "good" in datumaro_item .id else AnomalyLabel .ANOMALOUS
204+ elif isinstance (datumaro_item .media , ImageFromBytes ):
205+ label = self ._label_mapping [datumaro_item .annotations [0 ].label ]
206+ else :
207+ msg = f"Media type { type (datumaro_item .media )} is not supported."
208+ raise NotImplementedError (msg )
209+ return label .value
210+
211+ def _map_id_to_label (self ) -> dict [int , torch .Tensor ]:
212+ """Map label id to label tensor."""
213+ id_label_mapping = {}
214+ categories = self .dm_subset .categories ()[AnnotationType .label ]
215+ for label_item in categories .items :
216+ if any ("normal" in attribute .lower () for attribute in label_item .attributes ):
217+ label = AnomalyLabel .NORMAL
218+ else :
219+ label = AnomalyLabel .ANOMALOUS
220+ id_label_mapping [categories .find (label_item .name )[0 ]] = label
221+ return id_label_mapping
222+
223+ def _mask_image_from_file (self , datumaro_item : DatasetItem , img_shape : tuple [int , int ]) -> torch .Tensor :
224+ """Assumes MVTec format and returns mask from disk."""
225+ mask_file_path = (
226+ Path ("/" .join (datumaro_item .media .path .split ("/" )[:- 3 ]))
227+ / "ground_truth"
228+ / f"{ ('/' .join (datumaro_item .media .path .split ('/' )[- 2 :])).replace ('.png' ,'_mask.png' )} "
229+ )
230+ if mask_file_path .exists ():
231+ return (io .read_image (str (mask_file_path ), mode = io .ImageReadMode .GRAY ) / 255 ).to (torch .uint8 )
232+
233+ # Note: This is a workaround to handle the case where mask is not available otherwise the tests fail.
234+ # This is problematic because it assigns empty masks to an Anomalous image.
235+ return torch .zeros (1 , * img_shape ).to (torch .uint8 )
236+
145237 @property
146238 def collate_fn (self ) -> Callable :
147239 """Collection function to collect SegDataEntity into SegBatchDataEntity in data loader."""
0 commit comments