Skip to content

Commit 0f87c86

Browse files
Add missing tile recipes and various tile recipe changes (#3942)
* add missing tile recipes * Fix tiling XAI out of range (#3943) - Fix tile merge XAI out of range * update xai tile merge * update rtdetr * update tile recipes * update rtdetr tile postprocess * update rtdetr recipes and tile recipes * update tile recipes * fix rtdetr unittest * update recipes * refactor tile unit test * address pr reviews * remove unnecessary files * update color channel * fix image channel passing * include tiling in cli integration test * remove transform_bbox --------- Co-authored-by: Vladislav Sovrasov <[email protected]>
1 parent 8f96f27 commit 0f87c86

24 files changed

+542
-244
lines changed

src/otx/algo/detection/base_models/detection_transformer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,24 @@ def export(
9898
if explain_mode:
9999
msg = "Explain mode is not supported for DETR models yet."
100100
raise NotImplementedError(msg)
101-
return self.postprocess(self._forward_features(batch_inputs), deploy_mode=True)
101+
102+
return self.postprocess(
103+
self._forward_features(batch_inputs),
104+
[meta["img_shape"] for meta in batch_img_metas],
105+
deploy_mode=True,
106+
)
102107

103108
def postprocess(
104109
self,
105110
outputs: dict[str, Tensor],
106-
original_size: tuple[int, int] | None = None,
111+
original_sizes: list[tuple[int, int]],
107112
deploy_mode: bool = False,
108113
) -> dict[str, Tensor] | tuple[list[Tensor], list[Tensor], list[Tensor]]:
109114
"""Post-processes the model outputs.
110115
111116
Args:
112117
outputs (dict[str, Tensor]): The model outputs.
113-
original_size (tuple[int, int], optional): The original size of the input images. Defaults to None.
118+
original_sizes (list[tuple[int, int]]): The original image sizes.
114119
deploy_mode (bool, optional): Whether to run in deploy mode. Defaults to False.
115120
116121
Returns:
@@ -120,9 +125,9 @@ def postprocess(
120125

121126
# convert bbox to xyxy and rescale back to original size (resize in OTX)
122127
bbox_pred = box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy")
123-
if not deploy_mode and original_size is not None:
124-
original_size_tensor = torch.tensor(original_size).to(bbox_pred.device)
125-
bbox_pred *= original_size_tensor.repeat(1, 2).unsqueeze(1)
128+
if not deploy_mode:
129+
original_size_tensor = torch.tensor(original_sizes).to(bbox_pred.device)
130+
bbox_pred *= original_size_tensor.flip(1).repeat(1, 2).unsqueeze(1)
126131

127132
# perform scores computation and gather topk results
128133
scores = nn.functional.sigmoid(logits)
@@ -136,7 +141,7 @@ def postprocess(
136141

137142
scores_list, boxes_list, labels_list = [], [], []
138143

139-
for sc, bb, ll in zip(scores, boxes, labels):
144+
for sc, bb, ll, original_size in zip(scores, boxes, labels, original_sizes):
140145
scores_list.append(sc)
141146
boxes_list.append(
142147
BoundingBoxes(bb, format="xyxy", canvas_size=original_size),

src/otx/algo/detection/rtdetr.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,14 @@ def _customize_inputs(
7575
# prepare bboxes for the model
7676
for bb, ll in zip(entity.bboxes, entity.labels):
7777
# convert to cxcywh if needed
78-
converted_bboxes = (
79-
box_convert(bb, in_fmt="xyxy", out_fmt="cxcywh") if bb.format == BoundingBoxFormat.XYXY else bb
80-
)
81-
# normalize the bboxes
82-
scaled_bboxes = converted_bboxes / torch.tensor(bb.canvas_size[::-1]).tile(2)[None].to(
83-
converted_bboxes.device,
84-
)
78+
if len(scaled_bboxes := bb):
79+
converted_bboxes = (
80+
box_convert(bb, in_fmt="xyxy", out_fmt="cxcywh") if bb.format == BoundingBoxFormat.XYXY else bb
81+
)
82+
# normalize the bboxes
83+
scaled_bboxes = converted_bboxes / torch.tensor(bb.canvas_size[::-1]).tile(2)[None].to(
84+
converted_bboxes.device,
85+
)
8586
targets.append({"boxes": scaled_bboxes, "labels": ll})
8687

8788
return {
@@ -109,7 +110,8 @@ def _customize_outputs(
109110
raise TypeError(msg)
110111
return losses
111112

112-
scores, bboxes, labels = self.model.postprocess(outputs, [img_info.img_shape for img_info in inputs.imgs_info])
113+
original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
114+
scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)
113115

114116
return DetBatchPredEntity(
115117
batch_size=len(outputs),

src/otx/algo/detection/yolox.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from __future__ import annotations
77

8-
from typing import TYPE_CHECKING, Any
8+
from typing import TYPE_CHECKING, Any, Literal
99

1010
from otx.algo.common.losses import CrossEntropyLoss, L1Loss
1111
from otx.algo.detection.backbones import CSPDarknet
@@ -76,13 +76,16 @@ def _exporter(self) -> OTXModelExporter:
7676
raise ValueError(msg)
7777

7878
swap_rgb = not isinstance(self, YOLOXTINY) # only YOLOX-TINY uses RGB
79+
resize_mode: Literal["standard", "fit_to_window_letterbox"] = "fit_to_window_letterbox"
80+
if self.tile_config.enable_tiler:
81+
resize_mode = "standard"
7982

8083
return OTXNativeModelExporter(
8184
task_level_export_parameters=self._export_parameters,
8285
input_size=(1, 3, *self.input_size),
8386
mean=self.mean,
8487
std=self.std,
85-
resize_mode="fit_to_window_letterbox",
88+
resize_mode=resize_mode,
8689
pad_value=114,
8790
swap_rgb=swap_rgb,
8891
via_onnx=True,

src/otx/core/data/dataset/tile.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None:
218218
dataset.mem_cache_handler,
219219
dataset.mem_cache_img_max_size,
220220
dataset.max_refetch,
221+
dataset.image_color_channel,
222+
dataset.stack_images,
223+
dataset.to_tv_image,
221224
)
222225
self.tile_config = tile_config
223226
self._dataset = dataset

src/otx/core/data/entity/tile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def unbind(self) -> list[tuple[TileAttrDictList, DetBatchDataEntity]]:
125125
labels=[[] for _ in range(self.batch_size)],
126126
),
127127
)
128-
return list(zip(batch_tile_attr_list, batch_data_entities))
128+
return list(zip(batch_tile_attr_list, batch_data_entities, strict=True))
129129

130130
@classmethod
131131
def collate_fn(cls, batch_entities: list[TileDetDataEntity]) -> TileBatchDetDataEntity:
@@ -218,7 +218,7 @@ def unbind(self) -> list[tuple[TileAttrDictList, InstanceSegBatchDataEntity]]:
218218
)
219219
for i in range(0, len(tiles), self.batch_size)
220220
]
221-
return list(zip(batch_tile_attr_list, batch_data_entities))
221+
return list(zip(batch_tile_attr_list, batch_data_entities, strict=True))
222222

223223
@classmethod
224224
def collate_fn(cls, batch_entities: list[TileInstSegDataEntity]) -> TileBatchInstSegDataEntity:

src/otx/core/utils/tile_merge.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]):
2929
img_infos (list[ImageInfo]): Original image information before tiling.
3030
num_classes (int): Number of classes.
3131
tile_config (TileConfig): Tile configuration.
32-
explain_mode (bool): Whether or not tiles have explain features. Default: False.
32+
explain_mode (bool, optional): Whether or not tiles have explain features. Default: False.
3333
"""
3434

3535
def __init__(
@@ -119,8 +119,8 @@ def merge(
119119
img_ids = []
120120
explain_mode = self.explain_mode
121121

122-
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
123-
batch_size = tile_preds.batch_size
122+
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
123+
batch_size = len(tile_attrs)
124124
saliency_maps = tile_preds.saliency_map if explain_mode else [[] for _ in range(batch_size)]
125125
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(batch_size)]
126126
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_s_map, tile_f_vect in zip(
@@ -131,6 +131,7 @@ def merge(
131131
tile_preds.scores,
132132
saliency_maps,
133133
feature_vectors,
134+
strict=True,
134135
):
135136
offset_x, offset_y, _, _ = tile_attr["roi"]
136137
tile_bboxes[:, 0::2] += offset_x
@@ -156,7 +157,7 @@ def merge(
156157

157158
return [
158159
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
159-
for img_id, image_info in zip(img_ids, self.img_infos)
160+
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
160161
]
161162

162163
def _merge_entities(
@@ -319,8 +320,8 @@ def merge(
319320
img_ids = []
320321
explain_mode = self.explain_mode
321322

322-
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
323-
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(tile_preds.batch_size)]
323+
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
324+
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(len(tile_attrs))]
324325
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_masks, tile_f_vect in zip(
325326
tile_attrs,
326327
tile_preds.imgs_info,
@@ -329,6 +330,7 @@ def merge(
329330
tile_preds.scores,
330331
tile_preds.masks,
331332
feature_vectors,
333+
strict=True,
332334
):
333335
keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0
334336
keep_indices = keep_indices.nonzero(as_tuple=True)[0]
@@ -363,7 +365,7 @@ def merge(
363365

364366
return [
365367
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
366-
for img_id, image_info in zip(img_ids, self.img_infos)
368+
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
367369
]
368370

369371
def _merge_entities(
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
task: DETECTION
2+
input_size:
3+
- 800
4+
- 800
5+
mem_cache_size: 1GB
6+
mem_cache_img_max_size: null
7+
image_color_channel: RGB
8+
stack_images: true
9+
data_format: coco_instances
10+
unannotated_items_ratio: 0.0
11+
tile_config:
12+
enable_tiler: true
13+
enable_adaptive_tiling: true
14+
train_subset:
15+
subset_name: train
16+
transform_lib_type: TORCHVISION
17+
batch_size: 1
18+
num_workers: 2
19+
to_tv_image: false
20+
transforms:
21+
- class_path: otx.core.data.transform_libs.torchvision.Resize
22+
init_args:
23+
scale: $(input_size)
24+
keep_ratio: false
25+
transform_bbox: true
26+
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
27+
init_args:
28+
prob: 0.5
29+
is_numpy_to_tvtensor: true
30+
- class_path: torchvision.transforms.v2.ToDtype
31+
init_args:
32+
dtype: ${as_torch_dtype:torch.float32}
33+
- class_path: torchvision.transforms.v2.Normalize
34+
init_args:
35+
mean: [0.0, 0.0, 0.0]
36+
std: [255.0, 255.0, 255.0]
37+
sampler:
38+
class_path: torch.utils.data.RandomSampler
39+
40+
val_subset:
41+
subset_name: val
42+
transform_lib_type: TORCHVISION
43+
batch_size: 1
44+
num_workers: 2
45+
to_tv_image: false
46+
transforms:
47+
- class_path: otx.core.data.transform_libs.torchvision.Resize
48+
init_args:
49+
scale: $(input_size)
50+
keep_ratio: false
51+
is_numpy_to_tvtensor: true
52+
- class_path: torchvision.transforms.v2.ToDtype
53+
init_args:
54+
dtype: ${as_torch_dtype:torch.float32}
55+
- class_path: torchvision.transforms.v2.Normalize
56+
init_args:
57+
mean: [0.0, 0.0, 0.0]
58+
std: [255.0, 255.0, 255.0]
59+
sampler:
60+
class_path: torch.utils.data.RandomSampler
61+
62+
test_subset:
63+
subset_name: test
64+
transform_lib_type: TORCHVISION
65+
batch_size: 1
66+
num_workers: 2
67+
to_tv_image: false
68+
transforms:
69+
- class_path: otx.core.data.transform_libs.torchvision.Resize
70+
init_args:
71+
scale: $(input_size)
72+
keep_ratio: false
73+
is_numpy_to_tvtensor: true
74+
- class_path: torchvision.transforms.v2.ToDtype
75+
init_args:
76+
dtype: ${as_torch_dtype:torch.float32}
77+
- class_path: torchvision.transforms.v2.Normalize
78+
init_args:
79+
mean: [0.0, 0.0, 0.0]
80+
std: [255.0, 255.0, 255.0]
81+
sampler:
82+
class_path: torch.utils.data.RandomSampler

src/otx/recipe/detection/atss_mobilenetv2_tile.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,10 @@ engine:
2828

2929
callback_monitor: val/map_50
3030

31-
data: ../_base_/data/detection.yaml
31+
data: ../_base_/data/detection_tile.yaml
3232
overrides:
3333
gradient_clip_val: 35.0
3434
data:
35-
tile_config:
36-
enable_tiler: true
37-
enable_adaptive_tiling: true
38-
3935
train_subset:
4036
batch_size: 8
4137
sampler:
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
model:
2+
class_path: otx.algo.detection.atss.ResNeXt101ATSS
3+
init_args:
4+
label_info: 80
5+
6+
optimizer:
7+
class_path: torch.optim.SGD
8+
init_args:
9+
lr: 0.004
10+
momentum: 0.9
11+
weight_decay: 0.0001
12+
13+
scheduler:
14+
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
15+
init_args:
16+
num_warmup_steps: 3
17+
main_scheduler_callable:
18+
class_path: lightning.pytorch.cli.ReduceLROnPlateau
19+
init_args:
20+
mode: max
21+
factor: 0.1
22+
patience: 4
23+
monitor: val/map_50
24+
25+
engine:
26+
task: DETECTION
27+
device: auto
28+
29+
callback_monitor: val/map_50
30+
31+
data: ../_base_/data/detection_tile.yaml
32+
overrides:
33+
gradient_clip_val: 35.0
34+
callbacks:
35+
- class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling
36+
init_args:
37+
max_interval: 5
38+
decay: -0.025
39+
min_lrschedule_patience: 3
40+
41+
data:
42+
train_subset:
43+
batch_size: 4
44+
sampler:
45+
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler
46+
47+
val_subset:
48+
batch_size: 4
49+
50+
test_subset:
51+
batch_size: 4

src/otx/recipe/detection/rtdetr_101.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ overrides:
8888
init_args:
8989
scale: $(input_size)
9090
keep_ratio: false
91-
transform_bbox: true
9291
is_numpy_to_tvtensor: true
9392
- class_path: torchvision.transforms.v2.ToDtype
9493
init_args:
@@ -103,7 +102,6 @@ overrides:
103102
init_args:
104103
scale: $(input_size)
105104
keep_ratio: false
106-
transform_bbox: true
107105
is_numpy_to_tvtensor: true
108106
- class_path: torchvision.transforms.v2.ToDtype
109107
init_args:

0 commit comments

Comments
 (0)