Skip to content

Commit 98a9cac

Browse files
authored
Remove datumaro attribute id from tiling, add subset names (#3933)
* remove datumaro attribute id from tiling * add subset names
1 parent c3749e3 commit 98a9cac

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

src/otx/core/data/dataset/tile.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@
5353
# NOTE: Disable private-member-access (SLF001).
5454
# This is a workaround so we could apply the same transforms to tiles as the original dataset.
5555

56+
# NOTE: Datumaro subset name should be standardized.
57+
TRAIN_SUBSET_NAMES = ("train", "TRAINING")
58+
VAL_SUBSET_NAMES = ("val", "VALIDATION")
59+
5660

5761
class OTXTileTransform(Tile):
5862
"""OTX tile transform.
@@ -188,7 +192,7 @@ def create(
188192
Returns:
189193
OTXTileDataset: Tile dataset.
190194
"""
191-
if dataset.dm_subset[0].subset == "train":
195+
if dataset.dm_subset[0].subset in TRAIN_SUBSET_NAMES:
192196
return OTXTileTrainDataset(dataset, tile_config)
193197

194198
if task == OTXTaskType.DETECTION:
@@ -230,12 +234,17 @@ def _get_item_impl(self, index: int) -> OTXDataEntity | None:
230234
"""Get item implementation from the original dataset."""
231235
return self._dataset._get_item_impl(index)
232236

233-
def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem) -> OTXDataEntity:
237+
def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> OTXDataEntity:
234238
"""Convert a tile dataset item to OTXDataEntity."""
235239
msg = "Method _convert_entity is not implemented."
236240
raise NotImplementedError(msg)
237241

238-
def get_tiles(self, image: np.ndarray, item: DatasetItem) -> tuple[list[OTXDataEntity], list[dict]]:
242+
def get_tiles(
243+
self,
244+
image: np.ndarray,
245+
item: DatasetItem,
246+
parent_idx: int,
247+
) -> tuple[list[OTXDataEntity], list[dict]]:
239248
"""Retrieves tiles from the given image and dataset item.
240249
241250
Args:
@@ -256,14 +265,14 @@ def get_tiles(self, image: np.ndarray, item: DatasetItem) -> tuple[list[OTXDataE
256265
with_full_img=self.tile_config.with_full_img,
257266
)
258267

259-
if item.subset == "val":
268+
if item.subset in VAL_SUBSET_NAMES:
260269
# NOTE: filter validation tiles with annotations only to avoid evaluation on empty tiles.
261270
tile_ds = tile_ds.filter("/item/annotation", filter_annotations=True, remove_empty=True)
262271

263272
tile_entities: list[OTXDataEntity] = []
264273
tile_attrs: list[dict] = []
265274
for tile in tile_ds:
266-
tile_entity = self._convert_entity(image, tile)
275+
tile_entity = self._convert_entity(image, tile, parent_idx)
267276
# apply the same transforms as the original dataset
268277
transformed_tile = self._apply_transforms(tile_entity)
269278
if transformed_tile is None:
@@ -346,7 +355,7 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr
346355
)
347356
labels = torch.as_tensor([ann.label for ann in bbox_anns])
348357

349-
tile_entities, tile_attrs = self.get_tiles(img_data, item)
358+
tile_entities, tile_attrs = self.get_tiles(img_data, item, index)
350359

351360
return TileDetDataEntity(
352361
num_tiles=len(tile_entities),
@@ -365,13 +374,13 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr
365374
ori_labels=labels,
366375
)
367376

368-
def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem) -> DetDataEntity:
377+
def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> DetDataEntity:
369378
"""Convert a tile datumaro dataset item to DetDataEntity."""
370379
x1, y1, w, h = dataset_item.attributes["roi"]
371380
tile_img = image[y1 : y1 + h, x1 : x1 + w]
372381
tile_shape = tile_img.shape[:2]
373382
img_info = ImageInfo(
374-
img_idx=dataset_item.attributes["id"],
383+
img_idx=parent_idx,
375384
img_shape=tile_shape,
376385
ori_shape=tile_shape,
377386
)
@@ -448,7 +457,7 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
448457
masks = np.stack(gt_masks, axis=0) if gt_masks else np.zeros((0, *img_shape), dtype=bool)
449458
labels = np.array(gt_labels, dtype=np.int64)
450459

451-
tile_entities, tile_attrs = self.get_tiles(img_data, item)
460+
tile_entities, tile_attrs = self.get_tiles(img_data, item, index)
452461

453462
return TileInstSegDataEntity(
454463
num_tiles=len(tile_entities),
@@ -469,13 +478,13 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
469478
ori_polygons=gt_polygons,
470479
)
471480

472-
def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem) -> InstanceSegDataEntity:
481+
def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> InstanceSegDataEntity:
473482
"""Convert a tile dataset item to InstanceSegDataEntity."""
474483
x1, y1, w, h = dataset_item.attributes["roi"]
475484
tile_img = image[y1 : y1 + h, x1 : x1 + w]
476485
tile_shape = tile_img.shape[:2]
477486
img_info = ImageInfo(
478-
img_idx=dataset_item.attributes["id"],
487+
img_idx=parent_idx,
479488
img_shape=tile_shape,
480489
ori_shape=tile_shape,
481490
)

0 commit comments

Comments
 (0)