44#
55
66import copy
7- import tempfile
87import uuid
98from itertools import product
109from multiprocessing import Pool
10+ from random import sample
1111from time import time
1212from typing import Callable , Dict , List , Tuple , Union
1313
@@ -61,21 +61,24 @@ class Tile:
6161 only works when `test_mode=False`, i.e., we never filter images
6262 during tests. Defaults to True.
6363 nproc (int, optional): Processes used for processing masks. Default: 4.
64+ sampling_ratio (float): Ratio for sampling entire tile dataset. Default: 1.0.(No sample)
65+ include_full_img (bool): Whether to include full-size image for inference or training. Default: False.
6466 """
6567
6668 def __init__ (
6769 self ,
6870 dataset ,
6971 pipeline ,
70- tmp_dir : tempfile .TemporaryDirectory ,
7172 tile_size : int = 400 ,
7273 overlap : float = 0.2 ,
7374 min_area_ratio : float = 0.9 ,
7475 iou_threshold : float = 0.45 ,
7576 max_per_img : int = 1500 ,
76- max_annotation : int = 5000 ,
77+ max_annotation : int = 2000 ,
7778 filter_empty_gt : bool = True ,
7879 nproc : int = 2 ,
80+ sampling_ratio : float = 1.0 ,
81+ include_full_img : bool = False ,
7982 ):
8083 self .min_area_ratio = min_area_ratio
8184 self .filter_empty_gt = filter_empty_gt
@@ -88,7 +91,6 @@ def __init__(
8891 self .num_images = len (dataset )
8992 self .num_classes = len (dataset .CLASSES )
9093 self .CLASSES = dataset .CLASSES # pylint: disable=invalid-name
91- self .tmp_folder = tmp_dir .name
9294 self .nproc = nproc
9395 self .img2fp32 = False
9496 for p in pipeline :
@@ -97,15 +99,21 @@ def __init__(
9799 break
98100
99101 self .dataset = dataset
100- self .tiles , self .cached_results = self .gen_tile_ann ()
102+ self .tiles_all , self .cached_results = self .gen_tile_ann (include_full_img )
103+ self .sample_num = max (int (len (self .tiles_all ) * sampling_ratio ), 1 )
104+ if sampling_ratio < 1.0 :
105+ self .tiles = sample (self .tiles_all , self .sample_num )
106+ else :
107+ self .tiles = self .tiles_all
101108
102109 @timeit
103- def gen_tile_ann (self ) -> Tuple [List [Dict ], List [Dict ]]:
110+ def gen_tile_ann (self , include_full_img ) -> Tuple [List [Dict ], List [Dict ]]:
104111 """Generate tile annotations and cache the original image-level annotations.
105112
106113 Returns:
107114 tiles: a list of tile annotations with some other useful information for data pipeline.
108115 cache_result: a list of original image-level annotations.
116+ include_full_img: whether to include full-size image for inference or training.
109117 """
110118 tiles = []
111119 cache_result = []
@@ -114,7 +122,8 @@ def gen_tile_ann(self) -> Tuple[List[Dict], List[Dict]]:
114122
115123 pbar = tqdm (total = len (self .dataset ) * 2 , desc = "Generating tile annotations..." )
116124 for idx , result in enumerate (cache_result ):
117- tiles .append (self .gen_single_img (result , dataset_idx = idx ))
125+ if include_full_img :
126+ tiles .append (self .gen_single_img (result , dataset_idx = idx ))
118127 pbar .update (1 )
119128
120129 for idx , result in enumerate (cache_result ):
@@ -165,19 +174,19 @@ def gen_tiles_single_img(self, result: Dict, dataset_idx: int) -> List[Dict]:
165174 height , width = img_shape [:2 ]
166175 _tile = self .prepare_result (result )
167176
168- num_patches_h = int (( height - self .tile_size ) / self .stride ) + 1
169- num_patches_w = int (( width - self .tile_size ) / self .stride ) + 1
177+ num_patches_h = ( height + self .stride - 1 ) // self .stride
178+ num_patches_w = ( width + self .stride - 1 ) // self .stride
170179 for (_ , _ ), (loc_i , loc_j ) in zip (
171180 product (range (num_patches_h ), range (num_patches_w )),
172181 product (
173- range (0 , height - self . tile_size + 1 , self .stride ),
174- range (0 , width - self . tile_size + 1 , self .stride ),
182+ range (0 , height , self .stride ),
183+ range (0 , width , self .stride ),
175184 ),
176185 ):
177186 x_1 = loc_j
178- x_2 = loc_j + self .tile_size
187+ x_2 = min ( loc_j + self .tile_size , width )
179188 y_1 = loc_i
180- y_2 = loc_i + self .tile_size
189+ y_2 = min ( loc_i + self .tile_size , height )
181190 tile = copy .deepcopy (_tile )
182191 tile ["original_shape_" ] = img_shape
183192 tile ["ori_shape" ] = (y_2 - y_1 , x_2 - x_1 , 3 )
@@ -191,6 +200,9 @@ def gen_tiles_single_img(self, result: Dict, dataset_idx: int) -> List[Dict]:
191200 if self .filter_empty_gt and len (tile ["gt_labels" ]) == 0 :
192201 continue
193202 tile_list .append (tile )
203+ if dataset_idx == 0 :
204+ print (f"image: { height } x{ width } ~ tile_size: { self .tile_size } " )
205+ print (f"{ num_patches_h } x{ num_patches_w } tiles -> { len (tile_list )} tiles after filtering" )
194206 return tile_list
195207
196208 def prepare_result (self , result : Dict ) -> Dict :
@@ -233,12 +245,11 @@ def tile_ann_assignment(
233245 gt_labels (np.ndarray): the original image-level labels
234246 """
235247 x_1 , y_1 = tile_box [0 ][:2 ]
236- overlap_ratio = self .tile_boxes_overlap (tile_box , gt_bboxes )
237- match_idx = np .where ((overlap_ratio [0 ] >= self .min_area_ratio ))[0 ]
248+ matched_indices = self .tile_boxes_overlap (tile_box , gt_bboxes )
238249
239- if len (match_idx ):
240- tile_lables = gt_labels [match_idx ][:]
241- tile_bboxes = gt_bboxes [match_idx ][:]
250+ if len (matched_indices ):
251+ tile_lables = gt_labels [matched_indices ][:]
252+ tile_bboxes = gt_bboxes [matched_indices ][:]
242253 tile_bboxes [:, 0 ] -= x_1
243254 tile_bboxes [:, 1 ] -= y_1
244255 tile_bboxes [:, 2 ] -= x_1
@@ -249,7 +260,7 @@ def tile_ann_assignment(
249260 tile_bboxes [:, 3 ] = np .minimum (self .tile_size , tile_bboxes [:, 3 ])
250261 tile_result ["gt_bboxes" ] = tile_bboxes
251262 tile_result ["gt_labels" ] = tile_lables
252- tile_result ["gt_masks" ] = gt_masks [match_idx ].crop (tile_box [0 ]) if gt_masks is not None else []
263+ tile_result ["gt_masks" ] = gt_masks [matched_indices ].crop (tile_box [0 ]) if gt_masks is not None else []
253264 else :
254265 tile_result .pop ("bbox_fields" )
255266 tile_result .pop ("mask_fields" )
@@ -270,18 +281,12 @@ def tile_boxes_overlap(self, tile_box: np.ndarray, boxes: np.ndarray) -> np.ndar
270281 boxes (np.ndarray): boxes in shape (N, 4).
271282
272283 Returns:
273- np.ndarray: overlapping ratio over boxes
284+ np.ndarray: matched indices.
274285 """
275- box_area = (boxes [:, 2 ] - boxes [:, 0 ]) * (boxes [:, 3 ] - boxes [:, 1 ])
276-
277- width_height = np .minimum (tile_box [:, None , 2 :], boxes [:, 2 :]) - np .maximum (tile_box [:, None , :2 ], boxes [:, :2 ])
278-
279- width_height = width_height .clip (min = 0 ) # [N,M,2]
280- inter = width_height .prod (2 )
281-
282- # handle empty boxes
283- tile_box_ratio = np .where (inter > 0 , inter / box_area , np .zeros (1 , dtype = inter .dtype ))
284- return tile_box_ratio
286+ x1 , y1 , x2 , y2 = tile_box [0 ]
287+ match_indices = (boxes [:, 0 ] > x1 ) & (boxes [:, 1 ] > y1 ) & (boxes [:, 2 ] < x2 ) & (boxes [:, 3 ] < y2 )
288+ match_indices = np .argwhere (match_indices == 1 ).flatten ()
289+ return match_indices
285290
286291 def multiclass_nms (
287292 self , boxes : np .ndarray , scores : np .ndarray , idxs : np .ndarray , iou_threshold : float , max_num : int
@@ -431,7 +436,7 @@ def merge(self, results: List[List]) -> Union[List[Tuple[np.ndarray, list]], Lis
431436
432437 merged_bbox_results : List [np .ndarray ] = [np .empty ((0 , 5 ), dtype = dtype ) for _ in range (self .num_images )]
433438 merged_mask_results : List [List ] = [[] for _ in range (self .num_images )]
434- merged_label_results : List [Union [List , np .ndarray ]] = [[] for _ in range (self .num_images )]
439+ merged_label_results : List [Union [List , np .ndarray ]] = [np . array ([]) for _ in range (self .num_images )]
435440
436441 for result , tile in zip (results , self .tiles ):
437442 tile_x1 , tile_y1 , _ , _ = tile ["tile_box" ]
@@ -477,3 +482,21 @@ def merge(self, results: List[List]) -> Union[List[Tuple[np.ndarray, list]], Lis
477482 if detection :
478483 return list (merged_bbox_results )
479484 return list (zip (merged_bbox_results , merged_mask_results ))
485+
486+ def get_ann_info (self , idx ):
487+ """Get annotation by index.
488+
489+ Args:
490+ idx (int): Index of data.
491+
492+ Returns:
493+ dict: Annotation info of specified index.
494+ """
495+ ann = {}
496+ if "gt_bboxes" in self .tiles [idx ]:
497+ ann ["bboxes" ] = self .tiles [idx ]["gt_bboxes" ]
498+ if "gt_masks" in self .tiles [idx ]:
499+ ann ["masks" ] = self .tiles [idx ]["gt_masks" ]
500+ if "gt_labels" in self .tiles [idx ]:
501+ ann ["labels" ] = self .tiles [idx ]["gt_labels" ]
502+ return ann
0 commit comments