Skip to content

Commit 156d1e8

Browse files
sungmancjaegukhyuneugene123twsungmancSongki Choi
authored
Optimize counting train & inference speed and memory consumption (#2172)
* Add sampling tiling dataset method * Add unit test * update * update * update configs * update * Update * Fix configuration for deployment * Add experimental script * Fix smallthing in exp.sh * Change default labels type from list to np.array * Revert resolution, num_workers * Fix typo * Skip mask if its confidence is under threshold * Add prediction with user defined confidence threshold * Update exp.sh * support multi-batch in tile classifier * polygon sampling * exclude full image in training * Add * Refine adaptive tile params - Use size rather than area - 32 pixels as min detectable size - Default object_tile_ratio = 32 / 1024 = 0.03 * Tune params * Fix tile on the edge * Fix IR overlap ratio according to IR scale factor Signed-off-by: Songki Choi <[email protected]> * Set default object_tile_ratio = 0.06 (32/512) * Add tile_deployment * Keep aspect ratio for tiling * Fix tile patching * Fix adaptive tile logic for robustness * Delete useless file * Cleansing * Update exp.sh * Fix precommit * Refine adaptive tile params - Use size rather than area - 32 pixels as min detectable size - Default object_tile_ratio = 32 / 1024 = 0.03 * Tune params * Set default object_tile_ratio = 0.06 (32/512) * Keep aspect ratio for tiling * Fix tile patching * Fix adaptive tile logic for robustness * Remove useless line * Apply comments * Remove some comments * make black happy * Update default value of object tile ratio * Fix polygon append * Fix polygon append * Fix merge error, revert temp tox change * Fix precommit, remove exp.sh * Remove 20 points sampling * Fix mypy issue * Modify way to change subset type of _infer_model dataset & Add confidence threshold filter for openvino eval * Remove breakpoint * Update config param description, expose ellipse option * Fix unit test * Fix deploy patch bug (w,h) -> (h,w) * Fix iseg intg test * Revert tiling-ins-seg intg test * Fix for det tiling intg test --------- Signed-off-by: Songki Choi <[email protected]> Co-authored-by: jaegukhyun <[email protected]> Co-authored-by: Eugene Liu <[email protected]> Co-authored-by: sungmanc <[email protected]> Co-authored-by: Songki Choi <[email protected]>
1 parent c63d3ac commit 156d1e8

File tree

37 files changed

+726
-324
lines changed

37 files changed

+726
-324
lines changed

otx/algorithms/common/configs/training_base.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,13 @@ class BasePostprocessing(ParameterGroup):
220220
affects_outcome_of=ModelLifecycle.INFERENCE,
221221
)
222222

223+
use_ellipse_shapes = configurable_boolean(
224+
default_value=False,
225+
header="Use ellipse shapes",
226+
description="Use direct ellipse shape in inference instead of polygon from mask",
227+
affects_outcome_of=ModelLifecycle.INFERENCE,
228+
)
229+
223230
@attrs
224231
class BaseNNCFOptimization(ParameterGroup):
225232
"""BaseNNCFOptimization for OTX Algorithms."""
@@ -350,7 +357,7 @@ class BaseTilingParameters(ParameterGroup):
350357
description="Tile Image Size",
351358
default_value=400,
352359
min_value=100,
353-
max_value=1024,
360+
max_value=4096,
354361
affects_outcome_of=ModelLifecycle.NONE,
355362
)
356363

@@ -368,7 +375,7 @@ class BaseTilingParameters(ParameterGroup):
368375
description="Max object per image",
369376
default_value=1500,
370377
min_value=1,
371-
max_value=10000,
378+
max_value=5000,
372379
affects_outcome_of=ModelLifecycle.NONE,
373380
)
374381

@@ -388,4 +395,26 @@ class BaseTilingParameters(ParameterGroup):
388395
affects_outcome_of=ModelLifecycle.NONE,
389396
)
390397

398+
tile_sampling_ratio = configurable_float(
399+
header="Sampling Ratio for entire tiling",
400+
description="Since tiling train and validation to all tile from large image, "
401+
"usually it takes lots of time than normal training."
402+
"The tile_sampling_ratio is ratio for sampling entire tile dataset."
403+
"Sampling tile dataset would save lots of time for training and validation time."
404+
"Note that sampling will be applied to training and validation dataset, not test dataset.",
405+
default_value=1.0,
406+
min_value=0.000001,
407+
max_value=1.0,
408+
affects_outcome_of=ModelLifecycle.NONE,
409+
)
410+
411+
object_tile_ratio = configurable_float(
412+
header="Object tile ratio",
413+
description="The desired ratio of min object size and tile size.",
414+
default_value=0.03,
415+
min_value=0.00,
416+
max_value=1.00,
417+
affects_outcome_of=ModelLifecycle.NONE,
418+
)
419+
391420
tiling_parameters = add_parameter_group(BaseTilingParameters)

otx/algorithms/detection/adapters/mmdet/datasets/dataset.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions
1515
# and limitations under the License.
1616

17-
import tempfile
1817
from collections import OrderedDict
1918
from copy import copy
2019
from typing import Any, Dict, List, Sequence, Tuple, Union
@@ -32,6 +31,7 @@
3231
from otx.api.entities.dataset_item import DatasetItemEntity
3332
from otx.api.entities.datasets import DatasetEntity
3433
from otx.api.entities.label import Domain, LabelEntity
34+
from otx.api.entities.subset import Subset
3535
from otx.api.utils.shape_factory import ShapeFactory
3636

3737
from .tiling import Tile
@@ -270,6 +270,7 @@ def evaluate( # pylint: disable=too-many-branches
270270
if metric not in allowed_metrics:
271271
raise KeyError(f"metric {metric} is not supported")
272272
annotations = [self.get_ann_info(i) for i in range(len(self))]
273+
assert len(annotations) == len(results), "annotation length does not match prediction results"
273274
iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
274275
if metric == "mAP":
275276
assert isinstance(iou_thrs, list)
@@ -302,7 +303,7 @@ def evaluate( # pylint: disable=too-many-branches
302303

303304
# pylint: disable=too-many-arguments
304305
@DATASETS.register_module()
305-
class ImageTilingDataset:
306+
class ImageTilingDataset(OTXDetDataset):
306307
"""A wrapper of tiling dataset.
307308
308309
Suitable for training small object dataset. This wrapper composed of `Tile`
@@ -326,6 +327,8 @@ class ImageTilingDataset:
326327
after NMS, only top max_per_img will be kept. Defaults to 200.
327328
max_annotation (int, optional): Limit the number of ground truth by
328329
randomly select 5000 due to RAM OOM. Defaults to 5000.
330+
sampling_ratio (flaot): Ratio for sampling entire tile dataset.
331+
include_full_img (bool): Whether to include full image in the dataset.
329332
"""
330333

331334
def __init__(
@@ -340,28 +343,29 @@ def __init__(
340343
max_annotation=5000,
341344
filter_empty_gt=True,
342345
test_mode=False,
346+
sampling_ratio=1.0,
347+
include_full_img=False,
343348
):
344349
self.dataset = build_dataset(dataset)
345350
self.CLASSES = self.dataset.CLASSES
346-
self.tmp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
347351

348352
self.tile_dataset = Tile(
349353
self.dataset,
350354
pipeline,
351-
tmp_dir=self.tmp_dir,
352355
tile_size=tile_size,
353356
overlap=overlap_ratio,
354357
min_area_ratio=min_area_ratio,
355358
iou_threshold=iou_threshold,
356359
max_per_img=max_per_img,
357360
max_annotation=max_annotation,
358-
filter_empty_gt=False if test_mode else filter_empty_gt,
361+
filter_empty_gt=filter_empty_gt if self.dataset.otx_dataset[0].subset != Subset.TESTING else False,
362+
sampling_ratio=sampling_ratio if self.dataset.otx_dataset[0].subset != Subset.TESTING else 1.0,
363+
include_full_img=include_full_img if self.dataset.otx_dataset[0].subset != Subset.TESTING else True,
359364
)
360365
self.flag = np.zeros(len(self), dtype=np.uint8)
361366
self.pipeline = Compose(pipeline)
362367
self.test_mode = test_mode
363368
self.num_samples = len(self.dataset) # number of original samples
364-
self.merged_results: Union[List[Tuple[np.ndarray, list]], List[np.ndarray]] = []
365369

366370
def __len__(self) -> int:
367371
"""Get the length of the dataset."""
@@ -379,18 +383,16 @@ def __getitem__(self, idx: int) -> Dict:
379383
"""
380384
return self.pipeline(self.tile_dataset[idx])
381385

382-
def evaluate(self, results, **kwargs) -> Dict[str, float]:
383-
"""Evaluation on Tile dataset.
386+
def get_ann_info(self, idx):
387+
"""Get annotation information of a tile.
384388
385389
Args:
386-
results (list[list | tuple]): Testing results of the dataset.
387-
**kwargs: Addition keyword arguments.
390+
idx (int): Index of data.
388391
389392
Returns:
390-
dict[str, float]: evaluation metric.
393+
dict: Annotation information of a tile.
391394
"""
392-
self.merged_results = self.tile_dataset.merge(results)
393-
return self.dataset.evaluate(self.merged_results, **kwargs)
395+
return self.tile_dataset.get_ann_info(idx)
394396

395397
def merge(self, results) -> Union[List[Tuple[np.ndarray, list]], List[np.ndarray]]:
396398
"""Merge tile-level results to image-level results.
@@ -401,10 +403,4 @@ def merge(self, results) -> Union[List[Tuple[np.ndarray, list]], List[np.ndarray
401403
Returns:
402404
merged_results (list[list | tuple]): Merged results of the dataset.
403405
"""
404-
self.merged_results = self.tile_dataset.merge(results)
405-
return self.merged_results
406-
407-
def __del__(self):
408-
"""Delete the temporary directory when the object is deleted."""
409-
if getattr(self, "tmp_dir", False):
410-
self.tmp_dir.cleanup()
406+
return self.tile_dataset.merge(results)

otx/algorithms/detection/adapters/mmdet/datasets/tiling.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#
55

66
import copy
7-
import tempfile
87
import uuid
98
from itertools import product
109
from multiprocessing import Pool
10+
from random import sample
1111
from time import time
1212
from typing import Callable, Dict, List, Tuple, Union
1313

@@ -61,21 +61,24 @@ class Tile:
6161
only works when `test_mode=False`, i.e., we never filter images
6262
during tests. Defaults to True.
6363
nproc (int, optional): Processes used for processing masks. Default: 4.
64+
sampling_ratio (float): Ratio for sampling entire tile dataset. Default: 1.0.(No sample)
65+
include_full_img (bool): Whether to include full-size image for inference or training. Default: False.
6466
"""
6567

6668
def __init__(
6769
self,
6870
dataset,
6971
pipeline,
70-
tmp_dir: tempfile.TemporaryDirectory,
7172
tile_size: int = 400,
7273
overlap: float = 0.2,
7374
min_area_ratio: float = 0.9,
7475
iou_threshold: float = 0.45,
7576
max_per_img: int = 1500,
76-
max_annotation: int = 5000,
77+
max_annotation: int = 2000,
7778
filter_empty_gt: bool = True,
7879
nproc: int = 2,
80+
sampling_ratio: float = 1.0,
81+
include_full_img: bool = False,
7982
):
8083
self.min_area_ratio = min_area_ratio
8184
self.filter_empty_gt = filter_empty_gt
@@ -88,7 +91,6 @@ def __init__(
8891
self.num_images = len(dataset)
8992
self.num_classes = len(dataset.CLASSES)
9093
self.CLASSES = dataset.CLASSES # pylint: disable=invalid-name
91-
self.tmp_folder = tmp_dir.name
9294
self.nproc = nproc
9395
self.img2fp32 = False
9496
for p in pipeline:
@@ -97,15 +99,21 @@ def __init__(
9799
break
98100

99101
self.dataset = dataset
100-
self.tiles, self.cached_results = self.gen_tile_ann()
102+
self.tiles_all, self.cached_results = self.gen_tile_ann(include_full_img)
103+
self.sample_num = max(int(len(self.tiles_all) * sampling_ratio), 1)
104+
if sampling_ratio < 1.0:
105+
self.tiles = sample(self.tiles_all, self.sample_num)
106+
else:
107+
self.tiles = self.tiles_all
101108

102109
@timeit
103-
def gen_tile_ann(self) -> Tuple[List[Dict], List[Dict]]:
110+
def gen_tile_ann(self, include_full_img) -> Tuple[List[Dict], List[Dict]]:
104111
"""Generate tile annotations and cache the original image-level annotations.
105112
106113
Returns:
107114
tiles: a list of tile annotations with some other useful information for data pipeline.
108115
cache_result: a list of original image-level annotations.
116+
include_full_img: whether to include full-size image for inference or training.
109117
"""
110118
tiles = []
111119
cache_result = []
@@ -114,7 +122,8 @@ def gen_tile_ann(self) -> Tuple[List[Dict], List[Dict]]:
114122

115123
pbar = tqdm(total=len(self.dataset) * 2, desc="Generating tile annotations...")
116124
for idx, result in enumerate(cache_result):
117-
tiles.append(self.gen_single_img(result, dataset_idx=idx))
125+
if include_full_img:
126+
tiles.append(self.gen_single_img(result, dataset_idx=idx))
118127
pbar.update(1)
119128

120129
for idx, result in enumerate(cache_result):
@@ -165,19 +174,19 @@ def gen_tiles_single_img(self, result: Dict, dataset_idx: int) -> List[Dict]:
165174
height, width = img_shape[:2]
166175
_tile = self.prepare_result(result)
167176

168-
num_patches_h = int((height - self.tile_size) / self.stride) + 1
169-
num_patches_w = int((width - self.tile_size) / self.stride) + 1
177+
num_patches_h = (height + self.stride - 1) // self.stride
178+
num_patches_w = (width + self.stride - 1) // self.stride
170179
for (_, _), (loc_i, loc_j) in zip(
171180
product(range(num_patches_h), range(num_patches_w)),
172181
product(
173-
range(0, height - self.tile_size + 1, self.stride),
174-
range(0, width - self.tile_size + 1, self.stride),
182+
range(0, height, self.stride),
183+
range(0, width, self.stride),
175184
),
176185
):
177186
x_1 = loc_j
178-
x_2 = loc_j + self.tile_size
187+
x_2 = min(loc_j + self.tile_size, width)
179188
y_1 = loc_i
180-
y_2 = loc_i + self.tile_size
189+
y_2 = min(loc_i + self.tile_size, height)
181190
tile = copy.deepcopy(_tile)
182191
tile["original_shape_"] = img_shape
183192
tile["ori_shape"] = (y_2 - y_1, x_2 - x_1, 3)
@@ -191,6 +200,9 @@ def gen_tiles_single_img(self, result: Dict, dataset_idx: int) -> List[Dict]:
191200
if self.filter_empty_gt and len(tile["gt_labels"]) == 0:
192201
continue
193202
tile_list.append(tile)
203+
if dataset_idx == 0:
204+
print(f"image: {height}x{width} ~ tile_size: {self.tile_size}")
205+
print(f"{num_patches_h}x{num_patches_w} tiles -> {len(tile_list)} tiles after filtering")
194206
return tile_list
195207

196208
def prepare_result(self, result: Dict) -> Dict:
@@ -233,12 +245,11 @@ def tile_ann_assignment(
233245
gt_labels (np.ndarray): the original image-level labels
234246
"""
235247
x_1, y_1 = tile_box[0][:2]
236-
overlap_ratio = self.tile_boxes_overlap(tile_box, gt_bboxes)
237-
match_idx = np.where((overlap_ratio[0] >= self.min_area_ratio))[0]
248+
matched_indices = self.tile_boxes_overlap(tile_box, gt_bboxes)
238249

239-
if len(match_idx):
240-
tile_lables = gt_labels[match_idx][:]
241-
tile_bboxes = gt_bboxes[match_idx][:]
250+
if len(matched_indices):
251+
tile_lables = gt_labels[matched_indices][:]
252+
tile_bboxes = gt_bboxes[matched_indices][:]
242253
tile_bboxes[:, 0] -= x_1
243254
tile_bboxes[:, 1] -= y_1
244255
tile_bboxes[:, 2] -= x_1
@@ -249,7 +260,7 @@ def tile_ann_assignment(
249260
tile_bboxes[:, 3] = np.minimum(self.tile_size, tile_bboxes[:, 3])
250261
tile_result["gt_bboxes"] = tile_bboxes
251262
tile_result["gt_labels"] = tile_lables
252-
tile_result["gt_masks"] = gt_masks[match_idx].crop(tile_box[0]) if gt_masks is not None else []
263+
tile_result["gt_masks"] = gt_masks[matched_indices].crop(tile_box[0]) if gt_masks is not None else []
253264
else:
254265
tile_result.pop("bbox_fields")
255266
tile_result.pop("mask_fields")
@@ -270,18 +281,12 @@ def tile_boxes_overlap(self, tile_box: np.ndarray, boxes: np.ndarray) -> np.ndar
270281
boxes (np.ndarray): boxes in shape (N, 4).
271282
272283
Returns:
273-
np.ndarray: overlapping ratio over boxes
284+
np.ndarray: matched indices.
274285
"""
275-
box_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
276-
277-
width_height = np.minimum(tile_box[:, None, 2:], boxes[:, 2:]) - np.maximum(tile_box[:, None, :2], boxes[:, :2])
278-
279-
width_height = width_height.clip(min=0) # [N,M,2]
280-
inter = width_height.prod(2)
281-
282-
# handle empty boxes
283-
tile_box_ratio = np.where(inter > 0, inter / box_area, np.zeros(1, dtype=inter.dtype))
284-
return tile_box_ratio
286+
x1, y1, x2, y2 = tile_box[0]
287+
match_indices = (boxes[:, 0] > x1) & (boxes[:, 1] > y1) & (boxes[:, 2] < x2) & (boxes[:, 3] < y2)
288+
match_indices = np.argwhere(match_indices == 1).flatten()
289+
return match_indices
285290

286291
def multiclass_nms(
287292
self, boxes: np.ndarray, scores: np.ndarray, idxs: np.ndarray, iou_threshold: float, max_num: int
@@ -431,7 +436,7 @@ def merge(self, results: List[List]) -> Union[List[Tuple[np.ndarray, list]], Lis
431436

432437
merged_bbox_results: List[np.ndarray] = [np.empty((0, 5), dtype=dtype) for _ in range(self.num_images)]
433438
merged_mask_results: List[List] = [[] for _ in range(self.num_images)]
434-
merged_label_results: List[Union[List, np.ndarray]] = [[] for _ in range(self.num_images)]
439+
merged_label_results: List[Union[List, np.ndarray]] = [np.array([]) for _ in range(self.num_images)]
435440

436441
for result, tile in zip(results, self.tiles):
437442
tile_x1, tile_y1, _, _ = tile["tile_box"]
@@ -477,3 +482,21 @@ def merge(self, results: List[List]) -> Union[List[Tuple[np.ndarray, list]], Lis
477482
if detection:
478483
return list(merged_bbox_results)
479484
return list(zip(merged_bbox_results, merged_mask_results))
485+
486+
def get_ann_info(self, idx):
487+
"""Get annotation by index.
488+
489+
Args:
490+
idx (int): Index of data.
491+
492+
Returns:
493+
dict: Annotation info of specified index.
494+
"""
495+
ann = {}
496+
if "gt_bboxes" in self.tiles[idx]:
497+
ann["bboxes"] = self.tiles[idx]["gt_bboxes"]
498+
if "gt_masks" in self.tiles[idx]:
499+
ann["masks"] = self.tiles[idx]["gt_masks"]
500+
if "gt_labels" in self.tiles[idx]:
501+
ann["labels"] = self.tiles[idx]["gt_labels"]
502+
return ann

otx/algorithms/detection/adapters/mmdet/hooks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
#
55

66
from .det_class_probability_map_hook import DetClassProbabilityMapHook
7+
from .tile_sampling_hook import TileSamplingHook
78

8-
__all__ = ["DetClassProbabilityMapHook"]
9+
__all__ = ["DetClassProbabilityMapHook", "TileSamplingHook"]

0 commit comments

Comments
 (0)