5353# NOTE: Disable private-member-access (SLF001).
5454# This is a workaround so we could apply the same transforms to tiles as the original dataset.
5555
56+ # NOTE: Datumaro subset name should be standardized.
57+ TRAIN_SUBSET_NAMES = ("train" , "TRAINING" )
58+ VAL_SUBSET_NAMES = ("val" , "VALIDATION" )
59+
5660
5761class OTXTileTransform (Tile ):
5862 """OTX tile transform.
@@ -188,7 +192,7 @@ def create(
188192 Returns:
189193 OTXTileDataset: Tile dataset.
190194 """
191- if dataset .dm_subset [0 ].subset == "train" :
195+ if dataset .dm_subset [0 ].subset in TRAIN_SUBSET_NAMES :
192196 return OTXTileTrainDataset (dataset , tile_config )
193197
194198 if task == OTXTaskType .DETECTION :
@@ -230,12 +234,17 @@ def _get_item_impl(self, index: int) -> OTXDataEntity | None:
230234 """Get item implementation from the original dataset."""
231235 return self ._dataset ._get_item_impl (index )
232236
233- def _convert_entity (self , image : np .ndarray , dataset_item : DatasetItem ) -> OTXDataEntity :
237+ def _convert_entity (self , image : np .ndarray , dataset_item : DatasetItem , parent_idx : int ) -> OTXDataEntity :
234238 """Convert a tile dataset item to OTXDataEntity."""
235239 msg = "Method _convert_entity is not implemented."
236240 raise NotImplementedError (msg )
237241
238- def get_tiles (self , image : np .ndarray , item : DatasetItem ) -> tuple [list [OTXDataEntity ], list [dict ]]:
242+ def get_tiles (
243+ self ,
244+ image : np .ndarray ,
245+ item : DatasetItem ,
246+ parent_idx : int ,
247+ ) -> tuple [list [OTXDataEntity ], list [dict ]]:
239248 """Retrieves tiles from the given image and dataset item.
240249
241250 Args:
@@ -256,14 +265,14 @@ def get_tiles(self, image: np.ndarray, item: DatasetItem) -> tuple[list[OTXDataE
256265 with_full_img = self .tile_config .with_full_img ,
257266 )
258267
259- if item .subset == "val" :
268+ if item .subset in VAL_SUBSET_NAMES :
260269 # NOTE: filter validation tiles with annotations only to avoid evaluation on empty tiles.
261270 tile_ds = tile_ds .filter ("/item/annotation" , filter_annotations = True , remove_empty = True )
262271
263272 tile_entities : list [OTXDataEntity ] = []
264273 tile_attrs : list [dict ] = []
265274 for tile in tile_ds :
266- tile_entity = self ._convert_entity (image , tile )
275+ tile_entity = self ._convert_entity (image , tile , parent_idx )
267276 # apply the same transforms as the original dataset
268277 transformed_tile = self ._apply_transforms (tile_entity )
269278 if transformed_tile is None :
@@ -346,7 +355,7 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr
346355 )
347356 labels = torch .as_tensor ([ann .label for ann in bbox_anns ])
348357
349- tile_entities , tile_attrs = self .get_tiles (img_data , item )
358+ tile_entities , tile_attrs = self .get_tiles (img_data , item , index )
350359
351360 return TileDetDataEntity (
352361 num_tiles = len (tile_entities ),
@@ -365,13 +374,13 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr
365374 ori_labels = labels ,
366375 )
367376
368- def _convert_entity (self , image : np .ndarray , dataset_item : DatasetItem ) -> DetDataEntity :
377+ def _convert_entity (self , image : np .ndarray , dataset_item : DatasetItem , parent_idx : int ) -> DetDataEntity :
369378 """Convert a tile datumaro dataset item to DetDataEntity."""
370379 x1 , y1 , w , h = dataset_item .attributes ["roi" ]
371380 tile_img = image [y1 : y1 + h , x1 : x1 + w ]
372381 tile_shape = tile_img .shape [:2 ]
373382 img_info = ImageInfo (
374- img_idx = dataset_item . attributes [ "id" ] ,
383+ img_idx = parent_idx ,
375384 img_shape = tile_shape ,
376385 ori_shape = tile_shape ,
377386 )
@@ -448,7 +457,7 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
448457 masks = np .stack (gt_masks , axis = 0 ) if gt_masks else np .zeros ((0 , * img_shape ), dtype = bool )
449458 labels = np .array (gt_labels , dtype = np .int64 )
450459
451- tile_entities , tile_attrs = self .get_tiles (img_data , item )
460+ tile_entities , tile_attrs = self .get_tiles (img_data , item , index )
452461
453462 return TileInstSegDataEntity (
454463 num_tiles = len (tile_entities ),
@@ -469,13 +478,13 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
469478 ori_polygons = gt_polygons ,
470479 )
471480
472- def _convert_entity (self , image : np .ndarray , dataset_item : DatasetItem ) -> InstanceSegDataEntity :
481+ def _convert_entity (self , image : np .ndarray , dataset_item : DatasetItem , parent_idx : int ) -> InstanceSegDataEntity :
473482 """Convert a tile dataset item to InstanceSegDataEntity."""
474483 x1 , y1 , w , h = dataset_item .attributes ["roi" ]
475484 tile_img = image [y1 : y1 + h , x1 : x1 + w ]
476485 tile_shape = tile_img .shape [:2 ]
477486 img_info = ImageInfo (
478- img_idx = dataset_item . attributes [ "id" ] ,
487+ img_idx = parent_idx ,
479488 img_shape = tile_shape ,
480489 ori_shape = tile_shape ,
481490 )
0 commit comments