88from abc import abstractmethod
99from collections .abc import Iterable
1010from contextlib import contextmanager
11- from typing import TYPE_CHECKING , Callable , Generic , Iterator , List , Union
11+ from typing import TYPE_CHECKING , Any , Callable , Generic , Iterator , List , Union
1212
1313import cv2
1414import numpy as np
@@ -92,6 +92,7 @@ def __init__(
9292 self .image_color_channel = image_color_channel
9393 self .stack_images = stack_images
9494 self .to_tv_image = to_tv_image
95+
9596 if self .dm_subset .categories ():
9697 self .label_info = LabelInfo .from_dm_label_groups (self .dm_subset .categories ()[AnnotationType .label ])
9798 else :
@@ -141,11 +142,31 @@ def __getitem__(self, index: int) -> T_OTXDataEntity:
141142 msg = f"Reach the maximum refetch number ({ self .max_refetch } )"
142143 raise RuntimeError (msg )
143144
144- def _get_img_data_and_shape (self , img : Image ) -> tuple [np .ndarray , tuple [int , int ]]:
145+ def _get_img_data_and_shape (
146+ self ,
147+ img : Image ,
148+ roi : dict [str , Any ] | None = None ,
149+ ) -> tuple [np .ndarray , tuple [int , int ], dict [str , Any ] | None ]:
150+ """Get image data and shape.
151+
152+ This method is used to get image data and shape from Datumaro image object.
153+ If ROI is provided, the image data is extracted from the ROI.
154+
155+ Args:
156+ img (Image): Image object from Datumaro.
157+ roi (dict[str, Any] | None, Optional): Region of interest.
158+ Represented by dict with coordinates and some meta information.
159+
160+ Returns:
161+ The image data, shape, and ROI meta information
162+ """
145163 key = img .path if isinstance (img , ImageFromFile ) else id (img )
164+ roi_meta = None
146165
147- if (img_data := self .mem_cache_handler .get (key = key )[0 ]) is not None :
148- return img_data , img_data .shape [:2 ]
166+ # check if the image is already in the cache
167+ img_data , roi_meta = self .mem_cache_handler .get (key = key )
168+ if img_data is not None :
169+ return img_data , img_data .shape [:2 ], roi_meta
149170
150171 with image_decode_context ():
151172 img_data = (
@@ -158,11 +179,28 @@ def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, in
158179 msg = "Cannot get image data"
159180 raise RuntimeError (msg )
160181
161- img_data = self ._cache_img (key = key , img_data = img_data .astype (np .uint8 ))
182+ if roi :
183+ # extract ROI from image
184+ shape = roi ["shape" ]
185+ h , w = img_data .shape [:2 ]
186+ x1 , y1 , x2 , y2 = (
187+ int (np .clip (np .trunc (shape ["x1" ] * w ), 0 , w )),
188+ int (np .clip (np .trunc (shape ["y1" ] * h ), 0 , h )),
189+ int (np .clip (np .ceil (shape ["x2" ] * w ), 0 , w )),
190+ int (np .clip (np .ceil (shape ["y2" ] * h ), 0 , h )),
191+ )
192+ if (x2 - x1 ) * (y2 - y1 ) <= 0 :
193+ msg = f"ROI has zero or negative area. ROI coordinates: { x1 } , { y1 } , { x2 } , { y2 } "
194+ raise ValueError (msg )
195+
196+ img_data = img_data [y1 :y2 , x1 :x2 ]
197+ roi_meta = {"x1" : x1 , "y1" : y1 , "x2" : x2 , "y2" : y2 , "orig_image_shape" : (h , w )}
198+
199+ img_data = self ._cache_img (key = key , img_data = img_data .astype (np .uint8 ), meta = roi_meta )
162200
163- return img_data , img_data .shape [:2 ]
201+ return img_data , img_data .shape [:2 ], roi_meta
164202
165- def _cache_img (self , key : str | int , img_data : np .ndarray ) -> np .ndarray :
203+ def _cache_img (self , key : str | int , img_data : np .ndarray , meta : dict [ str , Any ] | None = None ) -> np .ndarray :
166204 """Cache an image after resizing.
167205
168206 If there is available space in the memory pool, the input image is cached.
@@ -182,14 +220,14 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
182220 return img_data
183221
184222 if self .mem_cache_img_max_size is None :
185- self .mem_cache_handler .put (key = key , data = img_data , meta = None )
223+ self .mem_cache_handler .put (key = key , data = img_data , meta = meta )
186224 return img_data
187225
188226 height , width = img_data .shape [:2 ]
189227 max_height , max_width = self .mem_cache_img_max_size
190228
191229 if height <= max_height and width <= max_width :
192- self .mem_cache_handler .put (key = key , data = img_data , meta = None )
230+ self .mem_cache_handler .put (key = key , data = img_data , meta = meta )
193231 return img_data
194232
195233 # Preserve the image size ratio and fit to max_height or max_width
@@ -206,7 +244,7 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
206244 self .mem_cache_handler .put (
207245 key = key ,
208246 data = resized_img ,
209- meta = None ,
247+ meta = meta ,
210248 )
211249 return resized_img
212250
0 commit comments