Skip to content

Commit 93f1a55

Browse files
authored
Include full image with anno in case there's no tile in tile dataset (#3964)
* include full image with anno incase there's no tile in dataset * update test
1 parent aa31dca commit 93f1a55

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

src/otx/core/data/dataset/tile.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,24 +148,24 @@ def _extract_rois(self, image: Image) -> list[BboxIntCoords]:
148148
tile_h, tile_w = self._tile_size
149149
h_ovl, w_ovl = self._overlap
150150

151-
rois: list[BboxIntCoords] = []
151+
rois: set[BboxIntCoords] = set()
152152
cols = range(0, img_w, int(tile_w * (1 - w_ovl)))
153153
rows = range(0, img_h, int(tile_h * (1 - h_ovl)))
154154

155155
if self.with_full_img:
156-
rois += [x1y1x2y2_to_xywh(0, 0, img_w, img_h)]
156+
rois.add(x1y1x2y2_to_xywh(0, 0, img_w, img_h))
157157
for offset_x, offset_y in product(cols, rows):
158158
x2 = min(offset_x + tile_w, img_w)
159159
y2 = min(offset_y + tile_h, img_h)
160160
c_x, c_y, w, h = x1y1x2y2_to_cxcywh(offset_x, offset_y, x2, y2)
161161
x1, y1, x2, y2 = cxcywh_to_x1y1x2y2(c_x, c_y, w, h)
162162
x1, y1, x2, y2 = clip_x1y1x2y2(x1, y1, x2, y2, img_w, img_h)
163163
x1, y1, x2, y2 = (int(v) for v in [x1, y1, x2, y2])
164-
rois += [x1y1x2y2_to_xywh(x1, y1, x2, y2)]
164+
rois.add(x1y1x2y2_to_xywh(x1, y1, x2, y2))
165165

166166
log.info(f"image: {img_h}x{img_w} ~ tile_size: {self._tile_size}")
167167
log.info(f"{len(rows)}x{len(cols)} tiles -> {len(rois)} tiles")
168-
return rois
168+
return list(rois)
169169

170170

171171
class OTXTileDatasetFactory:
@@ -242,6 +242,23 @@ def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_i
242242
msg = "Method _convert_entity is not implemented."
243243
raise NotImplementedError(msg)
244244

245+
def transform_item(
246+
self,
247+
item: DatasetItem,
248+
tile_size: tuple[int, int],
249+
overlap: tuple[float, float],
250+
with_full_img: bool,
251+
) -> DmDataset:
252+
"""Transform a dataset item to tile dataset which contains multiple tiles."""
253+
tile_ds = DmDataset.from_iterable([item])
254+
return tile_ds.transform(
255+
OTXTileTransform,
256+
tile_size=tile_size,
257+
overlap=overlap,
258+
threshold_drop_ann=0.5,
259+
with_full_img=with_full_img,
260+
)
261+
245262
def get_tiles(
246263
self,
247264
image: np.ndarray,
@@ -259,18 +276,24 @@ def get_tiles(
259276
- tile_entities (list[OTXDataEntity]): List of tile entities.
260277
- tile_attrs (list[dict]): List of tile attributes.
261278
"""
262-
tile_ds = DmDataset.from_iterable([item])
263-
tile_ds = tile_ds.transform(
264-
OTXTileTransform,
279+
tile_ds = self.transform_item(
280+
item,
265281
tile_size=self.tile_config.tile_size,
266282
overlap=(self.tile_config.overlap, self.tile_config.overlap),
267-
threshold_drop_ann=0.5,
268283
with_full_img=self.tile_config.with_full_img,
269284
)
270285

271286
if item.subset in VAL_SUBSET_NAMES:
272287
# NOTE: filter validation tiles with annotations only to avoid evaluation on empty tiles.
273288
tile_ds = tile_ds.filter("/item/annotation", filter_annotations=True, remove_empty=True)
289+
# if tile dataset is empty it means objects are too big to fit in any tile, in this case include full image
290+
if len(tile_ds) == 0:
291+
tile_ds = self.transform_item(
292+
item,
293+
tile_size=self.tile_config.tile_size,
294+
overlap=(self.tile_config.overlap, self.tile_config.overlap),
295+
with_full_img=True,
296+
)
274297

275298
tile_entities: list[OTXDataEntity] = []
276299
tile_attrs: list[dict] = []

tests/unit/core/utils/test_tile.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
def test_tile_transform_consistency(mocker):
18-
# Test that the tiler and tile transform are consistent
18+
# Test that OV tiler and PyTorch tile transform are consistent
1919
rng = np.random.default_rng()
2020
rnd_tile_size = rng.integers(low=100, high=500)
2121
rnd_tile_overlap = rng.random()
@@ -39,5 +39,8 @@ def test_tile_transform_consistency(mocker):
3939
tile_transform.with_full_img = True
4040

4141
dm_rois = [xywh_to_x1y1x2y2(*roi) for roi in tile_transform._extract_rois(dm_image)]
42-
# 0 index in tiler is the full image so we skip it
43-
assert np.allclose(dm_rois, tiler._tile(np_image))
42+
ov_tiler_rois = tiler._tile(np_image)
43+
44+
assert len(dm_rois) == len(ov_tiler_rois)
45+
for dm_roi in dm_rois:
46+
assert list(dm_roi) in ov_tiler_rois

0 commit comments

Comments
 (0)