diff --git a/library/pyproject.toml b/library/pyproject.toml index a2d497e13b..dc9526aeea 100644 --- a/library/pyproject.toml +++ b/library/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "datumaro==1.10.0", + "datumaro[experimental] @ git+https://github.com/open-edge-platform/datumaro.git@develop", "omegaconf==2.3.0", "rich==14.0.0", "jsonargparse==4.35.0", @@ -37,7 +37,6 @@ dependencies = [ "docstring_parser==0.16", # CLI help-formatter "rich_argparse==1.7.0", # CLI help-formatter "einops==0.8.1", - "decord==0.6.0", "typeguard>=4.3,<4.5", # TODO(ashwinvaidya17): https://github.com/openvinotoolkit/anomalib/issues/2126 "setuptools<70", @@ -51,6 +50,8 @@ dependencies = [ "onnxconverter-common==1.14.0", "nncf==2.17.0", "anomalib[core]==1.1.3", + "numpy<2.0.0", + "tensorboardX>=1.8", ] [project.optional-dependencies] diff --git a/library/src/otx/backend/native/callbacks/gpu_mem_monitor.py b/library/src/otx/backend/native/callbacks/gpu_mem_monitor.py index 4d7d638810..dcea0d5b36 100644 --- a/library/src/otx/backend/native/callbacks/gpu_mem_monitor.py +++ b/library/src/otx/backend/native/callbacks/gpu_mem_monitor.py @@ -29,7 +29,7 @@ def _get_and_log_device_stats( batch_size (int): batch size. """ device = trainer.strategy.root_device - if device.type in ["cpu", "xpu"]: + if device.type in ["cpu", "xpu", "mps"]: return device_stats = trainer.accelerator.get_device_stats(device) diff --git a/library/src/otx/backend/native/models/__init__.py b/library/src/otx/backend/native/models/__init__.py index 94632e335c..dd8c51f704 100644 --- a/library/src/otx/backend/native/models/__init__.py +++ b/library/src/otx/backend/native/models/__init__.py @@ -3,6 +3,11 @@ """Module for OTX custom models.""" +import multiprocessing + +if multiprocessing.get_start_method(allow_none=True) is None: + multiprocessing.set_start_method("forkserver") + from .anomaly import Padim, Stfpm, Uflow from .classification import ( EfficientNet, diff --git a/library/src/otx/backend/native/models/detection/base.py b/library/src/otx/backend/native/models/detection/base.py index b23b08886a..80cd657b13 100644 --- a/library/src/otx/backend/native/models/detection/base.py +++ b/library/src/otx/backend/native/models/detection/base.py @@ -33,6 +33,7 @@ from otx.types.task import OTXTaskType if TYPE_CHECKING: + from datumaro.experimental.fields import TileInfo from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from otx.backend.native.models.detection.detectors import SingleStageDetector @@ -262,21 +263,21 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity) -> OTXPredBatch: DetBatchPredEntity: Merged detection prediction. """ tile_preds: list[OTXPredBatch] = [] - tile_attrs: list[list[dict[str, int | str]]] = [] + tile_infos: list[list[TileInfo]] = [] merger = DetectionTileMerge( inputs.imgs_info, self.num_classes, self.tile_config, self.explain_mode, ) - for batch_tile_attrs, batch_tile_input in inputs.unbind(): + for batch_tile_infos, batch_tile_input in inputs.unbind(): output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input) if isinstance(output, OTXBatchLossEntity): msg = "Loss output is not supported for tile merging" raise TypeError(msg) tile_preds.append(output) - tile_attrs.append(batch_tile_attrs) - pred_entities = merger.merge(tile_preds, tile_attrs) + tile_infos.append(batch_tile_infos) + pred_entities = merger.merge(tile_preds, tile_infos) pred_entity = OTXPredBatch( batch_size=inputs.batch_size, diff --git a/library/src/otx/backend/native/models/detection/ssd.py b/library/src/otx/backend/native/models/detection/ssd.py index 6aa112c02b..b0343dc7d6 100644 --- a/library/src/otx/backend/native/models/detection/ssd.py +++ b/library/src/otx/backend/native/models/detection/ssd.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal import numpy as np -from datumaro.components.annotation import Bbox +from datumaro.experimental.dataset import Dataset as DmDataset from otx.backend.native.exporter.base import OTXModelExporter from otx.backend.native.exporter.native import OTXNativeModelExporter @@ -30,6 +30,7 @@ from otx.backend.native.models.utils.support_otx_v1 import OTXv1Helper from otx.backend.native.models.utils.utils import load_checkpoint from otx.config.data import TileConfig +from otx.data.entity.sample import DetectionSample from otx.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable if TYPE_CHECKING: @@ -231,7 +232,7 @@ def _get_new_anchors(self, dataset: OTXDataset, anchor_generator: SSDAnchorGener return self._get_anchor_boxes(wh_stats, group_as) @staticmethod - def _get_sizes_from_dataset_entity(dataset: OTXDataset, target_wh: list[int]) -> list[tuple[int, int]]: + def _get_sizes_from_dataset_entity(dataset: OTXDataset, target_wh: list[int]) -> np.ndarray: """Function to get width and height size of items in OTXDataset. Args: @@ -240,20 +241,34 @@ def _get_sizes_from_dataset_entity(dataset: OTXDataset, target_wh: list[int]) -> Return list[tuple[int, int]]: tuples with width and height of each instance """ - wh_stats: list[tuple[int, int]] = [] + wh_stats = np.empty((0, 2), dtype=np.float32) + if not isinstance(dataset.dm_subset, DmDataset): + exc_str = "The variable dataset.dm_subset must be an instance of DmDataset" + raise TypeError(exc_str) + for item in dataset.dm_subset: - for ann in item.annotations: - if isinstance(ann, Bbox): - x1, y1, x2, y2 = ann.points - x1 = x1 / item.media.size[1] * target_wh[0] - y1 = y1 / item.media.size[0] * target_wh[1] - x2 = x2 / item.media.size[1] * target_wh[0] - y2 = y2 / item.media.size[0] * target_wh[1] - wh_stats.append((x2 - x1, y2 - y1)) + if not isinstance(item, DetectionSample): + exc_str = "The variable item must be an instance of DetectionSample" + raise TypeError(exc_str) + + if item.img_info is None: + exc_str = "The image info must not be None" + raise RuntimeError(exc_str) + + height, width = item.img_info.img_shape + x1 = item.bboxes[:, 0] + y1 = item.bboxes[:, 1] + x2 = item.bboxes[:, 2] + y2 = item.bboxes[:, 3] + + w = (x2 - x1) / width * target_wh[0] + h = (y2 - y1) / height * target_wh[1] + + wh_stats = np.concatenate((wh_stats, np.stack((w, h), axis=1)), axis=0) return wh_stats @staticmethod - def _get_anchor_boxes(wh_stats: list[tuple[int, int]], group_as: list[int]) -> tuple: + def _get_anchor_boxes(wh_stats: np.ndarray, group_as: list[int]) -> tuple: """Get new anchor box widths & heights using KMeans.""" from sklearn.cluster import KMeans diff --git a/library/src/otx/backend/native/models/instance_segmentation/base.py b/library/src/otx/backend/native/models/instance_segmentation/base.py index 26c37a8a35..3f94e3626c 100644 --- a/library/src/otx/backend/native/models/instance_segmentation/base.py +++ b/library/src/otx/backend/native/models/instance_segmentation/base.py @@ -40,6 +40,7 @@ from otx.types.task import OTXTaskType if TYPE_CHECKING: + from datumaro.experimental.fields import TileInfo from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from torch import nn @@ -208,21 +209,21 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity) -> OTXPredBatch: TorchPredBatch: Merged instance segmentation prediction. """ tile_preds: list[OTXPredBatch] = [] - tile_attrs: list[list[dict[str, int | str]]] = [] + tile_infos: list[list[TileInfo]] = [] merger = InstanceSegTileMerge( inputs.imgs_info, self.num_classes, self.tile_config, self.explain_mode, ) - for batch_tile_attrs, batch_tile_input in inputs.unbind(): + for batch_tile_infos, batch_tile_input in inputs.unbind(): output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input) if isinstance(output, OTXBatchLossEntity): msg = "Loss output is not supported for tile merging" raise TypeError(msg) tile_preds.append(output) - tile_attrs.append(batch_tile_attrs) - pred_entities = merger.merge(tile_preds, tile_attrs) + tile_infos.append(batch_tile_infos) + pred_entities = merger.merge(tile_preds, tile_infos) pred_entity = OTXPredBatch( batch_size=inputs.batch_size, @@ -458,7 +459,7 @@ def _convert_pred_entity_to_compute_metric( rles = ( [encode_rle(mask) for mask in masks.data] - if len(masks) + if masks is not None else polygon_to_rle(polygons, *imgs_info.ori_shape) # type: ignore[union-attr,arg-type] ) target_info.append( diff --git a/library/src/otx/backend/native/models/instance_segmentation/heads/roi_head_tv.py b/library/src/otx/backend/native/models/instance_segmentation/heads/roi_head_tv.py index ea4c345049..8786813188 100644 --- a/library/src/otx/backend/native/models/instance_segmentation/heads/roi_head_tv.py +++ b/library/src/otx/backend/native/models/instance_segmentation/heads/roi_head_tv.py @@ -15,13 +15,13 @@ from otx.data.utils.structures.mask import mask_target if TYPE_CHECKING: - from datumaro import Polygon + import numpy as np def maskrcnn_loss( mask_logits: Tensor, proposals: list[Tensor], - gt_masks: list[list[Tensor]] | list[list[Polygon]], + gt_masks: list[list[Tensor]] | list[np.ndarray], gt_labels: list[Tensor], mask_matched_idxs: list[Tensor], image_shapes: list[tuple[int, int]], @@ -31,7 +31,7 @@ def maskrcnn_loss( Args: mask_logits (Tensor): the mask predictions. proposals (list[Tensor]): the region proposals. - gt_masks (list[list[Tensor]] | list[list[Polygon]]): the ground truth masks. + gt_masks (list[list[Tensor]] | list[np.ndarray]): the ground truth masks as ragged arrays. gt_labels (list[Tensor]): the ground truth labels. mask_matched_idxs (list[Tensor]): the matched indices. image_shapes (list[tuple[int, int]]): the image shapes. @@ -142,7 +142,9 @@ def forward( raise ValueError(msg) gt_masks = ( - [t["masks"] for t in targets] if len(targets[0]["masks"]) else [t["polygons"] for t in targets] + [t["masks"] for t in targets] + if targets[0]["masks"] is not None + else [t["polygons"] for t in targets] ) gt_labels = [t["labels"] for t in targets] rcnn_loss_mask = maskrcnn_loss( diff --git a/library/src/otx/backend/native/models/instance_segmentation/heads/rtmdet_inst_head.py b/library/src/otx/backend/native/models/instance_segmentation/heads/rtmdet_inst_head.py index cf5a8cf2cf..81f6d22210 100644 --- a/library/src/otx/backend/native/models/instance_segmentation/heads/rtmdet_inst_head.py +++ b/library/src/otx/backend/native/models/instance_segmentation/heads/rtmdet_inst_head.py @@ -18,7 +18,6 @@ import numpy as np import torch import torch.nn.functional -from datumaro import Polygon from torch import Tensor, nn from otx.backend.native.models.common.utils.nms import batched_nms, multiclass_nms @@ -644,7 +643,7 @@ def prepare_loss_inputs(self, x: tuple[Tensor], entity: OTXDataBatch) -> dict: ) # Convert polygon masks to bitmap masks - if isinstance(batch_gt_instances[0].masks[0], Polygon): + if isinstance(batch_gt_instances[0].masks, np.ndarray): for gt_instances, img_meta in zip(batch_gt_instances, batch_img_metas): ndarray_masks = polygon_to_bitmap(gt_instances.masks, *img_meta["img_shape"]) if len(ndarray_masks) == 0: diff --git a/library/src/otx/backend/native/models/instance_segmentation/rotated_det.py b/library/src/otx/backend/native/models/instance_segmentation/rotated_det.py index 10cf1d65c1..f50022963c 100644 --- a/library/src/otx/backend/native/models/instance_segmentation/rotated_det.py +++ b/library/src/otx/backend/native/models/instance_segmentation/rotated_det.py @@ -4,13 +4,27 @@ """Rotated Detection Prediction Mixin.""" import cv2 +import numpy as np import torch -from datumaro import Polygon from torchvision import tv_tensors from otx.data.entity.torch.torch import OTXPredBatch +def get_polygon_area(points: np.ndarray) -> float: + """Calculate polygon area using the shoelace formula. + + Args: + points: Array of polygon vertices with shape (N, 2) + + Returns: + float: Area of the polygon + """ + x = points[:, 0] + y = points[:, 1] + return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) + + def convert_masks_to_rotated_predictions(preds: OTXPredBatch) -> OTXPredBatch: """Convert masks to rotated bounding boxes and polygons. @@ -58,8 +72,10 @@ def convert_masks_to_rotated_predictions(preds: OTXPredBatch) -> OTXPredBatch: for contour, hierarchy in zip(contours, hierarchies[0]): if hierarchy[3] != -1 or len(contour) <= 2: continue - rbox_points = Polygon(cv2.boxPoints(cv2.minAreaRect(contour)).reshape(-1)) - rbox_polygons.append((rbox_points, rbox_points.get_area())) + # Get rotated bounding box points and convert to ragged array format + box_points = cv2.boxPoints(cv2.minAreaRect(contour)).astype(np.float32) + area = get_polygon_area(box_points) + rbox_polygons.append((box_points, area)) if rbox_polygons: rbox_polygons.sort(key=lambda x: x[1], reverse=True) diff --git a/library/src/otx/backend/native/models/instance_segmentation/utils/utils.py b/library/src/otx/backend/native/models/instance_segmentation/utils/utils.py index 3487e5662a..c3b8f6e3c9 100644 --- a/library/src/otx/backend/native/models/instance_segmentation/utils/utils.py +++ b/library/src/otx/backend/native/models/instance_segmentation/utils/utils.py @@ -53,7 +53,7 @@ def unpack_inst_seg_entity(entity: OTXDataBatch) -> tuple: } batch_img_metas.append(metainfo) - gt_masks = mask if len(mask) else polygon + gt_masks = mask if mask is not None else polygon batch_gt_instances.append( InstanceData( diff --git a/library/src/otx/backend/native/models/segmentation/base.py b/library/src/otx/backend/native/models/segmentation/base.py index 60d020b2e3..bf57369e08 100644 --- a/library/src/otx/backend/native/models/segmentation/base.py +++ b/library/src/otx/backend/native/models/segmentation/base.py @@ -31,6 +31,7 @@ from otx.types.task import OTXTaskType if TYPE_CHECKING: + from datumaro.experimental.fields import TileInfo from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from torch import Tensor @@ -223,15 +224,15 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity) -> OTXPredBatch: raise NotImplementedError(msg) tile_preds: list[OTXPredBatch] = [] - tile_attrs: list[list[dict[str, int | str]]] = [] + tile_infos: list[list[TileInfo]] = [] merger = SegmentationTileMerge( inputs.imgs_info, self.num_classes, self.tile_config, self.explain_mode, ) - for batch_tile_attrs, batch_tile_input in inputs.unbind(): - tile_size = batch_tile_attrs[0]["tile_size"] + for batch_tile_infos, batch_tile_input in inputs.unbind(): + tile_size = (batch_tile_infos[0].height, batch_tile_infos[0].width) output = self.model( inputs=batch_tile_input.images, img_metas=batch_tile_input.imgs_info, @@ -245,8 +246,8 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity) -> OTXPredBatch: msg = "Loss output is not supported for tile merging" raise TypeError(msg) tile_preds.append(output) - tile_attrs.append(batch_tile_attrs) - pred_entities = merger.merge(tile_preds, tile_attrs) + tile_infos.append(batch_tile_infos) + pred_entities = merger.merge(tile_preds, tile_infos) pred_entity = OTXPredBatch( batch_size=inputs.batch_size, diff --git a/library/src/otx/backend/native/tools/tile_merge.py b/library/src/otx/backend/native/tools/tile_merge.py index 1ee91356dc..63de4659d1 100644 --- a/library/src/otx/backend/native/tools/tile_merge.py +++ b/library/src/otx/backend/native/tools/tile_merge.py @@ -7,7 +7,7 @@ from abc import abstractmethod from collections import defaultdict -from typing import Callable +from typing import TYPE_CHECKING, Callable import cv2 import numpy as np @@ -20,6 +20,9 @@ from otx.config.data import TileConfig from otx.data.entity import ImageInfo, OTXPredBatch, OTXPredItem +if TYPE_CHECKING: + from datumaro.experimental.fields import TileInfo + # Maximum number of elements 2**31 -1 MAX_ELEMENTS: int = np.iinfo(np.int32).max @@ -75,7 +78,6 @@ def __init__( self.tile_size = tile_config.tile_size self.iou_threshold = tile_config.iou_threshold self.max_num_instances = tile_config.max_num_instances - self.with_full_img = tile_config.with_full_img self.explain_mode = explain_mode @abstractmethod @@ -137,7 +139,7 @@ class DetectionTileMerge(TileMerge): def merge( self, batch_tile_preds: list[OTXPredBatch], - batch_tile_attrs: list[list[dict]], + batch_tile_attrs: list[list[TileInfo]], ) -> list[OTXPredItem]: """Merge batch tile predictions to a list of full-size prediction data entities. @@ -165,15 +167,16 @@ def merge( tile_f_vect = tile_preds.feature_vector[i] if tile_preds.feature_vector is not None else None tile_bboxes = tile_preds.bboxes[i] if tile_preds.bboxes[i].numel() > 0 else None - offset_x, offset_y, _, _ = tile_attr["roi"] + offset_x = tile_attr.x + offset_y = tile_attr.y if tile_bboxes is not None: tile_bboxes[:, 0::2] += offset_x tile_bboxes[:, 1::2] += offset_y - tile_id = tile_attr["tile_id"] + tile_id = tile_attr.source_sample_idx if tile_id not in img_ids: img_ids.append(tile_id) - tile_img_info.padding = tile_attr["roi"] # type: ignore[union-attr] + tile_img_info.padding = [tile_attr.x, tile_attr.y, tile_attr.width, tile_attr.height] # type: ignore[union-attr] det_pred_entity = OTXPredItem( image=torch.empty(3, *tile_img_info.ori_shape), # type: ignore[union-attr] @@ -285,10 +288,7 @@ def _merge_saliency_maps( image_map_w = int(image_w * ratio[1]) merged_map = np.zeros((num_classes, image_map_h, image_map_w)) - # Note: Skip the first saliency map as it is the full image value. - saliency_maps, start_idx = (saliency_maps[1:], 1) if self.with_full_img else (saliency_maps, 0) - - for i, saliency_map in enumerate(saliency_maps, start_idx): + for i, saliency_map in enumerate(saliency_maps): for class_idx in range(num_classes): cls_map = saliency_map[class_idx] @@ -314,11 +314,6 @@ def _merge_saliency_maps( merged_map[class_idx][y_1 + hi, x_1 + wi] = map_pixel for class_idx in range(num_classes): - if self.with_full_img: - image_map_cls = image_saliency_map[class_idx] - image_map_cls = cv2.resize(image_map_cls, (image_map_w, image_map_h)) - merged_map[class_idx] += 0.5 * image_map_cls - merged_map[class_idx] = _non_linear_normalization(merged_map[class_idx]) return merged_map.astype(np.uint8) @@ -339,26 +334,21 @@ def _non_linear_normalization(saliency_map: np.ndarray) -> np.ndarray: class InstanceSegTileMerge(TileMerge): """Instance segmentation tile merge.""" - def merge( - self, - batch_tile_preds: list[OTXPredBatch], - batch_tile_attrs: list[list[dict]], - ) -> list[OTXPredItem]: + def merge(self, batch_tile_preds: list[OTXPredBatch], batch_tile_infos: list[list[TileInfo]]) -> list[OTXPredItem]: """Merge inst-seg tile predictions to one single prediction. Args: batch_tile_preds (list): instance-seg tile predictions. - batch_tile_attrs (list): instance-seg tile attributes. """ entities_to_merge = defaultdict(list) img_ids = [] explain_mode = self.explain_mode - for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True): - feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(len(tile_attrs))] - for i in range(len(tile_attrs)): - tile_attr = tile_attrs[i] + for tile_preds, tile_infos in zip(batch_tile_preds, batch_tile_infos, strict=True): + feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(len(tile_infos))] + for i in range(len(tile_infos)): + tile_info = tile_infos[i] tile_img_info = tile_preds.imgs_info[i] if tile_preds.imgs_info is not None else None tile_bboxes = tile_preds.bboxes[i] if tile_preds.bboxes is not None else None tile_labels = tile_preds.labels[i] if tile_preds.labels is not None else None @@ -376,14 +366,15 @@ def merge( _scores = tile_scores[keep_indices] _masks = tile_masks[keep_indices] - offset_x, offset_y, _, _ = tile_attr["roi"] + offset_x = tile_info.x + offset_y = tile_info.y _bboxes[:, 0::2] += offset_x _bboxes[:, 1::2] += offset_y - tile_id = tile_attr["tile_id"] + tile_id = tile_info.source_sample_idx if tile_id not in img_ids: img_ids.append(tile_id) - tile_img_info.padding = tile_attr["roi"] # type: ignore[union-attr] + tile_img_info.padding = [tile_info.x, tile_info.y, tile_info.width, tile_info.height] # type: ignore[union-attr] inst_seg_pred_entity = OTXPredItem( image=torch.empty(3, *tile_img_info.ori_shape), # type: ignore[union-attr] @@ -508,7 +499,7 @@ def __init__( def merge( self, batch_tile_preds: list[OTXPredBatch], - batch_tile_attrs: list[list[dict]], + batch_tile_attrs: list[list[TileInfo]], ) -> list[OTXPredItem]: """Merge batch tile predictions to a list of full-size prediction data entities. @@ -548,10 +539,10 @@ def merge( msg = f"Image information is not provided : {tile_preds.imgs_info}." raise ValueError(msg) - tile_id = tile_attr["tile_id"] + tile_id = tile_attr.source_sample_idx if tile_id not in img_ids: img_ids.append(tile_id) - tile_img_info.padding = tile_attr["roi"] + tile_img_info.padding = (tile_attr.x, tile_attr.y, tile_attr.width, tile_attr.height) seg_pred_entity = OTXPredItem( image=torch.empty((3, *tile_img_info.ori_shape)), img_info=tile_img_info, diff --git a/library/src/otx/backend/native/utils/utils.py b/library/src/otx/backend/native/utils/utils.py index 593d1f261f..4edbf960d4 100644 --- a/library/src/otx/backend/native/utils/utils.py +++ b/library/src/otx/backend/native/utils/utils.py @@ -80,8 +80,8 @@ def mock_modules_for_chkpt() -> Iterator[None]: setattr(sys.modules["otx.types.task"], "OTXTrainType", OTXTrainType) # noqa: B010 sys.modules["otx.core"] = types.ModuleType("otx.core") - sys.modules["otx.core.config"] = otx.config - sys.modules["otx.core.config.data"] = otx.config.data + sys.modules["otx.core.config"] = otx.config # type: ignore[attr-defined] + sys.modules["otx.core.config.data"] = otx.config.data # type: ignore[attr-defined] sys.modules["otx.core.types"] = otx.types sys.modules["otx.core.types.task"] = otx.types.task sys.modules["otx.core.types.label"] = otx.types.label diff --git a/library/src/otx/config/data.py b/library/src/otx/config/data.py index 7c4192fd24..d69a8495d3 100644 --- a/library/src/otx/config/data.py +++ b/library/src/otx/config/data.py @@ -78,7 +78,6 @@ class TileConfig: max_num_instances: int = 1500 object_tile_ratio: float = 0.03 sampling_ratio: float = 1.0 - with_full_img: bool = False def clone(self) -> TileConfig: """Return a deep copied one of this instance.""" diff --git a/library/src/otx/data/dataset/__init__.py b/library/src/otx/data/dataset/__init__.py index 3150763ca5..05cc07ea3d 100644 --- a/library/src/otx/data/dataset/__init__.py +++ b/library/src/otx/data/dataset/__init__.py @@ -9,7 +9,7 @@ from .instance_segmentation import OTXInstanceSegDataset from .keypoint_detection import OTXKeypointDetectionDataset from .segmentation import OTXSegmentationDataset -from .tile import OTXTileDatasetFactory +from .tile_new import OTXTileDatasetFactory __all__ = [ "OTXAnomalyDataset", diff --git a/library/src/otx/data/dataset/anomaly_new.py b/library/src/otx/data/dataset/anomaly_new.py new file mode 100644 index 0000000000..8af1188c30 --- /dev/null +++ b/library/src/otx/data/dataset/anomaly_new.py @@ -0,0 +1,36 @@ +# Copyright (C) 2023-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for OTXSegmentationDataset.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from datumaro.experimental.categories import LabelCategories, LabelSemantic + +from otx.data.dataset.base_new import OTXDataset +from otx.data.entity.sample import AnomalySample +from otx.types.label import AnomalyLabelInfo +from otx.types.task import OTXTaskType + +if TYPE_CHECKING: + from datumaro.experimental import Dataset + + +class OTXAnomalyDataset(OTXDataset): + """OTXDataset class for anomaly task.""" + + def __init__(self, task_type: OTXTaskType, dm_subset: Dataset, **kwargs) -> None: + self.task_type = task_type + sample_type = AnomalySample + categories = { + "label": LabelCategories( + labels=["normal", "anomalous"], + label_semantics={LabelSemantic.NORMAL: "normal", LabelSemantic.ANOMALOUS: "anomalous"}, + ) + } + dm_subset = dm_subset.convert_to_schema(sample_type, target_categories=categories) + super().__init__(dm_subset=dm_subset, sample_type=sample_type, **kwargs) + + self.label_info = AnomalyLabelInfo() diff --git a/library/src/otx/data/dataset/base.py b/library/src/otx/data/dataset/base.py index 501114f4fc..4f7146583b 100644 --- a/library/src/otx/data/dataset/base.py +++ b/library/src/otx/data/dataset/base.py @@ -6,13 +6,14 @@ from __future__ import annotations from abc import abstractmethod +from collections import defaultdict from collections.abc import Iterable from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Union import cv2 import numpy as np -from datumaro.components.annotation import AnnotationType +from datumaro.components.annotation import AnnotationType, LabelCategories from datumaro.util.image import IMAGE_BACKEND, IMAGE_COLOR_CHANNEL, ImageBackend from datumaro.util.image import ImageColorChannel as DatumaroImageColorChannel from torch.utils.data import Dataset @@ -196,3 +197,23 @@ def _get_item_impl(self, idx: int) -> OTXDataItem | None: def collate_fn(self) -> Callable: """Collection function to collect KeypointDetDataEntity into KeypointDetBatchDataEntity in data loader.""" return OTXDataItem.collate_fn + + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int | str, list[int]]: + """Get a dictionary mapping class labels (string or int) to lists of samples. + + Args: + use_string_label (bool): If True, use string class labels as keys. + If False, use integer indices as keys. + """ + stats: dict[int | str, list[int]] = defaultdict(list) + for item_idx, item in enumerate(self.dm_subset): + for ann in item.annotations: + if use_string_label: + labels = self.dm_subset.categories().get(AnnotationType.label, LabelCategories()) + stats[labels.items[ann.label].name].append(item_idx) + else: + stats[ann.label].append(item_idx) + # Remove duplicates in label stats idx: O(n) + for k in stats: + stats[k] = list(dict.fromkeys(stats[k])) + return stats diff --git a/library/src/otx/data/dataset/base_new.py b/library/src/otx/data/dataset/base_new.py new file mode 100644 index 0000000000..30eb8a439f --- /dev/null +++ b/library/src/otx/data/dataset/base_new.py @@ -0,0 +1,154 @@ +# Copyright (C) 2023-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Base class for OTXDataset using new Datumaro experimental Dataset.""" + +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Callable, Iterable, List, Union + +import numpy as np +import torch +from torch.utils.data import Dataset as TorchDataset + +from otx import LabelInfo, NullLabelInfo + +if TYPE_CHECKING: + from datumaro.experimental import Dataset + +from otx.data.entity.sample import OTXSample +from otx.data.entity.torch.torch import OTXDataBatch +from otx.data.transform_libs.torchvision import Compose +from otx.types.image import ImageColorChannel + +Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]] + + +def _default_collate_fn(items: list[OTXSample]) -> OTXDataBatch: + """Collate OTXSample items into an OTXDataBatch. + + Args: + items: List of OTXSample items to batch + Returns: + Batched OTXSample items with stacked tensors + """ + # Convert images to float32 tensors before stacking + image_tensors = [] + for item in items: + img = item.image + if isinstance(img, torch.Tensor): + # Convert to float32 if not already + if img.dtype != torch.float32: + img = img.float() + else: + # Convert numpy array to float32 tensor + img = torch.from_numpy(img).float() + image_tensors.append(img) + + # Try to stack images if they have the same shape + if len(image_tensors) > 0 and all(t.shape == image_tensors[0].shape for t in image_tensors): + images = torch.stack(image_tensors) + else: + images = image_tensors + + return OTXDataBatch( + batch_size=len(items), + images=images, + labels=[item.label for item in items] if items[0].label is not None else None, + masks=[item.masks for item in items] if any(item.masks is not None for item in items) else None, + bboxes=[item.bboxes for item in items] if any(item.bboxes is not None for item in items) else None, + keypoints=[item.keypoints for item in items] if any(item.keypoints is not None for item in items) else None, + polygons=[item.polygons for item in items if item.polygons is not None] + if any(item.polygons is not None for item in items) + else None, + imgs_info=[item.img_info for item in items] if any(item.img_info is not None for item in items) else None, + ) + + +class OTXDataset(TorchDataset): + """Base OTXDataset using new Datumaro experimental Dataset. + + Defines basic logic for OTX datasets. + + Args: + transforms: Transforms to apply on images + image_color_channel: Color channel of images + stack_images: Whether or not to stack images in collate function in OTXBatchData entity. + sample_type: Type of sample to use for this dataset + """ + + def __init__( + self, + dm_subset: Dataset, + transforms: Transforms, + max_refetch: int = 1000, + image_color_channel: ImageColorChannel = ImageColorChannel.RGB, + stack_images: bool = True, + to_tv_image: bool = True, + data_format: str = "", + sample_type: type[OTXSample] = OTXSample, + ) -> None: + self.transforms = transforms + self.image_color_channel = image_color_channel + self.stack_images = stack_images + self.to_tv_image = to_tv_image + self.sample_type = sample_type + self.max_refetch = max_refetch + self.data_format = data_format + self.label_info: LabelInfo = NullLabelInfo() + self.dm_subset = dm_subset + + def __len__(self) -> int: + return len(self.dm_subset) + + def _sample_another_idx(self) -> int: + return np.random.randint(0, len(self)) + + def _apply_transforms(self, entity: OTXSample) -> OTXSample | None: + if isinstance(self.transforms, Compose): + return self.transforms(entity) + if isinstance(self.transforms, Iterable): + return self._iterable_transforms(entity) + if callable(self.transforms): + return self.transforms(entity) + return None + + def _iterable_transforms(self, item: OTXSample) -> OTXSample | None: + if not isinstance(self.transforms, list): + raise TypeError(item) + + results = item + for transform in self.transforms: + results = transform(results) + # MMCV transform can produce None. Please see + # https://github.com/open-mmlab/mmengine/blob/26f22ed283ae4ac3a24b756809e5961efe6f9da8/mmengine/dataset/base_dataset.py#L59-L66 + if results is None: + return None + + return results + + def __getitem__(self, index: int) -> OTXSample: + for _ in range(self.max_refetch): + results = self._get_item_impl(index) + + if results is not None: + return results + + index = self._sample_another_idx() + + msg = f"Reach the maximum refetch number ({self.max_refetch})" + raise RuntimeError(msg) + + def _get_item_impl(self, index: int) -> OTXSample | None: + dm_item = self.dm_subset[index] + return self._apply_transforms(dm_item) + + @property + def collate_fn(self) -> Callable: + """Collection function to collect samples into a batch in data loader.""" + return _default_collate_fn + + @abc.abstractmethod + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: + """Get a dictionary with class labels as keys and lists of corresponding sample indices as values.""" diff --git a/library/src/otx/data/dataset/classification_new.py b/library/src/otx/data/dataset/classification_new.py new file mode 100644 index 0000000000..f753e68183 --- /dev/null +++ b/library/src/otx/data/dataset/classification_new.py @@ -0,0 +1,211 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for OTXClassificationDatasets using new Datumaro experimental Dataset.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from datumaro import Label +from torch.nn import functional +from torchvision.transforms.v2.functional import to_dtype, to_image + +from otx import HLabelInfo, LabelInfo +from otx.data.dataset.base_new import OTXDataset +from otx.data.entity.sample import ( + ClassificationHierarchicalSample, + ClassificationMultiLabelSample, + ClassificationSample, +) + +if TYPE_CHECKING: + from datumaro.experimental import Dataset + + +class OTXMulticlassClsDataset(OTXDataset): + """OTXDataset class for multi-class classification task using new Datumaro experimental Dataset.""" + + def __init__(self, dm_subset: Dataset, **kwargs) -> None: + """Initialize OTXMulticlassClsDataset. + + Args: + **kwargs: Keyword arguments to pass to OTXDataset + """ + sample_type = ClassificationSample + dm_subset = dm_subset.convert_to_schema(sample_type) + super().__init__(dm_subset=dm_subset, sample_type=sample_type, **kwargs) + + labels = dm_subset.schema.attributes["label"].categories.labels + self.label_info = LabelInfo( + label_names=labels, + label_groups=[labels], + label_ids=[str(i) for i in range(len(labels))], + ) + + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: + """Get a dictionary mapping class labels (string or int) to lists of samples. + + Args: + use_string_label (bool): If True, use string class labels as keys. + If False, use integer indices as keys. + """ + idx_list_per_classes: dict[int, list[int]] = {} + for idx in range(len(self)): + item = self.dm_subset[idx] + label_id = item.label.item() + if use_string_label: + label_id = self.label_info.label_names[label_id] + if label_id not in idx_list_per_classes: + idx_list_per_classes[label_id] = [] + idx_list_per_classes[label_id].append(idx) + return idx_list_per_classes + + +class OTXMultilabelClsDataset(OTXDataset): + """OTXDataset class for multi-label classification task.""" + + def __init__(self, dm_subset: Dataset, **kwargs) -> None: + sample_type = ClassificationMultiLabelSample + dm_subset = dm_subset.convert_to_schema(sample_type) + super().__init__(dm_subset=dm_subset, sample_type=sample_type, **kwargs) + + labels = dm_subset.schema.attributes["label"].categories.labels + self.label_info = LabelInfo( + label_names=labels, + label_groups=[labels], + label_ids=[str(i) for i in range(len(labels))], + ) + self.num_classes = len(labels) + + def _get_item_impl(self, index: int) -> ClassificationMultiLabelSample | None: + item = self.dm_subset[index] + item.image = to_dtype(to_image(item.image), dtype=torch.float32) + item.label = self._convert_to_onehot(torch.as_tensor(list(item.label)), ignored_labels=[]) + return self._apply_transforms(item) + + def _convert_to_onehot(self, labels: torch.tensor, ignored_labels: list[int]) -> torch.tensor: + """Convert label to one-hot vector format.""" + # Torch's one_hot() expects the input to be of type long + # However, when labels are empty, they are of type float32 + onehot = functional.one_hot(labels.long(), self.num_classes).sum(0).clamp_max_(1) + if ignored_labels: + for ignore_label in ignored_labels: + onehot[ignore_label] = -1 + return onehot + + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: + """Get a dictionary mapping class labels (string or int) to lists of samples. + + Args: + use_string_label (bool): If True, use string class labels as keys. + If False, use integer indices as keys. + """ + idx_list_per_classes: dict[int, list[int]] = {} + for idx in range(len(self)): + item = self.dm_subset[idx] + labels = item.label.tolist() + if use_string_label: + labels = [self.label_info.label_names[label] for label in labels] + for label in labels: + if label not in idx_list_per_classes: + idx_list_per_classes[label] = [] + idx_list_per_classes[label].append(idx) + return idx_list_per_classes + + +class OTXHlabelClsDataset(OTXDataset): + """OTXDataset class for H-label classification task.""" + + def __init__(self, dm_subset: Dataset, **kwargs) -> None: + sample_type = ClassificationHierarchicalSample + dm_subset = dm_subset.convert_to_schema(sample_type) + super().__init__(dm_subset=dm_subset, sample_type=sample_type, **kwargs) + if self.data_format != "arrow": + raise ValueError("The data format should be arrow.") # noqa: EM101, TRY003 + self.dm_categories = self.dm_subset.schema.attributes["label"].categories + self.label_info = HLabelInfo.from_dm_label_groups_arrow(self.dm_categories) + + self.id_to_name_mapping = dict(zip(self.label_info.label_ids, self.label_info.label_names)) + self.id_to_name_mapping[""] = "" + + if self.label_info.num_multiclass_heads == 0: + msg = "The number of multiclass heads should be larger than 0." + raise ValueError(msg) + + def _get_item_impl(self, index: int) -> ClassificationHierarchicalSample | None: + item = self.dm_subset[index] + item.image = to_dtype(to_image(item.image), dtype=torch.float32) + item.label = torch.as_tensor(self._convert_label_to_hlabel_format([Label(label=idx) for idx in item.label], [])) + return self._apply_transforms(item) + + def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_labels: list[int]) -> list[int]: + """Convert format of the label to the h-label. + + It converts the label format to h-label format. + Total length of result is sum of number of hierarchy and number of multilabel classes. + + i.e. + Let's assume that we used the same dataset with example of the definition of HLabelData + and the original labels are ["Rigid", "Triangle", "Lion"]. + + Then, h-label format will be [0, 1, 1, 0]. + The first N-th indices represent the label index of multiclass heads (N=num_multiclass_heads), + others represent the multilabel labels. + + [Multiclass Heads] + 0-th index = 0 -> ["Rigid"(O), "Non-Rigid"(X)] <- First multiclass head + 1-st index = 1 -> ["Rectangle"(O), "Triangle"(X), "Circle"(X)] <- Second multiclass head + + [Multilabel Head] + 2, 3 indices = [1, 0] -> ["Lion"(O), "Panda"(X)] + """ + if not isinstance(self.label_info, HLabelInfo): + msg = f"The type of label_info should be HLabelInfo, got {type(self.label_info)}." + raise TypeError(msg) + + num_multiclass_heads = self.label_info.num_multiclass_heads + num_multilabel_classes = self.label_info.num_multilabel_classes + + class_indices = [0] * (num_multiclass_heads + num_multilabel_classes) + for i in range(num_multiclass_heads): + class_indices[i] = -1 + + for ann in label_anns: + if self.data_format == "arrow": + # skips unknown labels for instance, the empty one + if self.dm_categories.items[ann.label].name not in self.id_to_name_mapping: + continue + ann_name = self.id_to_name_mapping[self.dm_categories.items[ann.label].name] + else: + ann_name = self.dm_categories.items[ann.label].name + group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name] + + if group_idx < num_multiclass_heads: + class_indices[group_idx] = in_group_idx + elif ann.label not in ignored_labels: + class_indices[num_multiclass_heads + in_group_idx] = 1 + else: + class_indices[num_multiclass_heads + in_group_idx] = -1 + + return class_indices + + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: + """Get a dictionary mapping class labels (string or int) to lists of samples. + + Args: + use_string_label (bool): If True, use string class labels as keys. + If False, use integer indices as keys. + """ + idx_list_per_classes: dict[int, list[int]] = {} + for idx in range(len(self)): + item = self.dm_subset[idx] + labels = item.label.tolist() + if use_string_label: + labels = [self.label_info.label_names[label] for label in labels] + for label in labels: + if label not in idx_list_per_classes: + idx_list_per_classes[label] = [] + idx_list_per_classes[label].append(idx) + return idx_list_per_classes diff --git a/library/src/otx/data/dataset/detection_new.py b/library/src/otx/data/dataset/detection_new.py new file mode 100644 index 0000000000..0008bd0734 --- /dev/null +++ b/library/src/otx/data/dataset/detection_new.py @@ -0,0 +1,56 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for OTXDetectionDataset.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from otx.data.entity.sample import DetectionSample +from otx.types.label import LabelInfo + +from .base_new import OTXDataset + +if TYPE_CHECKING: + from datumaro.experimental import Dataset + + +class OTXDetectionDataset(OTXDataset): + """OTXDataset class for detection task using new Datumaro experimental Dataset.""" + + def __init__(self, dm_subset: Dataset, **kwargs) -> None: + """Initialize _OTXDetectionDataset. + + Args: + **kwargs: Keyword arguments to pass to OTXDataset + """ + sample_type = DetectionSample + dm_subset = dm_subset.convert_to_schema(sample_type) + super().__init__(dm_subset=dm_subset, sample_type=sample_type, **kwargs) + + labels = list(dm_subset.schema.attributes["label"].categories.labels) + self.label_info = LabelInfo( + label_names=labels, + label_groups=[labels], + label_ids=[str(i) for i in range(len(labels))], + ) + + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: + """Get a dictionary mapping class labels (string or int) to lists of samples. + + Args: + use_string_label (bool): If True, use string class labels as keys. + If False, use integer indices as keys. + """ + idx_list_per_classes: dict[int, list[int]] = {} + for idx in range(len(self)): + item = self.dm_subset[idx] + labels = item.label.tolist() + if use_string_label: + labels = [self.label_info.label_names[label] for label in labels] + for label in labels: + if label not in idx_list_per_classes: + idx_list_per_classes[label] = [] + idx_list_per_classes[label].append(idx) + return idx_list_per_classes diff --git a/library/src/otx/data/dataset/instance_segmentation.py b/library/src/otx/data/dataset/instance_segmentation.py index f982ad43eb..144e6aa792 100644 --- a/library/src/otx/data/dataset/instance_segmentation.py +++ b/library/src/otx/data/dataset/instance_segmentation.py @@ -21,6 +21,28 @@ from .base import OTXDataset, Transforms +def convert_datumaro_polygons_to_ragged_array(polygons: list[Polygon]) -> np.ndarray: + """Convert list of datumaro.Polygon to ragged array format. + + Args: + polygons: List of datumaro.Polygon objects + + Returns: + np.ndarray: Object array containing np.ndarray objects of shape (Npoly, 2) + """ + if not polygons: + return np.array([], dtype=object) + + ragged_polygons = np.empty(len(polygons), dtype=object) + for i, polygon in enumerate(polygons): + points = np.array(polygon.points, dtype=np.float32) + if len(points) % 2 != 0: + # Handle invalid polygon by creating a degenerate triangle + points = np.array([0, 0, 0, 0, 0, 0], dtype=np.float32) + ragged_polygons[i] = points.reshape(-1, 2) + return ragged_polygons + + class OTXInstanceSegDataset(OTXDataset): """OTXDataset class for instance segmentation. @@ -89,6 +111,11 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: labels = np.array(gt_labels, dtype=np.int64) + # Convert polygons to ragged array format + polygons = None + if gt_polygons: + polygons = convert_datumaro_polygons_to_ragged_array(gt_polygons) + entity = OTXDataItem( image=img_data, img_info=ImageInfo( @@ -106,7 +133,7 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: ), masks=tv_tensors.Mask(masks, dtype=torch.uint8), label=torch.as_tensor(labels, dtype=torch.long), - polygons=gt_polygons if len(gt_polygons) > 0 else None, + polygons=polygons, ) return self._apply_transforms(entity) # type: ignore[return-value] diff --git a/library/src/otx/data/dataset/instance_segmentation_new.py b/library/src/otx/data/dataset/instance_segmentation_new.py new file mode 100644 index 0000000000..d9fcb082a5 --- /dev/null +++ b/library/src/otx/data/dataset/instance_segmentation_new.py @@ -0,0 +1,50 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for OTXInstanceSegDataset.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from otx import LabelInfo +from otx.data.dataset.base_new import OTXDataset +from otx.data.entity.sample import InstanceSegmentationSample, InstanceSegmentationSampleWithMask + +if TYPE_CHECKING: + from datumaro.experimental import Dataset + + +class OTXInstanceSegDataset(OTXDataset): + """OTXDataset class for instance segmentation task.""" + + def __init__(self, dm_subset: Dataset, include_polygons: bool = True, **kwargs) -> None: + sample_type = InstanceSegmentationSample if include_polygons else InstanceSegmentationSampleWithMask + dm_subset = dm_subset.convert_to_schema(sample_type) + super().__init__(dm_subset=dm_subset, sample_type=sample_type, **kwargs) + + labels = list(dm_subset.schema.attributes["label"].categories.labels) + self.label_info = LabelInfo( + label_names=labels, + label_groups=[labels], + label_ids=[str(i) for i in range(len(labels))], + ) + + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: + """Get a dictionary mapping class labels (string or int) to lists of samples. + + Args: + use_string_label (bool): If True, use string class labels as keys. + If False, use integer indices as keys. + """ + idx_list_per_classes: dict[int, list[int]] = {} + for idx in range(len(self)): + item = self.dm_subset[idx] + labels = item.label.tolist() + if use_string_label: + labels = [self.label_info.label_names[label] for label in labels] + for label in labels: + if label not in idx_list_per_classes: + idx_list_per_classes[label] = [] + idx_list_per_classes[label].append(idx) + return idx_list_per_classes diff --git a/library/src/otx/data/dataset/keypoint_detection_new.py b/library/src/otx/data/dataset/keypoint_detection_new.py new file mode 100644 index 0000000000..ecef4cf2e5 --- /dev/null +++ b/library/src/otx/data/dataset/keypoint_detection_new.py @@ -0,0 +1,45 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for OTXKeypointDetectionDataset.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List, Union + +import torch +from torchvision.transforms.v2.functional import to_dtype, to_image + +from otx.data.entity.sample import KeypointSample +from otx.data.transform_libs.torchvision import Compose +from otx.types.label import LabelInfo + +from .base_new import OTXDataset + +Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]] + +if TYPE_CHECKING: + from datumaro.experimental import Dataset + + +class OTXKeypointDetectionDataset(OTXDataset): + """OTXDataset class for keypoint detection task.""" + + def __init__(self, dm_subset: Dataset, **kwargs) -> None: + sample_type = KeypointSample + dm_subset = dm_subset.convert_to_schema(sample_type) + super().__init__(dm_subset=dm_subset, sample_type=sample_type, **kwargs) + labels = dm_subset.schema.attributes["label"].categories.labels + self.label_info = LabelInfo( + label_names=labels, + label_groups=[], + label_ids=[str(i) for i in range(len(labels))], + ) + + def _get_item_impl(self, index: int) -> KeypointSample | None: + item = self.dm_subset[index] + keypoints = item.keypoints + keypoints[:, 2] = torch.clamp(keypoints[:, 2], max=1) # OTX represents visibility as 0 or 1 + item.keypoints = keypoints + item.image = to_dtype(to_image(item.image), torch.float32) + return self._apply_transforms(item) # type: ignore[return-value] diff --git a/library/src/otx/data/dataset/segmentation_new.py b/library/src/otx/data/dataset/segmentation_new.py new file mode 100644 index 0000000000..a1f330044e --- /dev/null +++ b/library/src/otx/data/dataset/segmentation_new.py @@ -0,0 +1,31 @@ +# Copyright (C) 2023-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for OTXSegmentationDataset.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from otx import SegLabelInfo +from otx.data.dataset.base_new import OTXDataset +from otx.data.entity.sample import SegmentationSample + +if TYPE_CHECKING: + from datumaro.experimental import Dataset + + +class OTXSegmentationDataset(OTXDataset): + """OTXDataset class for segmentation task.""" + + def __init__(self, dm_subset: Dataset, **kwargs) -> None: + sample_type = SegmentationSample + dm_subset = dm_subset.convert_to_schema(sample_type) + super().__init__(dm_subset=dm_subset, sample_type=sample_type, **kwargs) + + labels = list(dm_subset.schema.attributes["masks"].categories.labels) + self.label_info = SegLabelInfo( + label_names=labels, + label_groups=[labels], + label_ids=[str(i) for i in range(len(labels))], + ) diff --git a/library/src/otx/data/dataset/tile.py b/library/src/otx/data/dataset/tile.py deleted file mode 100644 index 88726adf7d..0000000000 --- a/library/src/otx/data/dataset/tile.py +++ /dev/null @@ -1,691 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""OTX tile dataset.""" - -from __future__ import annotations - -import logging as log -import operator -import warnings -from collections import defaultdict -from copy import deepcopy -from itertools import product -from typing import TYPE_CHECKING, Callable - -import numpy as np -import shapely.geometry as sg -import torch -from datumaro import Dataset as DmDataset -from datumaro import DatasetItem, Image -from datumaro.components.annotation import AnnotationType, Bbox, Ellipse, ExtractedMask, Polygon -from datumaro.plugins.tiling import Tile -from datumaro.plugins.tiling.tile import _apply_offset -from datumaro.plugins.tiling.util import ( - clip_x1y1x2y2, - cxcywh_to_x1y1x2y2, - x1y1x2y2_to_cxcywh, - x1y1x2y2_to_xywh, -) -from torchvision import tv_tensors - -from otx.data.dataset.segmentation import _extract_class_mask -from otx.data.entity.base import ImageInfo -from otx.data.entity.tile import ( - TileBatchDetDataEntity, - TileBatchInstSegDataEntity, - TileBatchSegDataEntity, - TileDetDataEntity, - TileInstSegDataEntity, - TileSegDataEntity, -) -from otx.data.entity.torch import OTXDataItem -from otx.data.utils.structures.mask.mask_util import polygon_to_bitmap -from otx.types.task import OTXTaskType - -from .base import OTXDataset - -if TYPE_CHECKING: - from datumaro.components.media import BboxIntCoords - - from otx.config.data import TileConfig - from otx.data.dataset.detection import OTXDetectionDataset - from otx.data.dataset.instance_segmentation import OTXInstanceSegDataset - from otx.data.dataset.segmentation import OTXSegmentationDataset - -# ruff: noqa: SLF001 -# NOTE: Disable private-member-access (SLF001). -# This is a workaround so we could apply the same transforms to tiles as the original dataset. - -# NOTE: Datumaro subset name should be standardized. -TRAIN_SUBSET_NAMES = ("train", "TRAINING") -VAL_SUBSET_NAMES = ("val", "VALIDATION") - - -class OTXTileTransform(Tile): - """OTX tile transform. - - Different from the original Datumaro Tile transform, - OTXTileTransform takes tile_size and overlap as input instead of grid size - - Args: - extractor (DmDataset): Dataset subset to extract tiles from. - tile_size (tuple[int, int]): Tile size. - overlap (tuple[float, float]): Overlap ratio. - Overlap values are clipped between 0 and 0.9 to ensure the stride is not too small. - threshold_drop_ann (float): Threshold to drop annotations. - with_full_img (bool): Include full image in the tiles. - """ - - def __init__( - self, - extractor: DmDataset, - tile_size: tuple[int, int], - overlap: tuple[float, float], - threshold_drop_ann: float, - with_full_img: bool, - ) -> None: - # NOTE: clip overlap to [0, 0.9] - overlap = max(0, min(overlap[0], 0.9)), max(0, min(overlap[1], 0.9)) - super().__init__( - extractor, - (0, 0), - overlap=overlap, - threshold_drop_ann=threshold_drop_ann, - ) - self._tile_size = tile_size - self._tile_ann_func_map[AnnotationType.polygon] = OTXTileTransform._tile_polygon - self._tile_ann_func_map[AnnotationType.mask] = OTXTileTransform._tile_masks - self._tile_ann_func_map[AnnotationType.ellipse] = OTXTileTransform._tile_ellipse - self.with_full_img = with_full_img - - @staticmethod - def _tile_polygon( - ann: Polygon, - roi_box: sg.Polygon, - threshold_drop_ann: float = 0.8, - *args, # noqa: ARG004 - **kwargs, # noqa: ARG004 - ) -> Polygon | None: - polygon = sg.Polygon(ann.get_points()) - - # NOTE: polygon may be invalid, e.g. self-intersecting - if not roi_box.intersects(polygon) or not polygon.is_valid: - return None - - # NOTE: intersection may return a GeometryCollection or MultiPolygon - inter = polygon.intersection(roi_box) - if isinstance(inter, (sg.GeometryCollection, sg.MultiPolygon)): - shapes = [(geom, geom.area) for geom in list(inter.geoms) if geom.is_valid] - if not shapes: - return None - - inter, _ = max(shapes, key=operator.itemgetter(1)) - - if not isinstance(inter, sg.Polygon) and not inter.is_valid: - return None - - prop_area = inter.area / polygon.area - - if prop_area < threshold_drop_ann: - return None - - inter = _apply_offset(inter, roi_box) - - return ann.wrap( - points=[p for xy in inter.exterior.coords for p in xy], - attributes=deepcopy(ann.attributes), - ) - - @staticmethod - def _tile_masks( - ann: ExtractedMask, - roi_int: BboxIntCoords, - *args, # noqa: ARG004 - **kwargs, # noqa: ARG004 - ) -> ExtractedMask: - """Extracts a tile mask from the given annotation. - - Note: Original Datumaro _tile_masks does not work with ExtractedMask. - - Args: - ann (ExtractedMask): datumaro ExtractedMask annotation. - roi_int (BboxIntCoords): ROI coordinates. - - Returns: - ExtractedMask: ExtractedMask annotation. - """ - x, y, w, h = roi_int - return ann.wrap( - index_mask=ann.index_mask()[y : y + h, x : x + w], - attributes=deepcopy(ann.attributes), - ) - - @staticmethod - def _tile_ellipse( - ann: Ellipse, - roi_box: sg.Polygon, - threshold_drop_ann: float = 0.8, - *args, # noqa: ARG004 - **kwargs, # noqa: ARG004 - ) -> Polygon | None: - polygon = sg.Polygon(ann.get_points(num_points=10)) - - # NOTE: polygon may be invalid, e.g. self-intersecting - if not roi_box.intersects(polygon) or not polygon.is_valid: - return None - - # NOTE: intersection may return a GeometryCollection or MultiPolygon - inter = polygon.intersection(roi_box) - if isinstance(inter, (sg.GeometryCollection, sg.MultiPolygon)): - shapes = [(geom, geom.area) for geom in list(inter.geoms) if geom.is_valid] - if not shapes: - return None - - inter, _ = max(shapes, key=operator.itemgetter(1)) - - if not isinstance(inter, sg.Polygon) and not inter.is_valid: - return None - - prop_area = inter.area / polygon.area - - if prop_area < threshold_drop_ann: - return None - - inter = _apply_offset(inter, roi_box) - - return Polygon( - points=[p for xy in inter.exterior.coords for p in xy], - attributes=deepcopy(ann.attributes), - label=ann.label, - ) - - def _extract_rois(self, image: Image) -> list[BboxIntCoords]: - """Extracts Tile ROIs from the given image. - - Args: - image (Image): Full image. - - Returns: - list[BboxIntCoords]: list of ROIs. - """ - if image.size is None: - msg = "Image size is None" - raise ValueError(msg) - - img_h, img_w = image.size - tile_h, tile_w = self._tile_size - h_ovl, w_ovl = self._overlap - - rois: set[BboxIntCoords] = set() - cols = range(0, img_w, int(tile_w * (1 - w_ovl))) - rows = range(0, img_h, int(tile_h * (1 - h_ovl))) - - if self.with_full_img: - rois.add(x1y1x2y2_to_xywh(0, 0, img_w, img_h)) - for offset_x, offset_y in product(cols, rows): - x2 = min(offset_x + tile_w, img_w) - y2 = min(offset_y + tile_h, img_h) - c_x, c_y, w, h = x1y1x2y2_to_cxcywh(offset_x, offset_y, x2, y2) - x1, y1, x2, y2 = cxcywh_to_x1y1x2y2(c_x, c_y, w, h) - x1, y1, x2, y2 = clip_x1y1x2y2(x1, y1, x2, y2, img_w, img_h) - x1, y1, x2, y2 = (int(v) for v in [x1, y1, x2, y2]) - rois.add(x1y1x2y2_to_xywh(x1, y1, x2, y2)) - - log.info(f"image: {img_h}x{img_w} ~ tile_size: {self._tile_size}") - log.info(f"{len(rows)}x{len(cols)} tiles -> {len(rois)} tiles") - return list(rois) - - -class OTXTileDatasetFactory: - """OTX tile dataset factory.""" - - @classmethod - def create( - cls, - task: OTXTaskType, - dataset: OTXDataset, - tile_config: TileConfig, - ) -> OTXTileDataset: - """Create a tile dataset based on the task type and subset type. - - NOte: All task utilize the same OTXTileTrainDataset for training. - In testing, we use different tile dataset for different task - type due to different annotation format and data entity. - - Args: - task (OTXTaskType): OTX task type. - dataset (OTXDataset): OTX dataset. - tile_config (TilerConfig): Tile configuration. - - Returns: - OTXTileDataset: Tile dataset. - """ - if dataset.dm_subset[0].subset in TRAIN_SUBSET_NAMES: - return OTXTileTrainDataset(dataset, tile_config) - - if task == OTXTaskType.DETECTION: - return OTXTileDetTestDataset(dataset, tile_config) - if task in [OTXTaskType.ROTATED_DETECTION, OTXTaskType.INSTANCE_SEGMENTATION]: - return OTXTileInstSegTestDataset(dataset, tile_config) - if task == OTXTaskType.SEMANTIC_SEGMENTATION: - return OTXTileSemanticSegTestDataset(dataset, tile_config) - - msg = f"Unsupported task type: {task} for tiling" - raise NotImplementedError(msg) - - -class OTXTileDataset(OTXDataset): - """OTX tile dataset base class. - - Args: - dataset (OTXDataset): OTX dataset. - tile_config (TilerConfig): Tile configuration. - """ - - def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None: - super().__init__( - dataset.dm_subset, - dataset.transforms, - dataset.max_refetch, - dataset.image_color_channel, - dataset.stack_images, - dataset.to_tv_image, - ) - self.tile_config = tile_config - self._dataset = dataset - - # LabelInfo differs from SegLabelInfo, thus we need to update it for semantic segmentation. - if self.label_info != dataset.label_info: - msg = ( - "Replace the label info to match the dataset's label info", - "as there is a mismatch between the dataset and the tile dataset.", - ) - log.warning(msg) - self.label_info = dataset.label_info - - def __len__(self) -> int: - return len(self._dataset) - - @property - def collate_fn(self) -> Callable: - """Collate function from the original dataset.""" - return self._dataset.collate_fn - - def _get_item_impl(self, index: int) -> OTXDataItem | None: - """Get item implementation from the original dataset.""" - return self._dataset._get_item_impl(index) - - def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> OTXDataItem: - """Convert a tile dataset item to OTXDataItem.""" - msg = "Method _convert_entity is not implemented." - raise NotImplementedError(msg) - - def transform_item( - self, - item: DatasetItem, - tile_size: tuple[int, int], - overlap: tuple[float, float], - with_full_img: bool, - ) -> DmDataset: - """Transform a dataset item to tile dataset which contains multiple tiles.""" - tile_ds = DmDataset.from_iterable([item]) - return tile_ds.transform( - OTXTileTransform, - tile_size=tile_size, - overlap=overlap, - threshold_drop_ann=0.5, - with_full_img=with_full_img, - ) - - def get_tiles( - self, - image: np.ndarray, - item: DatasetItem, - parent_idx: int, - ) -> tuple[list[OTXDataItem], list[dict]]: - """Retrieves tiles from the given image and dataset item. - - Args: - image (np.ndarray): The input image. - item (DatasetItem): The dataset item. - parent_idx (int): The parent index. This is to keep track of the original dataset item index for merging. - - Returns: - A tuple containing two lists: - - tile_entities (list[OTXDataItem]): List of tile entities. - - tile_attrs (list[dict]): List of tile attributes. - """ - tile_ds = self.transform_item( - item, - tile_size=self.tile_config.tile_size, - overlap=(self.tile_config.overlap, self.tile_config.overlap), - with_full_img=self.tile_config.with_full_img, - ) - - if item.subset in VAL_SUBSET_NAMES: - # NOTE: filter validation tiles with annotations only to avoid evaluation on empty tiles. - tile_ds = tile_ds.filter("/item/annotation", filter_annotations=True, remove_empty=True) - # if tile dataset is empty it means objects are too big to fit in any tile, in this case include full image - if len(tile_ds) == 0: - tile_ds = self.transform_item( - item, - tile_size=self.tile_config.tile_size, - overlap=(self.tile_config.overlap, self.tile_config.overlap), - with_full_img=True, - ) - - tile_entities: list[OTXDataItem] = [] - tile_attrs: list[dict] = [] - for tile in tile_ds: - tile_entity = self._convert_entity(image, tile, parent_idx) - # apply the same transforms as the original dataset - transformed_tile = self._apply_transforms(tile_entity) - if transformed_tile is None: - msg = "Transformed tile is None" - raise RuntimeError(msg) - tile.attributes.update({"tile_size": self.tile_config.tile_size}) - tile_entities.append(transformed_tile) - tile_attrs.append(tile.attributes) - return tile_entities, tile_attrs - - -class OTXTileTrainDataset(OTXTileDataset): - """OTX tile train dataset. - - Args: - dataset (OTXDataset): OTX dataset. - tile_config (TilerConfig): Tile configuration. - """ - - def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None: - dm_dataset = dataset.dm_subset - dm_dataset = dm_dataset.transform( - OTXTileTransform, - tile_size=tile_config.tile_size, - overlap=(tile_config.overlap, tile_config.overlap), - threshold_drop_ann=0.5, - with_full_img=tile_config.with_full_img, - ) - dm_dataset = dm_dataset.filter("/item/annotation", filter_annotations=True, remove_empty=True) - # Include original dataset for training - dm_dataset.update(dataset.dm_subset) - dataset.dm_subset = dm_dataset - super().__init__(dataset, tile_config) - - -class OTXTileDetTestDataset(OTXTileDataset): - """OTX tile detection test dataset. - - OTXTileDetTestDataset wraps a list of tiles (DetDataEntity) into a single TileDetDataEntity for testing/predicting. - - Args: - dataset (OTXDetDataset): OTX detection dataset. - tile_config (TilerConfig): Tile configuration. - """ - - def __init__(self, dataset: OTXDetectionDataset, tile_config: TileConfig) -> None: - super().__init__(dataset, tile_config) - - @property - def collate_fn(self) -> Callable: - """Collate function for tile detection test dataset.""" - return TileBatchDetDataEntity.collate_fn - - def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[override] - """Get item implementation. - - Transform a single dataset item to multiple tiles using Datumaro tiling plugin, and - wrap tiles into a single TileDetDataEntity. - - Args: - index (int): Index of the dataset item. - - Returns: - TileDetDataEntity: tile detection data entity that wraps a list of detection data entities. - - Note: - Ignoring [override] check is necessary here since OTXDataset._get_item_impl exclusively permits - the return of OTXDataItem. Nevertheless, in instances involving tiling, it becomes - imperative to encapsulate tiles within a unified entity, namely TileDetDataEntity. - """ - item = self.dm_subset[index] - img = item.media_as(Image) - img_data, img_shape, _ = self._get_img_data_and_shape(img) - - bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] - - bboxes = ( - np.stack([ann.points for ann in bbox_anns], axis=0).astype(np.float32) - if len(bbox_anns) > 0 - else np.zeros((0, 4), dtype=np.float32) - ) - labels = torch.as_tensor([ann.label for ann in bbox_anns]) - - tile_entities, tile_attrs = self.get_tiles(img_data, item, index) - - return TileDetDataEntity( - num_tiles=len(tile_entities), - entity_list=tile_entities, - tile_attr_list=tile_attrs, - ori_img_info=ImageInfo( - img_idx=index, - img_shape=img_shape, - ori_shape=img_shape, - ), - ori_bboxes=tv_tensors.BoundingBoxes( - bboxes, - format=tv_tensors.BoundingBoxFormat.XYXY, - canvas_size=img_shape, - ), - ori_labels=labels, - ) - - def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> OTXDataItem: # type: ignore[override] - """Convert a tile datumaro dataset item to TorchDataItem.""" - x1, y1, w, h = dataset_item.attributes["roi"] - tile_img = image[y1 : y1 + h, x1 : x1 + w] - tile_shape = tile_img.shape[:2] - img_info = ImageInfo( - img_idx=parent_idx, - img_shape=tile_shape, - ori_shape=tile_shape, - ) - return OTXDataItem( - image=tile_img, - img_info=img_info, - ) - - -class OTXTileInstSegTestDataset(OTXTileDataset): - """OTX tile inst-seg test dataset. - - OTXTileDetTestDataset wraps a list of tiles (TorchDataItem) into a single TileDetDataEntity - for testing/predicting. - - Args: - dataset (OTXInstanceSegDataset): OTX inst-seg dataset. - tile_config (TilerConfig): Tile configuration. - """ - - def __init__(self, dataset: OTXInstanceSegDataset, tile_config: TileConfig) -> None: - super().__init__(dataset, tile_config) - - @property - def collate_fn(self) -> Callable: - """Collate function for tile inst-seg test dataset.""" - return TileBatchInstSegDataEntity.collate_fn - - def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[override] - """Get item implementation. - - Transform a single dataset item to multiple tiles using Datumaro tiling plugin, and - wrap tiles into a single TileInstSegDataEntity. - - Args: - index (int): Index of the dataset item. - - Returns: - TileInstSegDataEntity: tile inst-seg data entity that wraps a list of inst-seg data entities. - - Note: - Ignoring [override] check is necessary here since OTXDataset._get_item_impl exclusively permits - the return of OTXDataItem. Nevertheless, in instances involving tiling, it becomes - imperative to encapsulate tiles within a unified entity, namely TileInstSegDataEntity. - """ - item = self.dm_subset[index] - img = item.media_as(Image) - img_data, img_shape, _ = self._get_img_data_and_shape(img) - - anno_collection: dict[str, list] = defaultdict(list) - for anno in item.annotations: - anno_collection[anno.__class__.__name__].append(anno) - - gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], [] - - # TODO(Eugene): https://jira.devtools.intel.com/browse/CVS-159363 - # Temporary solution to handle multiple annotation types. - # Ideally, we should pre-filter annotations during initialization of the dataset. - - if Polygon.__name__ in anno_collection: # Polygon for InstSeg has higher priority - for poly in anno_collection[Polygon.__name__]: - bbox = Bbox(*poly.get_bbox()).points - gt_bboxes.append(bbox) - gt_labels.append(poly.label) - - if self._dataset.include_polygons: - gt_polygons.append(poly) - else: - gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0]) - elif Bbox.__name__ in anno_collection: - boxes = anno_collection[Bbox.__name__] - gt_bboxes = [ann.points for ann in boxes] - gt_labels = [ann.label for ann in boxes] - for box in boxes: - poly = Polygon(box.as_polygon()) - if self._dataset.include_polygons: - gt_polygons.append(poly) - else: - gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0]) - elif Ellipse.__name__ in anno_collection: - for ellipse in anno_collection[Ellipse.__name__]: - bbox = Bbox(*ellipse.get_bbox()).points - gt_bboxes.append(bbox) - gt_labels.append(ellipse.label) - poly = Polygon(ellipse.as_polygon(num_points=10)) - if self._dataset.include_polygons: - gt_polygons.append(poly) - else: - gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0]) - else: - warnings.warn(f"No valid annotations found for image {item.id}!", stacklevel=2) - - bboxes = np.stack(gt_bboxes, dtype=np.float32) if gt_bboxes else np.empty((0, 4), dtype=np.float32) - masks = np.stack(gt_masks, axis=0) if gt_masks else np.empty((0, *img_shape), dtype=bool) - labels = np.array(gt_labels, dtype=np.int64) - - tile_entities, tile_attrs = self.get_tiles(img_data, item, index) - - return TileInstSegDataEntity( - num_tiles=len(tile_entities), - entity_list=tile_entities, - tile_attr_list=tile_attrs, - ori_img_info=ImageInfo( - img_idx=index, - img_shape=img_shape, - ori_shape=img_shape, - ), - ori_bboxes=tv_tensors.BoundingBoxes( - bboxes, - format=tv_tensors.BoundingBoxFormat.XYXY, - canvas_size=img_shape, - ), - ori_labels=torch.as_tensor(labels), - ori_masks=tv_tensors.Mask(masks, dtype=torch.uint8), - ori_polygons=gt_polygons, - ) - - def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> OTXDataItem: # type: ignore[override] - """Convert a tile dataset item to TorchDataItem.""" - x1, y1, w, h = dataset_item.attributes["roi"] - tile_img = image[y1 : y1 + h, x1 : x1 + w] - tile_shape = tile_img.shape[:2] - img_info = ImageInfo( - img_idx=parent_idx, - img_shape=tile_shape, - ori_shape=tile_shape, - ) - return OTXDataItem( - image=tile_img, - img_info=img_info, - masks=tv_tensors.Mask(np.zeros((0, *tile_shape), dtype=bool)), - ) - - -class OTXTileSemanticSegTestDataset(OTXTileDataset): - """OTX tile semantic-seg test dataset. - - OTXTileSemanticSegTestDataset wraps a list of tiles (SegDataEntity) into a single TileSegDataEntity - for testing/predicting. - - Args: - dataset (OTXSegmentationDataset): OTX semantic-seg dataset. - tile_config (TilerConfig): Tile configuration. - """ - - def __init__(self, dataset: OTXSegmentationDataset, tile_config: TileConfig) -> None: - super().__init__(dataset, tile_config) - self.ignore_index = self._dataset.ignore_index - - @property - def collate_fn(self) -> Callable: - """Collate function for tile detection test dataset.""" - return TileBatchSegDataEntity.collate_fn - - def _get_item_impl(self, index: int) -> TileSegDataEntity: # type: ignore[override] - """Get item implementation. - - Transform a single dataset item to multiple tiles using Datumaro tiling plugin, and - wrap tiles into a single TileSegDataEntity. - - Args: - index (int): Index of the dataset item. - - Returns: - TileSegDataEntity: tile semantic-seg data entity that wraps a list of semantic-seg data entities. - """ - item = self.dm_subset[index] - img = item.media_as(Image) - img_data, img_shape, _ = self._get_img_data_and_shape(img) - - extracted_mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index) - masks = tv_tensors.Mask(extracted_mask[None]) - tile_entities, tile_attrs = self.get_tiles(img_data, item, index) - - return TileSegDataEntity( - num_tiles=len(tile_entities), - entity_list=tile_entities, - tile_attr_list=tile_attrs, - ori_img_info=ImageInfo( - img_idx=index, - img_shape=img_shape, - ori_shape=img_shape, - ), - ori_masks=masks, - ) - - def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> OTXDataItem: # type: ignore[override] - """Convert a tile datumaro dataset item to SegDataEntity.""" - x1, y1, w, h = dataset_item.attributes["roi"] - tile_img = image[y1 : y1 + h, x1 : x1 + w] - tile_shape = tile_img.shape[:2] - img_info = ImageInfo( - img_idx=parent_idx, - img_shape=tile_shape, - ori_shape=tile_shape, - ) - return OTXDataItem( - image=tile_img, - img_info=img_info, - masks=tv_tensors.Mask(np.zeros((0, *tile_shape), dtype=bool)), - ) diff --git a/library/src/otx/data/dataset/tile_new.py b/library/src/otx/data/dataset/tile_new.py new file mode 100644 index 0000000000..466a3d8bb2 --- /dev/null +++ b/library/src/otx/data/dataset/tile_new.py @@ -0,0 +1,320 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""OTX tile dataset.""" + +from __future__ import annotations + +import logging as log +from typing import TYPE_CHECKING, Callable + +from datumaro.experimental.fields import Subset +from datumaro.experimental.filtering.filter_registry import create_filtering_transform +from datumaro.experimental.tiling.tiler_registry import TilingConfig, create_tiling_transform + +from otx.data.entity.sample import OTXSample +from otx.data.entity.tile import ( + TileBatchDetDataEntity, + TileBatchInstSegDataEntity, + TileBatchSegDataEntity, + TileDetDataEntity, + TileInstSegDataEntity, + TileSegDataEntity, +) +from otx.types.task import OTXTaskType + +from .base_new import OTXDataset + +if TYPE_CHECKING: + from otx.config.data import TileConfig + from otx.data.dataset.detection import OTXDetectionDataset + from otx.data.dataset.instance_segmentation import OTXInstanceSegDataset + from otx.data.dataset.segmentation import OTXSegmentationDataset + +# ruff: noqa: SLF001 +# NOTE: Disable private-member-access (SLF001). +# This is a workaround so we could apply the same transforms to tiles as the original dataset. + + +class OTXTileDatasetFactory: + """OTX tile dataset factory.""" + + @classmethod + def create( + cls, + task: OTXTaskType, + dataset: OTXDataset, + tile_config: TileConfig, + ) -> OTXDataset: + """Create a tile dataset based on the task type and subset type. + + NOte: All task utilize the same OTXTileTrainDataset for training. + In testing, we use different tile dataset for different task + type due to different annotation format and data entity. + + Args: + task (OTXTaskType): OTX task type. + dataset (OTXDataset): OTX dataset. + tile_config (TilerConfig): Tile configuration. + + Returns: + OTXTileDataset: Tile dataset. + """ + subset = dataset.dm_subset[0].subset + if subset == Subset.TRAINING: + dm_dataset = dataset.dm_subset + dm_dataset = dm_dataset.transform( + create_tiling_transform( + TilingConfig( + tile_height=tile_config.tile_size[0], + tile_width=tile_config.tile_size[1], + overlap_x=tile_config.overlap, + overlap_y=tile_config.overlap, + ), + threshold_drop_ann=0.5, + ), + dtype=dm_dataset.dtype, + ) + dm_dataset = dm_dataset.transform(create_filtering_transform(), dtype=dm_dataset.dtype) + dataset.dm_subset = dm_dataset + return dataset + + if task == OTXTaskType.DETECTION: + return OTXTileDetTestDataset(dataset, tile_config, subset) + if task in [OTXTaskType.ROTATED_DETECTION, OTXTaskType.INSTANCE_SEGMENTATION]: + return OTXTileInstSegTestDataset(dataset, tile_config, subset) + if task == OTXTaskType.SEMANTIC_SEGMENTATION: + return OTXTileSemanticSegTestDataset(dataset, tile_config, subset) + + msg = f"Unsupported task type: {task} for tiling" + raise NotImplementedError(msg) + + +class OTXTileDataset(OTXDataset): + """OTX tile dataset base class. + + Args: + dataset (OTXDataset): OTX dataset. + tile_config (TilerConfig): Tile configuration. + """ + + def __init__(self, dataset: OTXDataset, tile_config: TileConfig, subset: Subset) -> None: + super().__init__( + dataset.dm_subset, + dataset.transforms, + dataset.max_refetch, + dataset.image_color_channel, + dataset.stack_images, + dataset.to_tv_image, + ) + self.tile_config = tile_config + self._dataset = dataset + self._subset = subset + + # LabelInfo differs from SegLabelInfo, thus we need to update it for semantic segmentation. + if self.label_info != dataset.label_info: + msg = ( + "Replace the label info to match the dataset's label info", + "as there is a mismatch between the dataset and the tile dataset.", + ) + log.warning(msg) + self.label_info = dataset.label_info + + def __len__(self) -> int: + return len(self._dataset) + + @property + def collate_fn(self) -> Callable: + """Collate function from the original dataset.""" + return self._dataset.collate_fn + + def _get_item_impl(self, index: int) -> OTXSample | None: + """Get item implementation from the original dataset.""" + return self._dataset._get_item_impl(index) + + def get_tiles( + self, + parent_idx: int, + ) -> list[OTXSample]: + """Retrieves tiles from the given image and dataset item. + + Args: + image (np.ndarray): The input image. + item (DatasetItem): The dataset item. + parent_idx (int): The parent index. This is to keep track of the original dataset item index for merging. + + Returns: + A tuple containing two lists: + - tile_entities (list[OTXSample]): List of tile entities. + """ + parent_slice_ds = self.dm_subset.slice(parent_idx, 1) + tile_ds = parent_slice_ds.transform( + create_tiling_transform( + TilingConfig( + tile_height=self.tile_config.tile_size[0], + tile_width=self.tile_config.tile_size[1], + overlap_x=self.tile_config.overlap, + overlap_y=self.tile_config.overlap, + ), + threshold_drop_ann=0.5, + ), + dtype=parent_slice_ds.dtype, + ) + + if self._subset == Subset.VALIDATION: + # NOTE: filter validation tiles with annotations only to avoid evaluation on empty tiles. + tile_ds = tile_ds.transform(create_filtering_transform(), dtype=parent_slice_ds.dtype) + + # if tile dataset is empty it means objects are too big to fit in any tile, in this case include full image + if len(tile_ds) == 0: + tile_ds = parent_slice_ds + + tile_entities: list[OTXSample] = [] + for tile in tile_ds: + # apply the same transforms as the original dataset + object.__setattr__(tile.tile, "source_sample_idx", parent_idx) + transformed_tile = self._apply_transforms(tile) + if transformed_tile is None: + msg = "Transformed tile is None" + raise RuntimeError(msg) + tile_entities.append(transformed_tile) + return tile_entities + + +class OTXTileDetTestDataset(OTXTileDataset): + """OTX tile detection test dataset. + + OTXTileDetTestDataset wraps a list of tiles (DetDataEntity) into a single TileDetDataEntity for testing/predicting. + + Args: + dataset (OTXDetDataset): OTX detection dataset. + tile_config (TilerConfig): Tile configuration. + """ + + def __init__(self, dataset: OTXDetectionDataset, tile_config: TileConfig, subset: Subset) -> None: + super().__init__(dataset, tile_config, subset) + + @property + def collate_fn(self) -> Callable: + """Collate function for tile detection test dataset.""" + return TileBatchDetDataEntity.collate_fn + + def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[override] + """Get item implementation. + + Transform a single dataset item to multiple tiles using Datumaro tiling plugin, and + wrap tiles into a single TileDetDataEntity. + + Args: + index (int): Index of the dataset item. + + Returns: + TileDetDataEntity: tile detection data entity that wraps a list of detection data entities. + + Note: + Ignoring [override] check is necessary here since OTXDataset._get_item_impl exclusively permits + the return of OTXSample. Nevertheless, in instances involving tiling, it becomes + imperative to encapsulate tiles within a unified entity, namely TileDetDataEntity. + """ + item = self.dm_subset[index] + tile_entities = self.get_tiles(index) + + return TileDetDataEntity( + num_tiles=len(tile_entities), + entity_list=tile_entities, + ori_img_info=item.img_info, + ori_bboxes=item.bboxes, + ori_labels=item.label, + ) + + +class OTXTileInstSegTestDataset(OTXTileDataset): + """OTX tile inst-seg test dataset. + + OTXTileDetTestDataset wraps a list of tiles (TorchDataItem) into a single TileDetDataEntity + for testing/predicting. + + Args: + dataset (OTXInstanceSegDataset): OTX inst-seg dataset. + tile_config (TilerConfig): Tile configuration. + """ + + def __init__(self, dataset: OTXInstanceSegDataset, tile_config: TileConfig, subset: Subset) -> None: + super().__init__(dataset, tile_config, subset) + + @property + def collate_fn(self) -> Callable: + """Collate function for tile inst-seg test dataset.""" + return TileBatchInstSegDataEntity.collate_fn + + def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[override] + """Get item implementation. + + Transform a single dataset item to multiple tiles using Datumaro tiling plugin, and + wrap tiles into a single TileInstSegDataEntity. + + Args: + index (int): Index of the dataset item. + + Returns: + TileInstSegDataEntity: tile inst-seg data entity that wraps a list of inst-seg data entities. + + Note: + Ignoring [override] check is necessary here since OTXDataset._get_item_impl exclusively permits + the return of OTXSample. Nevertheless, in instances involving tiling, it becomes + imperative to encapsulate tiles within a unified entity, namely TileInstSegDataEntity. + """ + item = self.dm_subset[index] + tile_entities = self.get_tiles(index) + + return TileInstSegDataEntity( + num_tiles=len(tile_entities), + entity_list=tile_entities, + ori_img_info=item.img_info, + ori_bboxes=item.bboxes, + ori_labels=item.label, + ori_masks=item.masks, + ori_polygons=item.polygons, + ) + + +class OTXTileSemanticSegTestDataset(OTXTileDataset): + """OTX tile semantic-seg test dataset. + + OTXTileSemanticSegTestDataset wraps a list of tiles (SegDataEntity) into a single TileSegDataEntity + for testing/predicting. + + Args: + dataset (OTXSegmentationDataset): OTX semantic-seg dataset. + tile_config (TilerConfig): Tile configuration. + """ + + def __init__(self, dataset: OTXSegmentationDataset, tile_config: TileConfig, subset: Subset) -> None: + super().__init__(dataset, tile_config, subset) + + @property + def collate_fn(self) -> Callable: + """Collate function for tile detection test dataset.""" + return TileBatchSegDataEntity.collate_fn + + def _get_item_impl(self, index: int) -> TileSegDataEntity: # type: ignore[override] + """Get item implementation. + + Transform a single dataset item to multiple tiles using Datumaro tiling plugin, and + wrap tiles into a single TileSegDataEntity. + + Args: + index (int): Index of the dataset item. + + Returns: + TileSegDataEntity: tile semantic-seg data entity that wraps a list of semantic-seg data entities. + """ + item = self.dm_subset[index] + tile_entities = self.get_tiles(index) + + return TileSegDataEntity( + num_tiles=len(tile_entities), + entity_list=tile_entities, + ori_img_info=item.img_info, + ori_masks=item.masks, + ) diff --git a/library/src/otx/data/entity/sample.py b/library/src/otx/data/entity/sample.py new file mode 100644 index 0000000000..4c3fdede63 --- /dev/null +++ b/library/src/otx/data/entity/sample.py @@ -0,0 +1,321 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Sample classes for OTX data entities.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import polars as pl +import torch +import torch.utils._pytree as pytree +from datumaro.experimental.dataset import Sample +from datumaro.experimental.fields import ImageInfo as DmImageInfo +from datumaro.experimental.fields import ( + Subset, + bbox_field, + image_field, + image_info_field, + instance_mask_field, + keypoints_field, + label_field, + mask_field, + polygon_field, + subset_field, +) +from datumaro.experimental.schema import Semantic +from torchvision import tv_tensors + +from otx.data.entity.base import ImageInfo + +if TYPE_CHECKING: + from torchvision.tv_tensors import BoundingBoxes, Mask + + +def register_pytree_node(cls: type[Sample]) -> type[Sample]: + """Decorator to register an OTX data entity with PyTorch's PyTree. + + This decorator should be applied to every OTX data entity, as TorchVision V2 transforms + use the PyTree to flatten and unflatten the data entity during runtime. + + Example: + `MulticlassClsDataEntity` example :: + + @register_pytree_node + @dataclass + class MulticlassClsDataEntity(OTXDataEntity): + ... + """ + + def flatten_fn(obj: object) -> tuple[list[Any], list[str]]: + obj_dict = dict(obj.__dict__) + + missing_keys = set(obj.__class__.__annotations__.keys()) - set(obj_dict.keys()) + for key in missing_keys: + obj_dict[key] = getattr(obj, key) + + return (list(obj_dict.values()), list(obj_dict.keys())) + + def unflatten_fn(values: list[Any], context: list[str]) -> object: + return cls(**dict(zip(context, values))) + + pytree.register_pytree_node( + cls, + flatten_fn=flatten_fn, + unflatten_fn=unflatten_fn, + ) + return cls + + +@register_pytree_node +class OTXSample(Sample): + """Base class for OTX data samples.""" + + image: np.ndarray | torch.Tensor | tv_tensors.Image | Any + subset: Subset = subset_field() + + @property + def masks(self) -> Mask | None: + """Get masks for the sample.""" + return None + + @property + def bboxes(self) -> BoundingBoxes | None: + """Get bounding boxes for the sample.""" + return None + + @property + def keypoints(self) -> torch.Tensor | None: + """Get keypoints for the sample.""" + return None + + @property + def polygons(self) -> np.ndarray | None: + """Get polygons for the sample.""" + return None + + @property + def label(self) -> torch.Tensor | None: + """Optional label property that returns None by default.""" + return None + + @property + def img_info(self) -> ImageInfo: + """Get image information for the sample.""" + if self._img_info is None: + err_msg = "img_info is not set." + raise ValueError(err_msg) + return self._img_info + + @img_info.setter + def img_info(self, value: ImageInfo) -> None: + self._img_info = value + + +@register_pytree_node +class ClassificationSample(OTXSample): + """ClassificationSample is a base class for OTX classification items.""" + + subset: Subset = subset_field() + + image: np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8) + label: torch.Tensor = label_field(pl.Int32()) + dm_image_info: DmImageInfo = image_info_field() + + def __post_init__(self) -> None: + shape = (self.dm_image_info.height, self.dm_image_info.width) + + self.img_info = ImageInfo( + img_idx=0, + img_shape=shape, + ori_shape=shape, + ) + + +@register_pytree_node +class ClassificationMultiLabelSample(OTXSample): + """ClassificationMultiLabelSample is a base class for OTX multi label classification items.""" + + image: np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8) + label: np.ndarray | torch.Tensor = label_field(pl.Int32(), multi_label=True) + dm_image_info: DmImageInfo = image_info_field() + + def __post_init__(self) -> None: + shape = (self.dm_image_info.height, self.dm_image_info.width) + + self.img_info = ImageInfo( + img_idx=0, + img_shape=shape, + ori_shape=shape, + ) + + +@register_pytree_node +class ClassificationHierarchicalSample(OTXSample): + """ClassificationHierarchicalSample is a base class for OTX hierarchical classification items.""" + + image: np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8) + label: np.ndarray | torch.Tensor = label_field(pl.Int32(), is_list=True) + dm_image_info: DmImageInfo = image_info_field() + + def __post_init__(self) -> None: + shape = (self.dm_image_info.height, self.dm_image_info.width) + + self.img_info = ImageInfo( + img_idx=0, + img_shape=shape, + ori_shape=shape, + ) + + +@register_pytree_node +class DetectionSample(OTXSample): + """DetectionSample is a base class for OTX detection items.""" + + subset: Subset = subset_field() + + image: tv_tensors.Image = image_field(dtype=pl.UInt8, channels_first=True) + label: torch.Tensor = label_field(pl.Int32(), is_list=True) + bboxes: np.ndarray | tv_tensors.BoundingBoxes = bbox_field(dtype=pl.Float32) + dm_image_info: DmImageInfo = image_info_field() + + def __post_init__(self) -> None: + shape = (self.dm_image_info.height, self.dm_image_info.width) + + # Convert bboxes to tv_tensors format + if isinstance(self.bboxes, np.ndarray): + self.bboxes = tv_tensors.BoundingBoxes( + self.bboxes, + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=shape, + dtype=torch.float32, + ) + + self.img_info = ImageInfo( + img_idx=0, + img_shape=shape, + ori_shape=shape, + ) + + +@register_pytree_node +class SegmentationSample(OTXSample): + """OTXDataItemSample is a base class for OTX data items.""" + + subset: Subset = subset_field() + image: tv_tensors.Image = image_field(dtype=pl.UInt8, channels_first=True) + masks: tv_tensors.Mask = mask_field(dtype=pl.UInt8, channels_first=True, has_channels_dim=True) + dm_image_info: DmImageInfo = image_info_field() + + def __post_init__(self) -> None: + shape = (self.dm_image_info.height, self.dm_image_info.width) + self.img_info = ImageInfo( + img_idx=0, + img_shape=shape, + ori_shape=shape, + ) + + +@register_pytree_node +class AnomalySample(OTXSample): + """ClassificationSample is a base class for OTX classification items.""" + + subset: Subset = subset_field() + image: np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8) + label: torch.Tensor = label_field(pl.Int32()) + dm_image_info: DmImageInfo = image_info_field() + masks: np.ndarray | tv_tensors.Image | None = mask_field( + dtype=pl.UInt8, semantic=Semantic.Anomaly, channels_first=True, has_channels_dim=True + ) + + def __post_init__(self) -> None: + shape = (self.dm_image_info.height, self.dm_image_info.width) + + self.img_info = ImageInfo( + img_idx=0, + img_shape=shape, + ori_shape=shape, + ) + + +@register_pytree_node +class InstanceSegmentationSample(OTXSample): + """OTXSample for instance segmentation tasks.""" + + subset: Subset = subset_field() + image: tv_tensors.Image = image_field(dtype=pl.UInt8, channels_first=True) + bboxes: np.ndarray | tv_tensors.BoundingBoxes = bbox_field(dtype=pl.Float32) + label: torch.Tensor = label_field(is_list=True) + polygons: np.ndarray = polygon_field(dtype=pl.Float32) # Ragged array of (Npoly, 2) arrays + dm_image_info: DmImageInfo = image_info_field() + + def __post_init__(self) -> None: + shape = (self.dm_image_info.height, self.dm_image_info.width) + + # Convert bboxes to tv_tensors format + if isinstance(self.bboxes, np.ndarray): + self.bboxes = tv_tensors.BoundingBoxes( + self.bboxes, + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=shape, + dtype=torch.float32, + ) + + self.img_info = ImageInfo( + img_idx=0, + img_shape=shape, + ori_shape=shape, + ) + + +@register_pytree_node +class InstanceSegmentationSampleWithMask(OTXSample): + """OTXSample for instance segmentation tasks.""" + + subset: Subset = subset_field() + image: tv_tensors.Image = image_field(dtype=pl.UInt8, channels_first=True) + bboxes: np.ndarray | tv_tensors.BoundingBoxes = bbox_field(dtype=pl.Float32) + masks: tv_tensors.Mask = instance_mask_field(dtype=pl.UInt8) + label: torch.Tensor = label_field(is_list=True) + polygons: np.ndarray = polygon_field(dtype=pl.Float32) # Ragged array of (Npoly, 2) arrays + dm_image_info: DmImageInfo = image_info_field() + + def __post_init__(self) -> None: + shape = (self.dm_image_info.height, self.dm_image_info.width) + + # Convert bboxes to tv_tensors format + if isinstance(self.bboxes, np.ndarray): + self.bboxes = tv_tensors.BoundingBoxes( + self.bboxes, + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=shape, + dtype=torch.float32, + ) + + self.img_info = ImageInfo( + img_idx=0, + img_shape=shape, + ori_shape=shape, + ) + + +@register_pytree_node +class KeypointSample(OTXSample): + """KeypointSample is a base class for OTX keypoint detection items.""" + + subset: Subset = subset_field() + image: np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8) + label: torch.Tensor = label_field(pl.Int32(), is_list=True) + keypoints: torch.Tensor = keypoints_field() + dm_image_info: DmImageInfo = image_info_field() + + def __post_init__(self) -> None: + shape = (self.dm_image_info.height, self.dm_image_info.width) + + self.img_info = ImageInfo( + img_idx=0, + img_shape=shape, + ori_shape=shape, + ) diff --git a/library/src/otx/data/entity/tile.py b/library/src/otx/data/entity/tile.py index cb8ce3f9dc..af29e7644b 100644 --- a/library/src/otx/data/entity/tile.py +++ b/library/src/otx/data/entity/tile.py @@ -11,7 +11,8 @@ import torch from torchvision import tv_tensors -from otx.data.entity.torch import OTXDataBatch, OTXDataItem +from otx.data.entity.sample import OTXSample +from otx.data.entity.torch import OTXDataBatch from otx.data.entity.utils import stack_batch from otx.types.task import OTXTaskType @@ -19,6 +20,7 @@ if TYPE_CHECKING: from datumaro import Polygon + from datumaro.experimental.fields import TileInfo from torch import LongTensor @@ -29,13 +31,11 @@ class TileDataEntity: Attributes: num_tiles (int): The number of tiles. entity_list (Sequence[OTXDataEntity]): A list of OTXDataEntity. - tile_attr_list (list[dict[str, int | str]]): The tile attributes including tile index and tile RoI information. ori_img_info (ImageInfo): The image information about the original image. """ num_tiles: int - entity_list: Sequence[OTXDataItem] - tile_attr_list: list[dict[str, int | str]] + entity_list: Sequence[OTXSample] ori_img_info: ImageInfo @property @@ -73,18 +73,16 @@ class OTXTileBatchDataEntity: batch_size (int): The size of the batch. batch_tiles (list[list[tv_tensors.Image]]): The batch of tile images. batch_tile_img_infos (list[list[ImageInfo]]): The batch of tiles image information. - batch_tile_attr_list (list[list[dict[str, int | str]]]): - The batch of tile attributes including tile index and tile RoI information. imgs_info (list[ImageInfo]): The image information about the original image. """ batch_size: int batch_tiles: list[list[tv_tensors.Image]] batch_tile_img_infos: list[list[ImageInfo]] - batch_tile_attr_list: list[TileAttrDictList] + batch_tile_tile_infos: list[list[TileInfo]] imgs_info: list[ImageInfo] - def unbind(self) -> list[tuple[TileAttrDictList, OTXDataBatch]]: + def unbind(self) -> list[tuple[list[TileInfo], OTXDataBatch]]: """Unbind batch data entity.""" raise NotImplementedError @@ -101,30 +99,29 @@ class TileBatchDetDataEntity(OTXTileBatchDataEntity): bboxes: list[tv_tensors.BoundingBoxes] labels: list[LongTensor] - def unbind(self) -> list[tuple[TileAttrDictList, OTXDataBatch]]: + def unbind(self) -> list[tuple[list[TileInfo], OTXDataBatch]]: """Unbind batch data entity for detection task.""" tiles = [tile for tiles in self.batch_tiles for tile in tiles] - tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] - tile_attr_list = [tile_attr for tile_attrs in self.batch_tile_attr_list for tile_attr in tile_attrs] - - batch_tile_attr_list = [ - tile_attr_list[i : i + self.batch_size] for i in range(0, len(tile_attr_list), self.batch_size) - ] + tile_img_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] + tile_tile_infos = [tile_info for tile_infos in self.batch_tile_tile_infos for tile_info in tile_infos] batch_data_entities = [] for i in range(0, len(tiles), self.batch_size): stacked_images, updated_img_info = stack_batch( tiles[i : i + self.batch_size], - tile_infos[i : i + self.batch_size], + tile_img_infos[i : i + self.batch_size], ) batch_data_entities.append( - OTXDataBatch( - batch_size=self.batch_size, - images=stacked_images, - imgs_info=updated_img_info, - ), + ( + tile_tile_infos[i : i + self.batch_size], + OTXDataBatch( + batch_size=self.batch_size, + images=stacked_images, + imgs_info=updated_img_info, + ), + ) ) - return list(zip(batch_tile_attr_list, batch_data_entities, strict=True)) + return batch_data_entities @classmethod def collate_fn(cls, batch_entities: list[TileDetDataEntity]) -> TileBatchDetDataEntity: @@ -135,8 +132,8 @@ def collate_fn(cls, batch_entities: list[TileDetDataEntity]) -> TileBatchDetData for tile_entity in batch_entities: for entity in tile_entity.entity_list: - if not isinstance(entity, OTXDataItem): - msg = "All entities should be OTXDataItem before collate_fn()" + if not isinstance(entity, OTXSample): + msg = "All entities should be OTXSample before collate_fn()" raise TypeError(msg) if entity.img_info is None: msg = "All entities should have img_info, but found None" @@ -146,10 +143,11 @@ def collate_fn(cls, batch_entities: list[TileDetDataEntity]) -> TileBatchDetData batch_size=batch_size, batch_tiles=[[entity.image for entity in tile_entity.entity_list] for tile_entity in batch_entities], batch_tile_img_infos=[ - [entity.img_info for entity in tile_entity.entity_list if isinstance(entity.img_info, ImageInfo)] - for tile_entity in batch_entities + [entity.img_info for entity in tile_entity.entity_list] for tile_entity in batch_entities + ], + batch_tile_tile_infos=[ + [entity.tile for entity in tile_entity.entity_list] for tile_entity in batch_entities ], - batch_tile_attr_list=[tile_entity.tile_attr_list for tile_entity in batch_entities], imgs_info=[tile_entity.ori_img_info for tile_entity in batch_entities], bboxes=[tile_entity.ori_bboxes for tile_entity in batch_entities], labels=[tile_entity.ori_labels for tile_entity in batch_entities], @@ -197,21 +195,21 @@ class TileBatchInstSegDataEntity(OTXTileBatchDataEntity): def unbind(self) -> list[tuple[TileAttrDictList, OTXDataBatch]]: """Unbind batch data entity for instance segmentation task.""" tiles = [tile for tiles in self.batch_tiles for tile in tiles] - tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] - tile_attr_list = [tile_attr for tile_attrs in self.batch_tile_attr_list for tile_attr in tile_attrs] + tile_img_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] + tile_tile_infos = [tile_info for tile_infos in self.batch_tile_tile_infos for tile_info in tile_infos] - batch_tile_attr_list = [ - tile_attr_list[i : i + self.batch_size] for i in range(0, len(tile_attr_list), self.batch_size) - ] batch_data_entities = [ - OTXDataBatch( - batch_size=self.batch_size, - images=tiles[i : i + self.batch_size], - imgs_info=tile_infos[i : i + self.batch_size], + ( + tile_tile_infos[i : i + self.batch_size], + OTXDataBatch( + batch_size=self.batch_size, + images=tiles[i : i + self.batch_size], + imgs_info=tile_img_infos[i : i + self.batch_size], + ), ) for i in range(0, len(tiles), self.batch_size) ] - return list(zip(batch_tile_attr_list, batch_data_entities, strict=True)) + return list(batch_data_entities) @classmethod def collate_fn(cls, batch_entities: list[TileInstSegDataEntity]) -> TileBatchInstSegDataEntity: @@ -222,8 +220,8 @@ def collate_fn(cls, batch_entities: list[TileInstSegDataEntity]) -> TileBatchIns for tile_entity in batch_entities: for entity in tile_entity.entity_list: - if not isinstance(entity, OTXDataItem): - msg = "All entities should be OTXDataItem before collate_fn()" + if not isinstance(entity, OTXSample): + msg = "All entities should be OTXSample before collate_fn()" raise TypeError(msg) if entity.img_info is None: msg = "All entities should have img_info, but found None" @@ -236,7 +234,9 @@ def collate_fn(cls, batch_entities: list[TileInstSegDataEntity]) -> TileBatchIns [entity.img_info for entity in tile_entity.entity_list if isinstance(entity.img_info, ImageInfo)] for tile_entity in batch_entities ], - batch_tile_attr_list=[tile_entity.tile_attr_list for tile_entity in batch_entities], + batch_tile_tile_infos=[ + [entity.tile for entity in tile_entity.entity_list] for tile_entity in batch_entities + ], imgs_info=[tile_entity.ori_img_info for tile_entity in batch_entities], bboxes=[tile_entity.ori_bboxes for tile_entity in batch_entities], labels=[tile_entity.ori_labels for tile_entity in batch_entities], @@ -274,22 +274,22 @@ class TileBatchSegDataEntity(OTXTileBatchDataEntity): def unbind(self) -> list[tuple[list[dict[str, int | str]], OTXDataBatch]]: """Unbind batch data entity for semantic segmentation task.""" tiles = [tile for tiles in self.batch_tiles for tile in tiles] - tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] - tile_attr_list = [tile_attr for tile_attrs in self.batch_tile_attr_list for tile_attr in tile_attrs] + tile_img_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] + tile_tile_infos = [tile_info for tile_infos in self.batch_tile_tile_infos for tile_info in tile_infos] - batch_tile_attr_list = [ - tile_attr_list[i : i + self.batch_size] for i in range(0, len(tile_attr_list), self.batch_size) - ] batch_data_entities = [ - OTXDataBatch( - batch_size=self.batch_size, - images=tv_tensors.wrap(torch.stack(tiles[i : i + self.batch_size]), like=tiles[0]), - imgs_info=tile_infos[i : i + self.batch_size], - masks=[torch.empty((1, 1, 1)) for _ in range(self.batch_size)], + ( + tile_tile_infos[i : i + self.batch_size], + OTXDataBatch( + batch_size=self.batch_size, + images=tv_tensors.wrap(torch.stack(tiles[i : i + self.batch_size]), like=tiles[0]), + imgs_info=tile_img_infos[i : i + self.batch_size], + masks=[torch.empty((1, 1, 1)) for _ in range(self.batch_size)], + ), ) for i in range(0, len(tiles), self.batch_size) ] - return list(zip(batch_tile_attr_list, batch_data_entities)) + return list(batch_data_entities) @classmethod def collate_fn(cls, batch_entities: list[TileSegDataEntity]) -> TileBatchSegDataEntity: @@ -300,8 +300,8 @@ def collate_fn(cls, batch_entities: list[TileSegDataEntity]) -> TileBatchSegData for tile_entity in batch_entities: for entity in tile_entity.entity_list: - if not isinstance(entity, OTXDataItem): - msg = "All entities should be OTXDataItem before collate_fn()" + if not isinstance(entity, OTXSample): + msg = "All entities should be OTXSample before collate_fn()" raise TypeError(msg) if entity.img_info is None: msg = "All entities should have img_info, but found None" @@ -311,10 +311,11 @@ def collate_fn(cls, batch_entities: list[TileSegDataEntity]) -> TileBatchSegData batch_size=batch_size, batch_tiles=[[entity.image for entity in tile_entity.entity_list] for tile_entity in batch_entities], batch_tile_img_infos=[ - [entity.img_info for entity in tile_entity.entity_list if isinstance(entity.img_info, ImageInfo)] - for tile_entity in batch_entities + [entity.img_info for entity in tile_entity.entity_list] for tile_entity in batch_entities + ], + batch_tile_tile_infos=[ + [entity.tile for entity in tile_entity.entity_list] for tile_entity in batch_entities ], - batch_tile_attr_list=[tile_entity.tile_attr_list for tile_entity in batch_entities], imgs_info=[tile_entity.ori_img_info for tile_entity in batch_entities], masks=[tile_entity.ori_masks for tile_entity in batch_entities], ) diff --git a/library/src/otx/data/entity/torch/torch.py b/library/src/otx/data/entity/torch/torch.py index ff85c15c51..6613507c29 100644 --- a/library/src/otx/data/entity/torch/torch.py +++ b/library/src/otx/data/entity/torch/torch.py @@ -22,7 +22,6 @@ if TYPE_CHECKING: import numpy as np - from datumaro import Polygon from torchvision.tv_tensors import BoundingBoxes, Mask from otx.data.entity.base import ImageInfo @@ -41,7 +40,7 @@ class OTXDataItem(ValidateItemMixin, Mapping): masks (Mask | None): The masks, optional. bboxes (BoundingBoxes | None): The bounding boxes, optional. keypoints (torch.Tensor | None): The keypoints, optional. - polygons (list[Polygon] | None): The polygons, optional. + polygons (np.ndarray | None): The polygons, optional. img_info (ImageInfo | None): Additional image information, optional. """ @@ -50,7 +49,7 @@ class OTXDataItem(ValidateItemMixin, Mapping): masks: Mask | None = None bboxes: BoundingBoxes | None = None keypoints: torch.Tensor | None = None - polygons: list[Polygon] | None = None + polygons: np.ndarray | None = None img_info: ImageInfo | None = None # TODO(ashwinvaidya17): revisit and try to remove this @staticmethod @@ -125,7 +124,7 @@ class OTXDataBatch(ValidateBatchMixin): masks: list[Mask] | None = None bboxes: list[BoundingBoxes] | None = None keypoints: list[torch.Tensor] | None = None - polygons: list[list[Polygon]] | None = None + polygons: list[np.ndarray] | None = None imgs_info: Sequence[ImageInfo | None] | None = None # TODO(ashwinvaidya17): revisit def pin_memory(self) -> OTXDataBatch: diff --git a/library/src/otx/data/entity/torch/validations.py b/library/src/otx/data/entity/torch/validations.py index 9b9d89cb50..75c30eae62 100644 --- a/library/src/otx/data/entity/torch/validations.py +++ b/library/src/otx/data/entity/torch/validations.py @@ -9,7 +9,6 @@ import numpy as np import torch -from datumaro import Polygon from torchvision.tv_tensors import BoundingBoxes, Mask from otx.data.entity.base import ImageInfo @@ -154,15 +153,15 @@ def _keypoints_validator(keypoints: torch.Tensor) -> torch.Tensor: return keypoints @staticmethod - def _polygons_validator(polygons: list[Polygon]) -> list[Polygon]: + def _polygons_validator(polygons: np.ndarray) -> np.ndarray: """Validate the polygons.""" if len(polygons) == 0: return polygons - if not isinstance(polygons, list): - msg = f"Polygons must be a list of datumaro.Polygon. Got {type(polygons)}" + if not isinstance(polygons, np.ndarray): + msg = f"Polygons must be a np.ndarray of np.ndarray. Got {type(polygons)}" raise TypeError(msg) - if not isinstance(polygons[0], Polygon): - msg = f"Polygons must be a list of datumaro.Polygon. Got {type(polygons[0])}" + if not isinstance(polygons[0], np.ndarray): + msg = f"Polygons must be a np.ndarray of np.ndarray. Got {type(polygons[0])}" raise TypeError(msg) return polygons @@ -388,20 +387,20 @@ def _batch_size_validator(batch_size: int) -> int: return batch_size @staticmethod - def _polygons_validator(polygons_batch: list[list[Polygon] | None]) -> list[list[Polygon] | None]: + def _polygons_validator(polygons_batch: list[np.ndarray | None]) -> list[np.ndarray | None]: """Validate the polygons batch.""" if all(polygon is None for polygon in polygons_batch): return [] if not isinstance(polygons_batch, list): msg = "Polygons batch must be a list" raise TypeError(msg) - if not isinstance(polygons_batch[0], list): - msg = "Polygons batch must be a list of list" + if not isinstance(polygons_batch[0], np.ndarray): + msg = "Polygons batch must be a list of np.ndarray of np.ndarray" raise TypeError(msg) if len(polygons_batch[0]) == 0: msg = f"Polygons batch must not be empty. Got {polygons_batch}" raise ValueError(msg) - if not isinstance(polygons_batch[0][0], Polygon): - msg = "Polygons batch must be a list of list of datumaro.Polygon" + if not isinstance(polygons_batch[0][0], np.ndarray): + msg = "Polygons batch must be a list of np.ndarray of np.ndarray" raise TypeError(msg) return polygons_batch diff --git a/library/src/otx/data/factory.py b/library/src/otx/data/factory.py index 7f601c4e69..222a4645ef 100644 --- a/library/src/otx/data/factory.py +++ b/library/src/otx/data/factory.py @@ -7,14 +7,21 @@ from typing import TYPE_CHECKING +from datumaro.components.annotation import AnnotationType +from datumaro.experimental.categories import LabelCategories +from datumaro.experimental.legacy import convert_from_legacy + +from otx import LabelInfo, NullLabelInfo from otx.types.image import ImageColorChannel from otx.types.task import OTXTaskType from otx.types.transformer_libs import TransformLibType from .dataset.base import OTXDataset, Transforms +from .dataset.base_new import OTXDataset as OTXDatasetNew if TYPE_CHECKING: - from datumaro import Dataset as DmDataset + from datumaro.components.dataset import Dataset as DmDataset + from datumaro.experimental import Dataset as DatasetNew from otx.config.data import SubsetConfig @@ -41,15 +48,16 @@ class OTXDatasetFactory: @classmethod def create( - cls: type[OTXDatasetFactory], + cls, task: OTXTaskType, - dm_subset: DmDataset, + dm_subset: DmDataset | DatasetNew, cfg_subset: SubsetConfig, data_format: str, image_color_channel: ImageColorChannel = ImageColorChannel.RGB, include_polygons: bool = False, - ignore_index: int = 255, - ) -> OTXDataset: + # TODO(gdlg): Add support for ignore_index again + ignore_index: int = 255, # noqa: ARG003 + ) -> OTXDataset | OTXDatasetNew: """Create OTXDataset.""" transforms = TransformLibFactory.generate(cfg_subset) common_kwargs = { @@ -66,43 +74,70 @@ def create( OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION, ): - from .dataset.anomaly import OTXAnomalyDataset + from .dataset.anomaly_new import OTXAnomalyDataset + + dataset = convert_from_legacy(dm_subset) + common_kwargs["dm_subset"] = dataset return OTXAnomalyDataset(task_type=task, **common_kwargs) if task == OTXTaskType.MULTI_CLASS_CLS: - from .dataset.classification import OTXMulticlassClsDataset + from .dataset.classification_new import OTXMulticlassClsDataset + dataset = convert_from_legacy(dm_subset) + common_kwargs["dm_subset"] = dataset return OTXMulticlassClsDataset(**common_kwargs) if task == OTXTaskType.MULTI_LABEL_CLS: - from .dataset.classification import OTXMultilabelClsDataset + from .dataset.classification_new import OTXMultilabelClsDataset + dataset = convert_from_legacy(dm_subset) + common_kwargs["dm_subset"] = dataset return OTXMultilabelClsDataset(**common_kwargs) if task == OTXTaskType.H_LABEL_CLS: - from .dataset.classification import OTXHlabelClsDataset + from .dataset.classification_new import OTXHlabelClsDataset + dataset = convert_from_legacy(dm_subset) + common_kwargs["dm_subset"] = dataset return OTXHlabelClsDataset(**common_kwargs) if task == OTXTaskType.DETECTION: - from .dataset.detection import OTXDetectionDataset + from .dataset.detection_new import OTXDetectionDataset + dataset = convert_from_legacy(dm_subset) + common_kwargs["dm_subset"] = dataset return OTXDetectionDataset(**common_kwargs) if task in [OTXTaskType.ROTATED_DETECTION, OTXTaskType.INSTANCE_SEGMENTATION]: - from .dataset.instance_segmentation import OTXInstanceSegDataset + from .dataset.instance_segmentation_new import OTXInstanceSegDataset + dataset = convert_from_legacy(dm_subset) + common_kwargs["dm_subset"] = dataset return OTXInstanceSegDataset(include_polygons=include_polygons, **common_kwargs) if task == OTXTaskType.SEMANTIC_SEGMENTATION: - from .dataset.segmentation import OTXSegmentationDataset + from .dataset.segmentation_new import OTXSegmentationDataset - return OTXSegmentationDataset(ignore_index=ignore_index, **common_kwargs) + dataset = convert_from_legacy(dm_subset) + common_kwargs["dm_subset"] = dataset + return OTXSegmentationDataset(**common_kwargs) if task == OTXTaskType.KEYPOINT_DETECTION: - from .dataset.keypoint_detection import OTXKeypointDetectionDataset + from .dataset.keypoint_detection_new import OTXKeypointDetectionDataset + dataset = convert_from_legacy(dm_subset) + common_kwargs["dm_subset"] = dataset return OTXKeypointDetectionDataset(**common_kwargs) raise NotImplementedError(task) + + @staticmethod + def _get_label_categories(dm_subset: DmDataset, data_format: str) -> LabelCategories: + if dm_subset.categories() and data_format == "arrow": + label_info = LabelInfo.from_dm_label_groups_arrow(dm_subset.categories()[AnnotationType.label]) + elif dm_subset.categories(): + label_info = LabelInfo.from_dm_label_groups(dm_subset.categories()[AnnotationType.label]) + else: + label_info = NullLabelInfo() + return LabelCategories(labels=label_info.label_names) diff --git a/library/src/otx/data/module.py b/library/src/otx/data/module.py index 2d19848668..b9503a601f 100644 --- a/library/src/otx/data/module.py +++ b/library/src/otx/data/module.py @@ -15,7 +15,7 @@ from torchvision.transforms.v2 import Normalize from otx.config.data import TileConfig -from otx.data.dataset.tile import OTXTileDatasetFactory +from otx.data.dataset.tile_new import OTXTileDatasetFactory from otx.data.factory import OTXDatasetFactory from otx.data.utils import adapt_input_size_to_dataset, adapt_tile_config, get_adaptive_num_workers, instantiate_sampler from otx.data.utils.pre_filtering import pre_filtering diff --git a/library/src/otx/data/samplers/balanced_sampler.py b/library/src/otx/data/samplers/balanced_sampler.py index 43bc11fae0..1cef96ac69 100644 --- a/library/src/otx/data/samplers/balanced_sampler.py +++ b/library/src/otx/data/samplers/balanced_sampler.py @@ -11,10 +11,9 @@ import torch from torch.utils.data import Sampler -from otx.data.utils import get_idx_list_per_classes - if TYPE_CHECKING: from otx.data.dataset.base import OTXDataset + from otx.data.dataset.base_new import OTXDataset as OTXDatasetNew class BalancedSampler(Sampler): @@ -43,7 +42,7 @@ class BalancedSampler(Sampler): def __init__( self, - dataset: OTXDataset, + dataset: OTXDataset | OTXDatasetNew, efficient_mode: bool = False, num_replicas: int = 1, rank: int = 0, @@ -61,7 +60,8 @@ def __init__( super().__init__(dataset) # img_indices: dict[label: list[idx]] - ann_stats = get_idx_list_per_classes(dataset.dm_subset) + ann_stats = dataset.get_idx_list_per_classes() + self.img_indices = {k: torch.tensor(v, dtype=torch.int64) for k, v in ann_stats.items() if len(v) > 0} self.num_cls = len(self.img_indices.keys()) self.data_length = len(self.dataset) diff --git a/library/src/otx/data/samplers/class_incremental_sampler.py b/library/src/otx/data/samplers/class_incremental_sampler.py index 05e6f65375..68d0f2ee8d 100644 --- a/library/src/otx/data/samplers/class_incremental_sampler.py +++ b/library/src/otx/data/samplers/class_incremental_sampler.py @@ -12,7 +12,6 @@ from torch.utils.data import Sampler from otx.data.dataset.base import OTXDataset -from otx.data.utils import get_idx_list_per_classes class ClassIncrementalSampler(Sampler): @@ -65,7 +64,7 @@ def __init__( super().__init__(dataset) # Need to split new classes dataset indices & old classses dataset indices - ann_stats = get_idx_list_per_classes(dataset.dm_subset, True) + ann_stats = dataset.get_idx_list_per_classes(use_string_label=True) new_indices, old_indices = [], [] for cls in new_classes: new_indices.extend(ann_stats[cls]) diff --git a/library/src/otx/data/transform_libs/torchvision.py b/library/src/otx/data/transform_libs/torchvision.py index 16b25c6559..5ba14e043b 100644 --- a/library/src/otx/data/transform_libs/torchvision.py +++ b/library/src/otx/data/transform_libs/torchvision.py @@ -7,7 +7,6 @@ import ast import copy -import itertools import math import operator import typing @@ -36,6 +35,7 @@ _resize_image_info, _resized_crop_image_info, ) +from otx.data.entity.sample import OTXSample from otx.data.entity.torch import OTXDataItem from otx.data.transform_libs.utils import ( CV2_INTERP_CODES, @@ -1409,9 +1409,8 @@ def _transform_polygons( valid_index = valid_index.numpy() # Filter polygons using valid_index - filtered_polygons = [p for p, keep in zip(inputs.polygons, valid_index) if keep] - - if filtered_polygons: + filtered_polygons = inputs.polygons[valid_index] + if len(filtered_polygons) > 0: inputs.polygons = project_polygons(filtered_polygons, warp_matrix, output_shape) def _recompute_bboxes(self, inputs: OTXDataItem, output_shape: tuple[int, int]) -> None: @@ -1442,14 +1441,13 @@ def _recompute_bboxes(self, inputs: OTXDataItem, output_shape: tuple[int, int]) elif has_polygons: polygons = inputs.polygons - for i, polygon in enumerate(polygons): # type: ignore[arg-type] - points_1d = np.array(polygon.points, dtype=np.float32) - if len(points_1d) % 2 != 0: - continue - points = points_1d.reshape(-1, 2) - x, y, w, h = cv2.boundingRect(points) - bboxes[i] = np.array([x, y, x + w, y + h]) + for i, poly_points in enumerate(polygons): # type: ignore[arg-type] + if poly_points.size > 0: + points = poly_points.astype(np.float32) + if len(points) >= 3: # Need at least 3 points for valid polygon + x, y, w, h = cv2.boundingRect(points) + bboxes[i] = np.array([x, y, x + w, y + h]) inputs.bboxes = tv_tensors.BoundingBoxes( bboxes, @@ -1765,9 +1763,7 @@ def forward(self, *_inputs: OTXDataItem) -> OTXDataItem | None: if len(mosaic_masks) > 0: inputs.masks = np.concatenate(mosaic_masks, axis=0)[inside_inds] if len(mosaic_polygons) > 0: - inputs.polygons = [ - polygon for ind, polygon in zip(inside_inds, itertools.chain(*mosaic_polygons)) if ind - ] # type: ignore[union-attr] + inputs.polygons = np.concatenate(mosaic_polygons, axis=0)[inside_inds] return self.convert(inputs) def _mosaic_combine( @@ -2040,7 +2036,7 @@ def forward(self, *_inputs: OTXDataItem) -> OTXDataItem | None: mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32) # TODO(ashwinvaidya17): remove this once we have a unified TorchDataItem - if isinstance(retrieve_results, OTXDataItem): + if isinstance(retrieve_results, (OTXDataItem, OTXSample)): retrieve_gt_bboxes_labels = retrieve_results.label else: retrieve_gt_bboxes_labels = retrieve_results.labels @@ -2113,9 +2109,9 @@ def forward(self, *_inputs: OTXDataItem) -> OTXDataItem | None: ) # 8. mix up - mixup_gt_polygons = list(itertools.chain(*[inputs.polygons, retrieve_gt_polygons])) + mixup_gt_polygons = np.concatenate((inputs.polygons, retrieve_gt_polygons)) - inputs.polygons = [mixup_gt_polygons[i] for i in np.where(inside_inds)[0]] + inputs.polygons = mixup_gt_polygons[np.where(inside_inds)[0]] return self.convert(inputs) @@ -2632,8 +2628,16 @@ def _crop_data( ) if (polygons := getattr(inputs, "polygons", None)) is not None and len(polygons) > 0: + # Handle both ragged array and legacy polygon formats + if isinstance(polygons, np.ndarray): + # Filter valid polygons using valid_inds for ragged array + filtered_polygons = polygons[valid_inds.nonzero()[0]] + else: + # Filter valid polygons for legacy format + filtered_polygons = [polygons[i] for i in valid_inds.nonzero()[0]] + inputs.polygons = crop_polygons( - [polygons[i] for i in valid_inds.nonzero()[0]], + filtered_polygons, np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]), *orig_shape, ) diff --git a/library/src/otx/data/transform_libs/utils.py b/library/src/otx/data/transform_libs/utils.py index adae5fb7c6..0adb01a957 100644 --- a/library/src/otx/data/transform_libs/utils.py +++ b/library/src/otx/data/transform_libs/utils.py @@ -10,14 +10,12 @@ import copy import functools import inspect -import itertools import weakref from typing import Sequence import cv2 import numpy as np import torch -from datumaro import Polygon from shapely import geometry from torch import BoolTensor, Tensor @@ -129,6 +127,7 @@ def to_np_image(img: np.ndarray | Tensor | list) -> np.ndarray | list[np.ndarray return img if isinstance(img, list): return [to_np_image(im) for im in img] + return np.ascontiguousarray(img.numpy().transpose(1, 2, 0)) @@ -178,28 +177,37 @@ def rescale_masks( ) -def rescale_polygons(polygons: list[Polygon], scale_factor: float | tuple[float, float]) -> list[Polygon]: +def rescale_polygons(polygons: np.ndarray, scale_factor: float | tuple[float, float]) -> np.ndarray: """Rescale polygons as large as possible while keeping the aspect ratio. Args: - polygons (np.ndarray): Polygons to be rescaled. - scale_factor (float | tuple[float, float]): Scale factor to be applied to polygons with (height, width) + polygons: Object array containing np.ndarray objects of shape (Npoly, 2) + scale_factor: Scale factor to be applied to polygons with (height, width) or single float value. Returns: - (np.ndarray) : The rescaled polygons. + np.ndarray: The rescaled polygons. """ + if len(polygons) == 0: + return polygons + if isinstance(scale_factor, float): w_scale = h_scale = scale_factor else: h_scale, w_scale = scale_factor - for polygon in polygons: - p = np.asarray(polygon.points, dtype=np.float32) - p[0::2] *= w_scale - p[1::2] *= h_scale - polygon.points = p.tolist() - return polygons + rescaled_polygons = np.empty_like(polygons) + for i, poly_points in enumerate(polygons): + if poly_points.size > 0: + scaled_points = poly_points.astype(np.float32) + scaled_points[:, 0] *= w_scale # x coordinates + scaled_points[:, 1] *= h_scale # y coordinates + rescaled_polygons[i] = scaled_points + else: + # Handle empty or invalid polygons + rescaled_polygons[i] = np.array([[0, 0], [0, 0], [0, 0]], dtype=np.float32) + + return rescaled_polygons def rescale_keypoints(keypoints: Tensor, scale_factor: float | tuple[float, float]) -> Tensor: @@ -306,25 +314,45 @@ def translate_masks( def translate_polygons( - polygons: list[Polygon], + polygons: np.ndarray, out_shape: tuple[int, int], offset: int | float, direction: str = "horizontal", border_value: int | float = 0, -) -> list[Polygon]: - """Translate polygons.""" +) -> np.ndarray: + """Translate polygons. + + Args: + polygons: Object array containing np.ndarray objects of shape (Npoly, 2) + out_shape: Output shape (height, width) + offset: Translation offset + direction: Translation direction, "horizontal" or "vertical" + border_value: Border value (only used for legacy compatibility) + + Returns: + np.ndarray: Translated polygons + """ assert ( # noqa: S101 border_value is None or border_value == 0 ), f"Here border_value is not used, and defaultly should be None or 0. got {border_value}." + if len(polygons) == 0: + return polygons + axis = 0 if direction == "horizontal" else 1 out = out_shape[1] if direction == "horizontal" else out_shape[0] - for polygon in polygons: - p = np.asarray(polygon.points) - p[axis::2] = np.clip(p[axis::2] + offset, 0, out) - polygon.points = p.tolist() - return polygons + translated_polygons = np.empty_like(polygons) + for i, poly_points in enumerate(polygons): + if poly_points.size > 0: + translated_points = poly_points.copy() + translated_points[:, axis] = np.clip(translated_points[:, axis] + offset, 0, out) + translated_polygons[i] = translated_points + else: + # Handle empty or invalid polygons + translated_polygons[i] = np.array([[0, 0], [0, 0], [0, 0]], dtype=np.float32) + + return translated_polygons def _get_translate_matrix(offset: int | float, direction: str = "horizontal") -> np.ndarray: @@ -720,19 +748,34 @@ def flip_masks(masks: np.ndarray, direction: str = "horizontal") -> np.ndarray: return np.stack([flip_image(mask, direction=direction) for mask in masks]) -def flip_polygons(polygons: list[Polygon], height: int, width: int, direction: str = "horizontal") -> list[Polygon]: - """Flip polygons alone the given direction.""" - for polygon in polygons: - p = np.asarray(polygon.points) +def flip_polygons(polygons: np.ndarray, height: int, width: int, direction: str = "horizontal") -> np.ndarray: + """Flip polygons along the given direction. + + Args: + polygons: Object array containing np.ndarray objects of shape (Npoly, 2) + height: Image height + width: Image width + direction: Flip direction, "horizontal", "vertical", or "diagonal" + + Returns: + np.ndarray: Flipped polygons + """ + if len(polygons) == 0: + return polygons + + flipped_polygons = np.empty_like(polygons) + for i, poly_points in enumerate(polygons): + flipped_points = poly_points.copy() if direction == "horizontal": - p[0::2] = width - p[0::2] + flipped_points[:, 0] = width - flipped_points[:, 0] # x coordinates elif direction == "vertical": - p[1::2] = height - p[1::2] + flipped_points[:, 1] = height - flipped_points[:, 1] # y coordinates else: - p[0::2] = width - p[0::2] - p[1::2] = height - p[1::2] - polygon.points = p.tolist() - return polygons + flipped_points[:, 0] = width - flipped_points[:, 0] # x coordinates + flipped_points[:, 1] = height - flipped_points[:, 1] # y coordinates + flipped_polygons[i] = flipped_points + + return flipped_polygons def project_bboxes(boxes: Tensor, homography_matrix: Tensor | np.ndarray) -> Tensor: @@ -760,47 +803,46 @@ def project_bboxes(boxes: Tensor, homography_matrix: Tensor | np.ndarray) -> Ten def project_polygons( - polygons: list[Polygon], + polygons: np.ndarray, homography_matrix: np.ndarray, out_shape: tuple[int, int], -) -> list[Polygon]: +) -> np.ndarray: """Transform polygons using a homography matrix. Args: - polygons (list[Polygon]): List of polygons to be transformed. - homography_matrix (np.ndarray): Homography matrix of shape (3, 3) for geometric transformation. - out_shape (tuple[int, int]): Output shape (height, width) for boundary clipping. + polygons: Object array containing np.ndarray objects of shape (Npoly, 2) + homography_matrix: Homography matrix of shape (3, 3) for geometric transformation + out_shape: Output shape (height, width) for boundary clipping Returns: - list[Polygon]: List of transformed polygons. + np.ndarray: Transformed polygons """ - if not polygons: + if len(polygons) == 0: return polygons height, width = out_shape - transformed_polygons = [] - - for polygon in polygons: - points = np.array(polygon.points, dtype=np.float32) + transformed_polygons = np.empty_like(polygons) - if len(points) % 2 != 0: - # Invalid polygon - transformed_polygons.append(Polygon(points=[0, 0, 0, 0, 0, 0])) + for i, poly_points in enumerate(polygons): + if poly_points.size == 0: + transformed_polygons[i] = np.array([[0, 0], [0, 0], [0, 0]], dtype=np.float32) continue + # Convert to homogeneous coordinates + points_h = np.hstack([poly_points, np.ones((poly_points.shape[0], 1), dtype=np.float32)]) # (N, 3) - points_2d = points.reshape(-1, 2) - points_h = np.hstack([points_2d, np.ones((points_2d.shape[0], 1), dtype=np.float32)]) # (N, 3) + # Apply transformation proj = homography_matrix @ points_h.T # (3, N) + # Convert back to Cartesian coordinates denom = proj[2:3] denom[denom == 0] = 1e-6 # avoid divide-by-zero proj_cartesian = (proj[:2] / denom).T # (N, 2) - # Clip + # Clip to image boundaries proj_cartesian[:, 0] = np.clip(proj_cartesian[:, 0], 0, width - 1) proj_cartesian[:, 1] = np.clip(proj_cartesian[:, 1], 0, height - 1) - transformed_polygons.append(Polygon(points=proj_cartesian.flatten().tolist())) + transformed_polygons[i] = proj_cartesian.astype(np.float32) return transformed_polygons @@ -857,8 +899,18 @@ def crop_masks(masks: np.ndarray, bbox: np.ndarray) -> np.ndarray: return masks[:, y1 : y1 + h, x1 : x1 + w] -def crop_polygons(polygons: list[Polygon], bbox: np.ndarray, height: int, width: int) -> list[Polygon]: - """Crop each polygon by the given bbox.""" +def crop_polygons(polygons: np.ndarray, bbox: np.ndarray, height: int, width: int) -> np.ndarray: + """Crop each polygon by the given bbox. + + Args: + polygons: Object array containing np.ndarray objects of shape (Npoly, 2) + bbox: Bounding box as [x1, y1, x2, y2] + height: Original image height + width: Original image width + + Returns: + np.ndarray: Cropped polygons + """ assert isinstance(bbox, np.ndarray) # noqa: S101 assert bbox.ndim == 1 # noqa: S101 @@ -874,21 +926,30 @@ def crop_polygons(polygons: list[Polygon], bbox: np.ndarray, height: int, width: # reference: https://github.com/shapely/shapely/issues/1345 initial_settings = np.seterr() np.seterr(invalid="ignore") - for polygon in polygons: - cropped_poly_per_obj: list[Polygon] = [] - p = np.asarray(polygon.points).copy() - p = geometry.Polygon(p.reshape(-1, 2)).buffer(0.0) + cropped_polygons = np.empty_like(polygons) + + for i, polygon_points in enumerate(polygons): + cropped_poly_per_obj = [] + + # Convert ragged array polygon to shapely polygon + if polygon_points.size == 0: + # Handle empty or invalid polygons + cropped_polygons[i] = np.array([[0, 0], [0, 0], [0, 0]], dtype=np.float32) + continue + + p = geometry.Polygon(polygon_points).buffer(0.0) + # polygon must be valid to perform intersection. if not p.is_valid: # a dummy polygon to avoid misalignment between masks and boxes - polygon.points = [0, 0, 0, 0, 0, 0] + cropped_polygons[i] = np.array([[0, 0], [0, 0], [0, 0]], dtype=np.float32) continue cropped = p.intersection(crop_box) if cropped.is_empty: # a dummy polygon to avoid misalignment between masks and boxes - polygon.points = [0, 0, 0, 0, 0, 0] + cropped_polygons[i] = np.array([[0, 0], [0, 0], [0, 0]], dtype=np.float32) continue cropped = cropped.geoms if isinstance(cropped, geometry.collection.BaseMultipartGeometry) else [cropped] @@ -905,15 +966,17 @@ def crop_polygons(polygons: list[Polygon], bbox: np.ndarray, height: int, width: coords = coords[:-1] coords[:, 0] -= x1 coords[:, 1] -= y1 - cropped_poly_per_obj.append(coords.reshape(-1).tolist()) + cropped_poly_per_obj.append(coords) # a dummy polygon to avoid misalignment between masks and boxes if len(cropped_poly_per_obj) == 0: - cropped_poly_per_obj.append([0, 0, 0, 0, 0, 0]) + cropped_polygons[i] = np.array([[0, 0], [0, 0], [0, 0]], dtype=np.float32) + else: + # Concatenate all cropped polygons for this object into a single array + cropped_polygons[i] = np.concatenate(cropped_poly_per_obj, axis=0) - polygon.points = list(itertools.chain(*cropped_poly_per_obj)) np.seterr(**initial_settings) - return polygons + return cropped_polygons def get_bboxes_from_masks(masks: Tensor) -> np.ndarray: @@ -933,20 +996,32 @@ def get_bboxes_from_masks(masks: Tensor) -> np.ndarray: return bboxes -def get_bboxes_from_polygons(polygons: list[Polygon], height: int, width: int) -> np.ndarray: - """Create boxes from polygons.""" +def get_bboxes_from_polygons(polygons: np.ndarray, height: int, width: int) -> np.ndarray: + """Create boxes from polygons. + + Args: + polygons: Ragged array of (Npoly, 2) arrays + height: Image height + width: Image width + + Returns: + np.ndarray: Bounding boxes in XYXY format, shape (N, 4) + """ num_polygons = len(polygons) boxes = np.zeros((num_polygons, 4), dtype=np.float32) - for idx, polygon in enumerate(polygons): - # simply use a number that is big enough for comparison with coordinates - xy_min = np.array([width * 2, height * 2], dtype=np.float32) - xy_max = np.zeros(2, dtype=np.float32) - - xy = np.array(polygon.points).reshape(-1, 2).astype(np.float32) - xy_min = np.minimum(xy_min, np.min(xy, axis=0)) - xy_max = np.maximum(xy_max, np.max(xy, axis=0)) - boxes[idx, :2] = xy_min - boxes[idx, 2:] = xy_max + + ref_xy_min = np.array([width * 2, height * 2], dtype=np.float32) + ref_xy_max = np.zeros(2, dtype=np.float32) + + for idx, poly_points in enumerate(polygons): + if poly_points.size > 0: + xy_min = np.minimum(ref_xy_min, np.min(poly_points, axis=0)) + xy_max = np.maximum(ref_xy_max, np.max(poly_points, axis=0)) + boxes[idx, :2] = xy_min + boxes[idx, 2:] = xy_max + else: + # Handle empty or invalid polygons + boxes[idx] = [0, 0, 0, 0] return boxes diff --git a/library/src/otx/data/utils/__init__.py b/library/src/otx/data/utils/__init__.py index bc2ed250b8..31242128b2 100644 --- a/library/src/otx/data/utils/__init__.py +++ b/library/src/otx/data/utils/__init__.py @@ -7,7 +7,6 @@ adapt_input_size_to_dataset, adapt_tile_config, get_adaptive_num_workers, - get_idx_list_per_classes, import_object_from_module, instantiate_sampler, ) @@ -17,6 +16,5 @@ "adapt_input_size_to_dataset", "instantiate_sampler", "get_adaptive_num_workers", - "get_idx_list_per_classes", "import_object_from_module", ] diff --git a/library/src/otx/data/utils/structures/mask/mask_target.py b/library/src/otx/data/utils/structures/mask/mask_target.py index 75f310e40a..a39db80b4b 100644 --- a/library/src/otx/data/utils/structures/mask/mask_target.py +++ b/library/src/otx/data/utils/structures/mask/mask_target.py @@ -14,7 +14,6 @@ import numpy as np import torch -from datumaro.components.annotation import Polygon from torch import Tensor from torch.nn.modules.utils import _pair from torchvision import tv_tensors @@ -25,7 +24,7 @@ def mask_target( pos_proposals_list: list[Tensor], pos_assigned_gt_inds_list: list[Tensor], - gt_masks_list: list[list[Polygon]] | list[tv_tensors.Mask], + gt_masks_list: list[np.ndarray] | list[tv_tensors.Mask], mask_size: int, meta_infos: list[dict], ) -> Tensor: @@ -36,8 +35,7 @@ def mask_target( images, each has shape (num_pos, 4). pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each positive proposals, each has shape (num_pos,). - gt_masks_list (list[list[Polygon]] or list[tv_tensors.Mask]): Ground truth masks of - each image. + gt_masks_list (list[np.ndarray] or list[tv_tensors.Mask]): Ground truth masks or polygons. mask_size (int): The mask size. meta_infos (list[dict]): Meta information of each image. @@ -62,7 +60,7 @@ def mask_target( def mask_target_single( pos_proposals: Tensor, pos_assigned_gt_inds: Tensor, - gt_masks: list[Polygon] | tv_tensors.Mask, + gt_masks: np.ndarray | tv_tensors.Mask, mask_size: list[int], meta_info: dict, ) -> Tensor: @@ -71,7 +69,7 @@ def mask_target_single( Args: pos_proposals (Tensor): Positive proposals, has shape (num_pos, 4). pos_assigned_gt_inds (Tensor): Assigned GT indices for positive proposals, has shape (num_pos,). - gt_masks (list[Polygon] or tv_tensors.Mask): Ground truth masks as list of polygons or tv_tensors.Mask. + gt_masks (np.ndarray or tv_tensors.Mask): Ground truth masks as polygons or tv_tensors.Mask. mask_size (list[int]): The mask size. meta_info (dict): Meta information of the image. @@ -83,7 +81,7 @@ def mask_target_single( warnings.warn("No ground truth masks are provided!", stacklevel=2) return pos_proposals.new_zeros((0, *mask_size)) - if isinstance(gt_masks[0], Polygon): + if isinstance(gt_masks, np.ndarray): crop_and_resize = crop_and_resize_polygons elif isinstance(gt_masks, tv_tensors.Mask): crop_and_resize = crop_and_resize_masks diff --git a/library/src/otx/data/utils/structures/mask/mask_util.py b/library/src/otx/data/utils/structures/mask/mask_util.py index 0d2dec0aa3..ff43788628 100644 --- a/library/src/otx/data/utils/structures/mask/mask_util.py +++ b/library/src/otx/data/utils/structures/mask/mask_util.py @@ -10,7 +10,6 @@ import numpy as np import pycocotools.mask as mask_utils import torch -from datumaro import Polygon from torchvision.ops import roi_align if TYPE_CHECKING: @@ -18,44 +17,45 @@ def polygon_to_bitmap( - polygons: list[Polygon], + polygons: np.ndarray, height: int, width: int, ) -> np.ndarray: - """Convert a list of polygons to a bitmap mask. + """Convert polygons to a bitmap mask. Args: - polygons (list[Polygon]): List of Datumaro Polygon objects. - height (int): bitmap height - width (int): bitmap width + polygons: a ragged array containing np.ndarray objects of shape (Npoly, 2) + height: bitmap height + width: bitmap width Returns: np.ndarray: bitmap masks """ - polygons = [polygon.points for polygon in polygons] - rles = mask_utils.frPyObjects(polygons, height, width) + # Convert to list of flat point arrays for pycocotools + polygon_points = [points.reshape(-1) for points in polygons] + rles = mask_utils.frPyObjects(polygon_points, height, width) return mask_utils.decode(rles).astype(bool).transpose((2, 0, 1)) def polygon_to_rle( - polygons: list[Polygon], + polygons: np.ndarray, height: int, width: int, ) -> list[dict]: - """Convert a list of polygons to a list of RLE masks. + """Convert polygons to a list of RLE masks. Args: - polygons (list[Polygon]): List of Datumaro Polygon objects. - height (int): bitmap height - width (int): bitmap width + polygons: a ragged array containing np.ndarray objects of shape (Npoly, 2) + height: bitmap height + width: bitmap width Returns: list[dict]: List of RLE masks. """ - polygons = [polygon.points for polygon in polygons] - if len(polygons): - return mask_utils.frPyObjects(polygons, height, width) - return [] + # Convert to list of flat point arrays for pycocotools + polygon_points = [points.reshape(-1) for points in polygons] + + return mask_utils.frPyObjects(polygon_points, height, width) def encode_rle(mask: torch.Tensor) -> dict: @@ -96,20 +96,31 @@ def encode_rle(mask: torch.Tensor) -> dict: def crop_and_resize_polygons( - annos: list[Polygon], + annos: np.ndarray, bboxes: np.ndarray, out_shape: tuple, inds: np.ndarray, device: str = "cpu", ) -> torch.Tensor: - """Crop and resize polygons to the target size.""" + """Crop and resize polygons to the target size. + + Args: + annos: Ragged array containing np.ndarray objects of shape (Npoly, 2) + bboxes: Bounding boxes array of shape (N, 4) + out_shape: Output shape (height, width) + inds: Indices array + device: Target device + + Returns: + torch.Tensor: Resized polygon masks + """ out_h, out_w = out_shape if len(annos) == 0: return torch.empty((0, *out_shape), dtype=torch.float, device=device) - resized_polygons = [] + resized_polygons = np.empty(len(bboxes), dtype=object) for i in range(len(bboxes)): - polygon = annos[inds[i]] + polygon_points = annos[inds[i]] bbox = bboxes[i, :] x1, y1, x2, y2 = bbox w = np.maximum(x2 - x1, 1) @@ -117,21 +128,17 @@ def crop_and_resize_polygons( h_scale = out_h / max(h, 0.1) # avoid too large scale w_scale = out_w / max(w, 0.1) - points = polygon.points - points = points.copy() - points = np.array(points) - # crop - # pycocotools will clip the boundary - points[0::2] = points[0::2] - bbox[0] - points[1::2] = points[1::2] - bbox[1] - - # resize - points[0::2] = points[0::2] * w_scale - points[1::2] = points[1::2] * h_scale + # Crop: translate points relative to bbox origin + cropped_points = polygon_points.copy() + cropped_points[:, 0] -= x1 # x coordinates + cropped_points[:, 1] -= y1 # y coordinates - resized_polygon = Polygon(points.tolist()) + # Resize: scale points to output size + resized_points = cropped_points.copy() + resized_points[:, 0] *= w_scale + resized_points[:, 1] *= h_scale - resized_polygons.append(resized_polygon) + resized_polygons[i] = resized_points mask_targets = polygon_to_bitmap(resized_polygons, *out_shape) diff --git a/library/src/otx/data/utils/utils.py b/library/src/otx/data/utils/utils.py index 769fc3ec7f..1d2c5eeb2f 100644 --- a/library/src/otx/data/utils/utils.py +++ b/library/src/otx/data/utils/utils.py @@ -15,14 +15,13 @@ import cv2 import numpy as np import torch -from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, LabelCategories, Polygon +from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, Polygon from datumaro.components.annotation import Shape as _Shape from otx.types import OTXTaskType from otx.utils.device import is_xpu_available if TYPE_CHECKING: - from datumaro import Dataset as DmDataset from datumaro import DatasetSubset from torch.utils.data import Dataset, Sampler @@ -322,22 +321,6 @@ def get_adaptive_num_workers(num_dataloader: int = 1) -> int | None: return min(cpu_count() // (num_dataloader * num_devices), 8) # max available num_workers is 8 -def get_idx_list_per_classes(dm_dataset: DmDataset, use_string_label: bool = False) -> dict[int | str, list[int]]: - """Compute class statistics.""" - stats: dict[int | str, list[int]] = defaultdict(list) - labels = dm_dataset.categories().get(AnnotationType.label, LabelCategories()) - for item_idx, item in enumerate(dm_dataset): - for ann in item.annotations: - if use_string_label: - stats[labels.items[ann.label].name].append(item_idx) - else: - stats[ann.label].append(item_idx) - # Remove duplicates in label stats idx: O(n) - for k in stats: - stats[k] = list(dict.fromkeys(stats[k])) - return stats - - def import_object_from_module(obj_path: str) -> Any: # noqa: ANN401 """Get object from import format string.""" module_name, obj_name = obj_path.rsplit(".", 1) diff --git a/library/src/otx/types/export.py b/library/src/otx/types/export.py index 5e1c9d5685..dd0d766878 100644 --- a/library/src/otx/types/export.py +++ b/library/src/otx/types/export.py @@ -164,7 +164,6 @@ def to_metadata(self) -> dict[tuple[str, str], str]: ("model_info", "tile_size"): str(self.tile_config.tile_size[0]), ("model_info", "tiles_overlap"): str(self.tile_config.overlap), ("model_info", "max_pred_number"): str(self.tile_config.max_num_instances), - ("model_info", "tile_with_full_img"): str(self.tile_config.with_full_img), }, ) diff --git a/library/src/otx/types/label.py b/library/src/otx/types/label.py index 56a30d4433..50bbc7915a 100644 --- a/library/src/otx/types/label.py +++ b/library/src/otx/types/label.py @@ -8,12 +8,13 @@ import copy import json from dataclasses import asdict, dataclass -from typing import TYPE_CHECKING, Any +from typing import Any -from datumaro.components.annotation import GroupType - -if TYPE_CHECKING: - from datumaro import Label, LabelCategories +from datumaro.experimental.categories import ( + GroupType, + HierarchicalLabelCategories, + HierarchicalLabelCategory, +) __all__ = [ "LabelInfo", @@ -64,7 +65,7 @@ def from_num_classes(cls, num_classes: int) -> LabelInfo: ) @classmethod - def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> LabelInfo: + def from_dm_label_groups(cls, dm_label_categories: HierarchicalLabelCategories) -> LabelInfo: """Create this object from the datumaro label groups. Args: @@ -89,14 +90,9 @@ def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> LabelInfo ) @classmethod - def from_dm_label_groups_arrow(cls, dm_label_categories: LabelCategories) -> LabelInfo: + def from_dm_label_groups_arrow(cls, dm_label_categories: HierarchicalLabelCategories) -> LabelInfo: """Overload to support datumaro's arrow format.""" - label_names = [] - for item in dm_label_categories.items: - for attr in item.attributes: - if attr.startswith("__name__"): - label_names.append(attr[len("__name__") :]) - break + label_names = [item.label_semantics["name"] for item in dm_label_categories.items] if len(label_names) != len(dm_label_categories.items): msg = "Wrong arrow format: can not extract label names from attributes" @@ -105,7 +101,9 @@ def from_dm_label_groups_arrow(cls, dm_label_categories: LabelCategories) -> Lab id_to_name_mapping = {item.name: label_names[i] for i, item in enumerate(dm_label_categories.items)} for label_group in dm_label_categories.label_groups: - label_group.labels = [id_to_name_mapping.get(label, label) for label in label_group.labels] + object.__setattr__( + label_group, "labels", [id_to_name_mapping.get(label, label) for label in label_group.labels] + ) label_groups = [label_group.labels for label_group in dm_label_categories.label_groups] if len(label_groups) == 0: # Single-label classification @@ -164,7 +162,6 @@ class HLabelInfo(LabelInfo): Args: num_multiclass_heads: The number of multiclass heads in the hierarchy. num_multilabel_classes: The number of multilabel classes. - head_to_logits_range: The logit range for each head as a dictionary mapping head indices to (start, end) tuples. num_single_label_classes: The number of single label classes. class_to_group_idx: Dictionary mapping class names to (head_index, label_index) @@ -205,20 +202,20 @@ class HLabelInfo(LabelInfo): head_idx_to_logits_range: dict[str, tuple[int, int]] num_single_label_classes: int class_to_group_idx: dict[str, tuple[int, int]] - all_groups: list[list[str]] + all_groups: list[tuple[str, ...]] label_to_idx: dict[str, int] label_tree_edges: list[list[str]] empty_multiclass_head_indices: list[int] @classmethod - def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> HLabelInfo: + def from_dm_label_groups(cls, dm_label_categories: HierarchicalLabelCategories) -> HLabelInfo: """Generate HLabelData from the Datumaro LabelCategories. Args: dm_label_categories (LabelCategories): the label categories of datumaro. """ - def get_exclusive_group_info(exclusive_groups: list[Label | list[Label]]) -> dict[str, Any]: + def get_exclusive_group_info(exclusive_groups: list[tuple[str, ...]]) -> dict[str, Any]: """Get exclusive group information.""" last_logits_pos = 0 num_single_label_classes = 0 @@ -240,7 +237,7 @@ def get_exclusive_group_info(exclusive_groups: list[Label | list[Label]]) -> dic } def get_single_label_group_info( - single_label_groups: list[Label | list[Label]], + single_label_groups: list, num_exclusive_groups: int, ) -> dict[str, Any]: """Get single label group information.""" @@ -270,30 +267,28 @@ def put_key_values(src: dict, dst: dict) -> None: put_key_values(single_label_ctoi, class_to_idx) return class_to_idx - def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str]]: + def get_label_tree_edges(dm_label_items: tuple[HierarchicalLabelCategory, ...]) -> list[list[str]]: """Get label tree edges information. Each edges represent [child, parent].""" return [[item.name, item.parent] for item in dm_label_items if item.parent != ""] def convert_labels_if_needed( - dm_label_categories: LabelCategories, + dm_label_categories: HierarchicalLabelCategories, label_names: list[str], - ) -> list[list[str]]: + ) -> list[tuple[str, ...]]: # Check if the labels need conversion and create name to ID mapping if required name_to_id_mapping = None for label_group in dm_label_categories.label_groups: if label_group.labels and label_group.labels[0] not in label_names: name_to_id_mapping = { - attr[len("__name__") :]: category.name - for category in dm_label_categories.items - for attr in category.attributes - if attr.startswith("__name__") + category.label_semantics["name"]: category.name for category in dm_label_categories.items } - break # If mapping exists, update the labels if name_to_id_mapping: for label_group in dm_label_categories.label_groups: - label_group.labels = [name_to_id_mapping.get(label, label) for label in label_group.labels] + object.__setattr__( + label_group, "labels", [name_to_id_mapping.get(label, label) for label in label_group.labels] + ) # Retrieve all label groups after conversion return [group.labels for group in dm_label_categories.label_groups] @@ -318,7 +313,7 @@ def convert_labels_if_needed( return HLabelInfo( label_names=label_names, - label_groups=exclusive_groups + single_label_groups, + label_groups=exclusive_groups + single_label_groups, # type: ignore[arg-type] num_multiclass_heads=exclusive_group_info["num_multiclass_heads"], num_multilabel_classes=single_label_group_info["num_multilabel_classes"], head_idx_to_logits_range=exclusive_group_info["head_idx_to_logits_range"], @@ -332,7 +327,7 @@ def convert_labels_if_needed( ) @classmethod - def from_dm_label_groups_arrow(cls, dm_label_categories: LabelCategories) -> HLabelInfo: + def from_dm_label_groups_arrow(cls, dm_label_categories: HierarchicalLabelCategories) -> HLabelInfo: """Generate HLabelData from the Datumaro LabelCategories. Arrow-specific implementation. Args: @@ -344,21 +339,17 @@ def from_dm_label_groups_arrow(cls, dm_label_categories: LabelCategories) -> HLa for label_group in dm_label_categories.label_groups: if label_group.group_type == GroupType.RESTRICTED: empty_label_name = label_group.labels[0] - - dm_label_categories.label_groups = [ - group for group in dm_label_categories.label_groups if group.group_type != GroupType.RESTRICTED - ] + label_groups = [group for group in dm_label_categories.label_groups if group.group_type != GroupType.RESTRICTED] + object.__setattr__(dm_label_categories, "label_groups", label_groups) empty_label_id = None label_names = [] for item in dm_label_categories.items: - for attr in item.attributes: - if attr.startswith("__name__"): - name = attr[len("__name__") :] - if name == empty_label_name: - empty_label_id = item.name - label_names.append(name) - break + name = item.label_semantics["name"] + + if name == empty_label_name: + empty_label_id = item.name + label_names.append(name) if len(label_names) != len(dm_label_categories.items): msg = "Wrong arrow file: can not extract label names from attributes" @@ -366,17 +357,23 @@ def from_dm_label_groups_arrow(cls, dm_label_categories: LabelCategories) -> HLa if empty_label_name is not None: label_names.remove(empty_label_name) - dm_label_categories.items = [item for item in dm_label_categories.items if item.name != empty_label_id] + + object.__setattr__( + dm_label_categories, "items", [item for item in dm_label_categories.items if item.name != empty_label_id] + ) + label_ids = [item.name for item in dm_label_categories.items] id_to_name_mapping = {item.name: label_names[i] for i, item in enumerate(dm_label_categories.items)} for i, item in enumerate(dm_label_categories.items): - item.name = label_names[i] - item.parent = id_to_name_mapping.get(item.parent, item.parent) + object.__setattr__(dm_label_categories, "name", label_names[i]) + object.__setattr__(dm_label_categories, "parent", id_to_name_mapping.get(item.parent, item.parent)) for label_group in dm_label_categories.label_groups: - label_group.labels = [id_to_name_mapping.get(label, label) for label in label_group.labels] + object.__setattr__( + label_group, "labels", [id_to_name_mapping.get(label, label) for label in label_group.labels] + ) obj = cls.from_dm_label_groups(dm_label_categories) obj.label_ids = label_ids diff --git a/library/tests/conftest.py b/library/tests/conftest.py index 69d35a1154..425a236da3 100644 --- a/library/tests/conftest.py +++ b/library/tests/conftest.py @@ -5,10 +5,10 @@ from collections import defaultdict from pathlib import Path +import numpy as np import pytest import torch import yaml -from datumaro import Polygon from torch import LongTensor from torchvision import tv_tensors from torchvision.tv_tensors import Image, Mask @@ -267,7 +267,9 @@ def fxt_inst_seg_data_entity() -> tuple[tuple, OTXDataItem, OTXDataBatch]: fake_bboxes = tv_tensors.BoundingBoxes(data=torch.Tensor([0, 0, 5, 5]), format="xyxy", canvas_size=(10, 10)) fake_labels = LongTensor([1]) fake_masks = Mask(torch.randint(low=0, high=255, size=(1, *img_size), dtype=torch.uint8)) - fake_polygons = [Polygon(points=[1, 1, 2, 2, 3, 3, 4, 4])] + fake_polygons = np.empty(shape=(1,), dtype=object) + fake_polygons[0] = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]) + # define data entity single_data_entity = OTXDataItem( image=fake_image, diff --git a/library/tests/test_helpers.py b/library/tests/test_helpers.py index 313b6f0666..faed389f87 100644 --- a/library/tests/test_helpers.py +++ b/library/tests/test_helpers.py @@ -17,9 +17,6 @@ from datumaro.components.errors import MediaTypeError from datumaro.components.exporter import Exporter from datumaro.components.media import Image -from datumaro.plugins.data_formats.common_semantic_segmentation import ( - CommonSemanticSegmentationPath, -) from datumaro.util.definitions import DEFAULT_SUBSET_NAME from datumaro.util.image import save_image from datumaro.util.meta_file_util import save_meta_file @@ -122,8 +119,8 @@ def _apply_impl(self) -> None: subset_dir = Path(save_dir, _subset_name) subset_dir.mkdir(parents=True, exist_ok=True) - mask_dir = subset_dir / CommonSemanticSegmentationPath.MASKS_DIR - img_dir = subset_dir / CommonSemanticSegmentationPath.IMAGES_DIR + mask_dir = subset_dir / "masks" + img_dir = subset_dir / "images" for item in subset: self._export_item_annotation(item, mask_dir) if self._save_media: diff --git a/library/tests/unit/backend/native/models/instance_segmentation/heads/test_rtmdet_inst_head.py b/library/tests/unit/backend/native/models/instance_segmentation/heads/test_rtmdet_inst_head.py index 382db10ce6..2fd466f77b 100644 --- a/library/tests/unit/backend/native/models/instance_segmentation/heads/test_rtmdet_inst_head.py +++ b/library/tests/unit/backend/native/models/instance_segmentation/heads/test_rtmdet_inst_head.py @@ -7,9 +7,9 @@ from functools import partial from unittest.mock import Mock +import numpy as np import pytest import torch -from datumaro import Polygon from torch import nn from otx.backend.native.models.common.utils.assigners import DynamicSoftLabelAssigner @@ -124,6 +124,11 @@ def test_prepare_loss_inputs(self, mocker, rtmdet_ins_head: RTMDetInstHead) -> N mocker.patch.object(rtmdet_ins_head, "_mask_predict_by_feat_single", return_value=torch.randn(4, 80, 80)) x = (torch.randn(2, 96, 80, 80), torch.randn(2, 96, 40, 40), torch.randn(2, 96, 20, 20)) + + polygons = [np.empty((1,), dtype=object), np.empty((1,), dtype=object)] + polygons[0] = np.array([[0, 0], [0, 1], [1, 1], [1, 0]]) + polygons[1] = np.array([[0, 0], [0, 1], [1, 1], [1, 0]]) + entity = OTXDataBatch( batch_size=2, images=[torch.randn(3, 640, 640), torch.randn(3, 640, 640)], @@ -134,7 +139,7 @@ def test_prepare_loss_inputs(self, mocker, rtmdet_ins_head: RTMDetInstHead) -> N bboxes=[torch.randn(2, 4), torch.randn(3, 4)], labels=[torch.randint(0, 3, (2,)), torch.randint(0, 3, (3,))], masks=[torch.zeros(2, 640, 640), torch.zeros(3, 640, 640)], - polygons=[[Polygon(points=[0, 0, 0, 1, 1, 1, 1, 0])], [Polygon(points=[0, 0, 0, 1, 1, 1, 1, 0])]], + polygons=polygons, ) results = rtmdet_ins_head.prepare_loss_inputs(x, entity) diff --git a/library/tests/unit/backend/native/utils/test_tile.py b/library/tests/unit/backend/native/utils/test_tile.py deleted file mode 100644 index a6d6163278..0000000000 --- a/library/tests/unit/backend/native/utils/test_tile.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -from __future__ import annotations - -from unittest.mock import MagicMock - -import numpy as np -from datumaro import Image -from datumaro.plugins.tiling.util import xywh_to_x1y1x2y2 -from model_api.models import Model -from model_api.tilers import Tiler - -from otx.data.dataset.tile import OTXTileTransform - - -def test_tile_transform_consistency(mocker): - # Test that OV tiler and PyTorch tile transform are consistent - rng = np.random.default_rng() - rnd_tile_size = rng.integers(low=100, high=500) - rnd_tile_overlap = min(rng.random(), 0.9) - image_size = rng.integers(low=1000, high=5000) - np_image = np.zeros((image_size, image_size, 3), dtype=np.uint8) - dm_image = Image.from_numpy(np_image) - - mock_model = MagicMock(spec=Model) - mocker.patch("model_api.tilers.tiler.Tiler.__init__", return_value=None) - mocker.patch.multiple(Tiler, __abstractmethods__=set()) - - tiler = Tiler(model=mock_model) - tiler.tile_with_full_img = True - tiler.tile_size = rnd_tile_size - tiler.tiles_overlap = rnd_tile_overlap - - mocker.patch("otx.data.dataset.tile.OTXTileTransform.__init__", return_value=None) - tile_transform = OTXTileTransform() - tile_transform._tile_size = (rnd_tile_size, rnd_tile_size) - tile_transform._overlap = (rnd_tile_overlap, rnd_tile_overlap) - tile_transform.with_full_img = True - - dm_rois = [xywh_to_x1y1x2y2(*roi) for roi in tile_transform._extract_rois(dm_image)] - ov_tiler_rois = tiler._tile(np_image) - - assert len(dm_rois) == len(ov_tiler_rois) - for dm_roi in dm_rois: - assert list(dm_roi) in ov_tiler_rois diff --git a/library/tests/unit/data/conftest.py b/library/tests/unit/data/conftest.py index 2932bc055d..5c8aee1db2 100644 --- a/library/tests/unit/data/conftest.py +++ b/library/tests/unit/data/conftest.py @@ -2,8 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -import uuid -from pathlib import Path from typing import TYPE_CHECKING from unittest.mock import MagicMock @@ -38,26 +36,19 @@ from otx.data.dataset.base import OTXDataset _LABEL_NAMES = ["Non-Rigid", "Rigid", "Rectangle", "Triangle", "Circle", "Lion", "Panda"] +_ANOMALY_LABEL_NAMES = ["good", "anomaly"] -@pytest.fixture(params=["bytes", "file"]) -def fxt_dm_item(request, tmpdir) -> DatasetItem: +@pytest.fixture() +def fxt_dm_item() -> DatasetItem: np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8) np_img[:, :, 0] = 0 # Set 0 for B channel np_img[:, :, 1] = 1 # Set 1 for G channel np_img[:, :, 2] = 2 # Set 2 for R channel - if request.param == "bytes": - _, np_bytes = cv2.imencode(".png", np_img) - media = Image.from_bytes(np_bytes.tobytes()) - media.path = "" - elif request.param == "file": - fname = str(uuid.uuid4()) - fpath = str(Path(tmpdir) / f"{fname}.png") - cv2.imwrite(fpath, np_img) - media = Image.from_file(fpath) - else: - raise ValueError(request.param) + _, np_bytes = cv2.imencode(".png", np_img) + media = Image.from_bytes(np_bytes.tobytes()) + media.path = "" return DatasetItem( id="item", @@ -72,24 +63,61 @@ def fxt_dm_item(request, tmpdir) -> DatasetItem: ) -@pytest.fixture(params=["bytes", "file"]) -def fxt_dm_item_bbox_only(request, tmpdir) -> DatasetItem: +@pytest.fixture() +def fxt_classification_dm_item() -> DatasetItem: + np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8) + np_img[:, :, 0] = 0 # Set 0 for B channel + np_img[:, :, 1] = 1 # Set 1 for G channel + np_img[:, :, 2] = 2 # Set 2 for R channel + + _, np_bytes = cv2.imencode(".png", np_img) + media = Image.from_bytes(np_bytes.tobytes()) + media.path = "" + + return DatasetItem( + id="item", + subset="train", + media=media, + annotations=[ + Label(label=0), + ], + ) + + +@pytest.fixture() +def fxt_anomaly_dm_item() -> DatasetItem: np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8) np_img[:, :, 0] = 0 # Set 0 for B channel np_img[:, :, 1] = 1 # Set 1 for G channel np_img[:, :, 2] = 2 # Set 2 for R channel - if request.param == "bytes": - _, np_bytes = cv2.imencode(".png", np_img) - media = Image.from_bytes(np_bytes.tobytes()) - media.path = "" - elif request.param == "file": - fname = str(uuid.uuid4()) - fpath = str(Path(tmpdir) / f"{fname}.png") - cv2.imwrite(fpath, np_img) - media = Image.from_file(fpath) - else: - raise ValueError(request.param) + _, np_bytes = cv2.imencode(".png", np_img) + media = Image.from_bytes(np_bytes.tobytes()) + media.path = "" + + return DatasetItem( + id="item", + subset="train", + media=media, + annotations=[ + Label(label=0), + Bbox(x=200, y=200, w=1, h=1, label=0), + Mask(label=0, image=np.eye(10, dtype=np.uint8)), + Polygon(points=[399.0, 570.0, 397.0, 572.0, 397.0, 573.0, 394.0, 576.0], label=0), + ], + ) + + +@pytest.fixture() +def fxt_detection_dm_item() -> DatasetItem: + np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8) + np_img[:, :, 0] = 0 # Set 0 for B channel + np_img[:, :, 1] = 1 # Set 1 for G channel + np_img[:, :, 2] = 2 # Set 2 for R channel + + _, np_bytes = cv2.imencode(".png", np_img) + media = Image.from_bytes(np_bytes.tobytes()) + media.path = "" return DatasetItem( id="item", @@ -103,12 +131,37 @@ def fxt_dm_item_bbox_only(request, tmpdir) -> DatasetItem: ) +@pytest.fixture() +def fxt_segmentation_dm_item() -> DatasetItem: + np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8) + np_img[:, :, 0] = 0 # Set 0 for B channel + np_img[:, :, 1] = 1 # Set 1 for G channel + np_img[:, :, 2] = 2 # Set 2 for R channel + + _, np_bytes = cv2.imencode(".png", np_img) + media = Image.from_bytes(np_bytes.tobytes()) + media.path = "" + + return DatasetItem( + id="item", + subset="train", + media=media, + annotations=[ + Mask(label=0, image=np.eye(10, dtype=np.uint8)), + Polygon(points=[399.0, 570.0, 397.0, 572.0, 397.0, 573.0, 394.0, 576.0], label=0), + ], + ) + + @pytest.fixture() def fxt_mock_dm_subset(mocker: MockerFixture, fxt_dm_item: DatasetItem) -> MagicMock: mock_dm_subset = mocker.MagicMock(spec=DmDataset) mock_dm_subset.__getitem__.return_value = fxt_dm_item + mock_dm_subset.__iter__.return_value = [fxt_dm_item] mock_dm_subset.__len__.return_value = 1 mock_dm_subset.categories().__getitem__.return_value = LabelCategories.from_iterable(_LABEL_NAMES) + mock_dm_subset.categories().get.return_value = LabelCategories.from_iterable(_LABEL_NAMES) + mock_dm_subset.media_type.return_value = Image mock_dm_subset.ann_types.return_value = [ AnnotationType.label, AnnotationType.bbox, @@ -119,15 +172,64 @@ def fxt_mock_dm_subset(mocker: MockerFixture, fxt_dm_item: DatasetItem) -> Magic @pytest.fixture() -def fxt_mock_det_dm_subset(mocker: MockerFixture, fxt_dm_item_bbox_only: DatasetItem) -> MagicMock: +def fxt_mock_classification_dm_subset(mocker: MockerFixture, fxt_classification_dm_item: DatasetItem) -> MagicMock: mock_dm_subset = mocker.MagicMock(spec=DmDataset) - mock_dm_subset.__getitem__.return_value = fxt_dm_item_bbox_only + mock_dm_subset.__getitem__.return_value = fxt_classification_dm_item + mock_dm_subset.__iter__.return_value = [fxt_classification_dm_item] mock_dm_subset.__len__.return_value = 1 mock_dm_subset.categories().__getitem__.return_value = LabelCategories.from_iterable(_LABEL_NAMES) + mock_dm_subset.categories().get.return_value = LabelCategories.from_iterable(_LABEL_NAMES) + mock_dm_subset.media_type.return_value = Image + mock_dm_subset.ann_types.return_value = [ + AnnotationType.label, + ] + return mock_dm_subset + + +@pytest.fixture() +def fxt_mock_anomaly_dm_subset(mocker: MockerFixture, fxt_anomaly_dm_item: DatasetItem) -> MagicMock: + mock_dm_subset = mocker.MagicMock(spec=DmDataset) + mock_dm_subset.__getitem__.return_value = fxt_anomaly_dm_item + mock_dm_subset.__iter__.return_value = [fxt_anomaly_dm_item] + mock_dm_subset.__len__.return_value = 1 + mock_dm_subset.categories().__getitem__.return_value = LabelCategories.from_iterable(_ANOMALY_LABEL_NAMES) + mock_dm_subset.categories().get.return_value = LabelCategories.from_iterable(_ANOMALY_LABEL_NAMES) + mock_dm_subset.media_type.return_value = Image + mock_dm_subset.ann_types.return_value = [ + AnnotationType.label, + AnnotationType.bbox, + AnnotationType.mask, + AnnotationType.polygon, + ] + return mock_dm_subset + + +@pytest.fixture() +def fxt_mock_detection_dm_subset(mocker: MockerFixture, fxt_detection_dm_item: DatasetItem) -> MagicMock: + mock_dm_subset = mocker.MagicMock(spec=DmDataset) + mock_dm_subset.__getitem__.return_value = fxt_detection_dm_item + mock_dm_subset.__iter__.return_value = [fxt_detection_dm_item] + mock_dm_subset.__len__.return_value = 1 + mock_dm_subset.categories().__getitem__.return_value = LabelCategories.from_iterable(_LABEL_NAMES) + mock_dm_subset.categories().get.return_value = LabelCategories.from_iterable(_LABEL_NAMES) + mock_dm_subset.media_type.return_value = Image mock_dm_subset.ann_types.return_value = [AnnotationType.bbox] return mock_dm_subset +@pytest.fixture() +def fxt_mock_segmentation_dm_subset(mocker: MockerFixture, fxt_segmentation_dm_item: DatasetItem) -> MagicMock: + mock_dm_subset = mocker.MagicMock(spec=DmDataset) + mock_dm_subset.__getitem__.return_value = fxt_segmentation_dm_item + mock_dm_subset.__iter__.return_value = [fxt_segmentation_dm_item] + mock_dm_subset.__len__.return_value = 1 + mock_dm_subset.categories().__getitem__.return_value = LabelCategories.from_iterable(_LABEL_NAMES) + mock_dm_subset.categories().get.return_value = LabelCategories.from_iterable(_LABEL_NAMES) + mock_dm_subset.media_type.return_value = Image + mock_dm_subset.ann_types.return_value = [AnnotationType.polygon, AnnotationType.mask] + return mock_dm_subset + + @pytest.fixture( params=[ (OTXHlabelClsDataset, OTXDataItem, {}), diff --git a/library/tests/unit/data/dataset/test_base_new.py b/library/tests/unit/data/dataset/test_base_new.py new file mode 100644 index 0000000000..a274e132f7 --- /dev/null +++ b/library/tests/unit/data/dataset/test_base_new.py @@ -0,0 +1,261 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for base_new OTXDataset.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +import torch +from datumaro.experimental import Dataset + +from otx.data.dataset.base_new import OTXDataset, _default_collate_fn +from otx.data.entity.sample import OTXSample +from otx.data.entity.torch.torch import OTXDataBatch + + +class TestDefaultCollateFn: + """Test _default_collate_fn function.""" + + def test_collate_with_torch_tensors(self): + """Test collating items with torch tensor images.""" + # Create mock samples with torch tensor images + sample1 = Mock(spec=OTXSample) + sample1.image = torch.randn(3, 224, 224) + sample1.label = torch.tensor(0) + sample1.masks = None + sample1.bboxes = None + sample1.keypoints = None + sample1.polygons = None + sample1.img_info = None + + sample2 = Mock(spec=OTXSample) + sample2.image = torch.randn(3, 224, 224) + sample2.label = torch.tensor(1) + sample2.masks = None + sample2.bboxes = None + sample2.keypoints = None + sample2.polygons = None + sample2.img_info = None + + items = [sample1, sample2] + result = _default_collate_fn(items) + + assert isinstance(result, OTXDataBatch) + assert result.batch_size == 2 + assert isinstance(result.images, torch.Tensor) + assert result.images.shape == (2, 3, 224, 224) + assert result.images.dtype == torch.float32 + assert result.labels == [torch.tensor(0), torch.tensor(1)] + + def test_collate_with_different_image_shapes(self): + """Test collating items with different image shapes.""" + sample1 = Mock(spec=OTXSample) + sample1.image = torch.randn(3, 224, 224) + sample1.label = None + sample1.masks = None + sample1.bboxes = None + sample1.keypoints = None + sample1.polygons = None + sample1.img_info = None + + sample2 = Mock(spec=OTXSample) + sample2.image = torch.randn(3, 256, 256) + sample2.label = None + sample2.masks = None + sample2.bboxes = None + sample2.keypoints = None + sample2.polygons = None + sample2.img_info = None + + items = [sample1, sample2] + result = _default_collate_fn(items) + + # When shapes are different, should return list instead of stacked tensor + assert isinstance(result.images, list) + assert len(result.images) == 2 + assert result.labels is None + + +class TestOTXDataset: + """Test OTXDataset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_dm_subset = Mock(spec=Dataset) + self.mock_dm_subset.__len__ = Mock(return_value=100) + + # Mock schema attributes for label_info + mock_schema = Mock() + mock_attributes = {"label": Mock()} + mock_attributes["label"].categories = Mock() + # Configure labels to be a list with proper length support + mock_attributes["label"].categories.labels = ["class_0", "class_1", "class_2"] + mock_schema.attributes = mock_attributes + self.mock_dm_subset.schema = mock_schema + + self.mock_transforms = Mock() + + def test_sample_another_idx(self): + """Test _sample_another_idx method.""" + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + with patch("numpy.random.randint", return_value=42): + idx = dataset._sample_another_idx() + assert idx == 42 + + def test_apply_transforms_with_compose(self): + """Test _apply_transforms with Compose transforms.""" + from otx.data.transform_libs.torchvision import Compose + + mock_compose = Mock(spec=Compose) + mock_entity = Mock(spec=OTXSample) + mock_result = Mock() + mock_compose.return_value = mock_result + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=mock_compose, + data_format="arrow", + to_tv_image=True, + ) + + result = dataset._apply_transforms(mock_entity) + + mock_entity.as_tv_image.assert_called_once() + mock_compose.assert_called_once_with(mock_entity) + assert result == mock_result + + def test_apply_transforms_with_callable(self): + """Test _apply_transforms with callable transform.""" + mock_transform = Mock() + mock_entity = Mock(spec=OTXSample) + mock_result = Mock() + mock_transform.return_value = mock_result + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=mock_transform, + data_format="arrow", + ) + + result = dataset._apply_transforms(mock_entity) + + mock_transform.assert_called_once_with(mock_entity) + assert result == mock_result + + def test_apply_transforms_with_list(self): + """Test _apply_transforms with list of transforms.""" + transform1 = Mock() + transform2 = Mock() + + mock_entity = Mock(spec=OTXSample) + intermediate_result = Mock() + final_result = Mock() + + transform1.return_value = intermediate_result + transform2.return_value = final_result + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=[transform1, transform2], + data_format="arrow", + ) + + result = dataset._apply_transforms(mock_entity) + + transform1.assert_called_once_with(mock_entity) + transform2.assert_called_once_with(intermediate_result) + assert result == final_result + + def test_apply_transforms_with_list_returns_none(self): + """Test _apply_transforms with list that returns None.""" + transform1 = Mock() + transform2 = Mock() + + mock_entity = Mock(spec=OTXSample) + transform1.return_value = None # First transform returns None + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=[transform1, transform2], + data_format="arrow", + ) + + result = dataset._apply_transforms(mock_entity) + + transform1.assert_called_once_with(mock_entity) + transform2.assert_not_called() # Should not be called since first returned None + assert result is None + + def test_iterable_transforms_with_non_list(self): + """Test _iterable_transforms with non-list iterable raises TypeError.""" + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + mock_entity = Mock(spec=OTXSample) + dataset.transforms = "not_a_list" # String is iterable but not a list + + with pytest.raises(TypeError): + dataset._iterable_transforms(mock_entity) + + def test_getitem_success(self): + """Test __getitem__ with successful retrieval.""" + mock_item = Mock() + self.mock_dm_subset.__getitem__ = Mock(return_value=mock_item) + + mock_transformed_item = Mock(spec=OTXSample) + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + with patch.object(dataset, "_apply_transforms", return_value=mock_transformed_item): + result = dataset[5] + + self.mock_dm_subset.__getitem__.assert_called_once_with(5) + assert result == mock_transformed_item + + def test_getitem_with_refetch(self): + """Test __getitem__ with failed first attempt requiring refetch.""" + mock_item = Mock() + self.mock_dm_subset.__getitem__ = Mock(return_value=mock_item) + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + max_refetch=2, + ) + + mock_transformed_item = Mock(spec=OTXSample) + + # First call returns None, second returns valid item + with patch.object(dataset, "_apply_transforms", side_effect=[None, mock_transformed_item]), patch.object( + dataset, "_sample_another_idx", return_value=10 + ): + result = dataset[5] + + assert result == mock_transformed_item + assert dataset._apply_transforms.call_count == 2 + + def test_collate_fn_property(self): + """Test collate_fn property returns _default_collate_fn.""" + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + assert dataset.collate_fn == _default_collate_fn diff --git a/library/tests/unit/data/dataset/test_classification.py b/library/tests/unit/data/dataset/test_classification.py index c6a62ecea9..0d790d7a58 100644 --- a/library/tests/unit/data/dataset/test_classification.py +++ b/library/tests/unit/data/dataset/test_classification.py @@ -28,10 +28,10 @@ def test_get_item( def test_get_item_from_bbox_dataset( self, - fxt_mock_det_dm_subset, + fxt_mock_detection_dm_subset, ) -> None: dataset = OTXMulticlassClsDataset( - dm_subset=fxt_mock_det_dm_subset, + dm_subset=fxt_mock_detection_dm_subset, transforms=[lambda x: x], max_refetch=3, ) @@ -52,10 +52,10 @@ def test_get_item( def test_get_item_from_bbox_dataset( self, - fxt_mock_det_dm_subset, + fxt_mock_detection_dm_subset, ) -> None: dataset = OTXMultilabelClsDataset( - dm_subset=fxt_mock_det_dm_subset, + dm_subset=fxt_mock_detection_dm_subset, transforms=[lambda x: x], max_refetch=3, ) @@ -92,12 +92,12 @@ def test_get_item( def test_get_item_from_bbox_dataset( self, mocker, - fxt_mock_det_dm_subset, + fxt_mock_detection_dm_subset, fxt_mock_hlabelinfo, ) -> None: mocker.patch.object(HLabelInfo, "from_dm_label_groups", return_value=fxt_mock_hlabelinfo) dataset = OTXHlabelClsDataset( - dm_subset=fxt_mock_det_dm_subset, + dm_subset=fxt_mock_detection_dm_subset, transforms=[lambda x: x], max_refetch=3, ) diff --git a/library/tests/unit/data/dataset/test_classification_new.py b/library/tests/unit/data/dataset/test_classification_new.py new file mode 100644 index 0000000000..25ba230c46 --- /dev/null +++ b/library/tests/unit/data/dataset/test_classification_new.py @@ -0,0 +1,68 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for classification_new dataset.""" + +from __future__ import annotations + +from unittest.mock import Mock + +from datumaro.experimental import Dataset + +from otx.data.dataset.classification_new import OTXMulticlassClsDataset +from otx.data.entity.sample import ClassificationSample + + +class TestOTXMulticlassClsDataset: + """Test OTXMulticlassClsDataset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_dm_subset = Mock(spec=Dataset) + self.mock_dm_subset.__len__ = Mock(return_value=10) + + # Mock schema attributes for label_info + mock_schema = Mock() + mock_attributes = {"label": Mock()} + mock_attributes["label"].categories = Mock() + # Configure labels to be a list with proper length support + mock_attributes["label"].categories.labels = ["class_0", "class_1", "class_2"] + mock_schema.attributes = mock_attributes + self.mock_dm_subset.schema = mock_schema + + self.mock_transforms = Mock() + + def test_init_sets_sample_type(self): + """Test that initialization sets sample_type to ClassificationSample.""" + dataset = OTXMulticlassClsDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + assert dataset.sample_type == ClassificationSample + + def test_get_idx_list_per_classes_single_class(self): + """Test get_idx_list_per_classes with single class.""" + # Mock dataset items with labels + mock_items = [] + for _ in range(5): + mock_item = Mock() + mock_item.label.item.return_value = 0 # All items have label 0 + mock_items.append(mock_item) + + self.mock_dm_subset.__getitem__ = Mock(side_effect=mock_items) + + dataset = OTXMulticlassClsDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + # Override length for this test + dataset.dm_subset.__len__ = Mock(return_value=5) + + result = dataset.get_idx_list_per_classes() + + expected = {0: [0, 1, 2, 3, 4]} + assert result == expected diff --git a/library/tests/unit/data/dataset/test_detection_new.py b/library/tests/unit/data/dataset/test_detection_new.py new file mode 100644 index 0000000000..aa024a39c3 --- /dev/null +++ b/library/tests/unit/data/dataset/test_detection_new.py @@ -0,0 +1,83 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for detection_new dataset.""" + +from __future__ import annotations + +from unittest.mock import Mock + +from datumaro.experimental import Dataset + +from otx.data.dataset.detection_new import OTXDetectionDataset +from otx.data.entity.sample import DetectionSample + + +class TestOTXDetectionDataset: + """Test OTXDetectionDataset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_dm_subset = Mock(spec=Dataset) + self.mock_dm_subset.__len__ = Mock(return_value=10) + self.mock_dm_subset.convert_to_schema = Mock(return_value=self.mock_dm_subset) + + # Mock schema attributes for label_info + mock_schema = Mock() + mock_attributes = {"label": Mock()} + mock_attributes["label"].categories = Mock() + # Configure labels to be a list with proper length support + mock_attributes["label"].categories.labels = ["class_0", "class_1", "class_2"] + mock_schema.attributes = mock_attributes + self.mock_dm_subset.schema = mock_schema + + self.mock_transforms = Mock() + + def test_init_sets_sample_type(self): + """Test that initialization sets sample_type to DetectionSample.""" + dataset = OTXDetectionDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + assert dataset.sample_type == DetectionSample + + def test_get_idx_list_per_classes_multiple_classes_per_item(self): + """Test get_idx_list_per_classes with multiple classes per item.""" + # Mock dataset items with multiple labels per item + mock_items = [] + # Item 0: classes [0, 1] + mock_item0 = Mock() + mock_item0.label.tolist.return_value = [0, 1] + mock_items.append(mock_item0) + + # Item 1: class [1] + mock_item1 = Mock() + mock_item1.label.tolist.return_value = [1] + mock_items.append(mock_item1) + + # Item 2: classes [0, 2] + mock_item2 = Mock() + mock_item2.label.tolist.return_value = [0, 2] + mock_items.append(mock_item2) + + self.mock_dm_subset.__getitem__ = Mock(side_effect=mock_items) + + dataset = OTXDetectionDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + # Override length for this test + dataset.dm_subset.__len__ = Mock(return_value=3) + + result = dataset.get_idx_list_per_classes() + + expected = { + 0: [0, 2], # Items 0 and 2 contain class 0 + 1: [0, 1], # Items 0 and 1 contain class 1 + 2: [2], # Item 2 contains class 2 + } + assert result == expected diff --git a/library/tests/unit/data/dataset/test_segmentation.py b/library/tests/unit/data/dataset/test_segmentation.py index a415ad25ae..d49c675a48 100644 --- a/library/tests/unit/data/dataset/test_segmentation.py +++ b/library/tests/unit/data/dataset/test_segmentation.py @@ -22,10 +22,10 @@ def test_get_item( def test_get_item_from_bbox_dataset( self, - fxt_mock_det_dm_subset, + fxt_mock_detection_dm_subset, ) -> None: dataset = OTXSegmentationDataset( - dm_subset=fxt_mock_det_dm_subset, + dm_subset=fxt_mock_detection_dm_subset, transforms=[lambda x: x], max_refetch=3, ) diff --git a/library/tests/unit/data/samplers/test_balanced_sampler.py b/library/tests/unit/data/samplers/test_balanced_sampler.py index 43b8810c3b..768fad8ef4 100644 --- a/library/tests/unit/data/samplers/test_balanced_sampler.py +++ b/library/tests/unit/data/samplers/test_balanced_sampler.py @@ -12,7 +12,6 @@ from otx.data.dataset.base import OTXDataset from otx.data.samplers.balanced_sampler import BalancedSampler -from otx.data.utils import get_idx_list_per_classes @pytest.fixture() @@ -81,7 +80,7 @@ def test_sampler_iter_with_multiple_replicas(self, fxt_imbalanced_dataset): def test_compute_class_statistics(self, fxt_imbalanced_dataset): # Compute class statistics - stats = get_idx_list_per_classes(fxt_imbalanced_dataset.dm_subset) + stats = fxt_imbalanced_dataset.get_idx_list_per_classes() # Check the expected results assert stats == {0: list(range(100)), 1: list(range(100, 108))} @@ -90,7 +89,7 @@ def test_sampler_iter_per_class(self, fxt_imbalanced_dataset): batch_size = 4 sampler = BalancedSampler(fxt_imbalanced_dataset) - stats = get_idx_list_per_classes(fxt_imbalanced_dataset.dm_subset) + stats = fxt_imbalanced_dataset.get_idx_list_per_classes() class_0_idx = stats[0] class_1_idx = stats[1] list_iter = list(iter(sampler)) diff --git a/library/tests/unit/data/samplers/test_class_incremental_sampler.py b/library/tests/unit/data/samplers/test_class_incremental_sampler.py index cd2f34b8e5..f031f58265 100644 --- a/library/tests/unit/data/samplers/test_class_incremental_sampler.py +++ b/library/tests/unit/data/samplers/test_class_incremental_sampler.py @@ -10,7 +10,6 @@ from otx.data.dataset.base import OTXDataset from otx.data.samplers.class_incremental_sampler import ClassIncrementalSampler -from otx.data.utils import get_idx_list_per_classes @pytest.fixture() @@ -107,7 +106,7 @@ def test_sampler_iter_per_class(self, fxt_old_new_dataset): new_classes=["2"], ) - stats = get_idx_list_per_classes(fxt_old_new_dataset.dm_subset, True) + stats = fxt_old_new_dataset.get_idx_list_per_classes(True) old_idx = stats["0"] + stats["1"] new_idx = stats["2"] list_iter = list(iter(sampler)) diff --git a/library/tests/unit/data/test_factory.py b/library/tests/unit/data/test_factory.py index 3c24b1c774..cc2cf8c94c 100644 --- a/library/tests/unit/data/test_factory.py +++ b/library/tests/unit/data/test_factory.py @@ -6,16 +6,16 @@ import pytest from otx.config.data import SubsetConfig -from otx.data.dataset.anomaly import OTXAnomalyDataset +from otx.data.dataset.anomaly_new import OTXAnomalyDataset from otx.data.dataset.classification import ( HLabelInfo, OTXHlabelClsDataset, - OTXMulticlassClsDataset, OTXMultilabelClsDataset, ) -from otx.data.dataset.detection import OTXDetectionDataset -from otx.data.dataset.instance_segmentation import OTXInstanceSegDataset -from otx.data.dataset.segmentation import OTXSegmentationDataset +from otx.data.dataset.classification_new import OTXMulticlassClsDataset +from otx.data.dataset.detection_new import OTXDetectionDataset +from otx.data.dataset.instance_segmentation_new import OTXInstanceSegDataset +from otx.data.dataset.segmentation_new import OTXSegmentationDataset from otx.data.factory import OTXDatasetFactory, TransformLibFactory from otx.data.transform_libs.torchvision import TorchVisionTransformLib from otx.types.image import ImageColorChannel @@ -40,37 +40,39 @@ def test_generate(self, lib_type, lib, mocker) -> None: class TestOTXDatasetFactory: @pytest.mark.parametrize( - ("task_type", "dataset_cls"), + ("task_type", "dataset_cls", "dm_subset_fxt_name"), [ - (OTXTaskType.MULTI_CLASS_CLS, OTXMulticlassClsDataset), - (OTXTaskType.MULTI_LABEL_CLS, OTXMultilabelClsDataset), - (OTXTaskType.H_LABEL_CLS, OTXHlabelClsDataset), - (OTXTaskType.DETECTION, OTXDetectionDataset), - (OTXTaskType.ROTATED_DETECTION, OTXInstanceSegDataset), - (OTXTaskType.INSTANCE_SEGMENTATION, OTXInstanceSegDataset), - (OTXTaskType.SEMANTIC_SEGMENTATION, OTXSegmentationDataset), - (OTXTaskType.ANOMALY, OTXAnomalyDataset), - (OTXTaskType.ANOMALY_CLASSIFICATION, OTXAnomalyDataset), - (OTXTaskType.ANOMALY_DETECTION, OTXAnomalyDataset), - (OTXTaskType.ANOMALY_SEGMENTATION, OTXAnomalyDataset), + (OTXTaskType.MULTI_CLASS_CLS, OTXMulticlassClsDataset, "fxt_mock_classification_dm_subset"), + (OTXTaskType.MULTI_LABEL_CLS, OTXMultilabelClsDataset, "fxt_mock_classification_dm_subset"), + (OTXTaskType.H_LABEL_CLS, OTXHlabelClsDataset, "fxt_mock_classification_dm_subset"), + (OTXTaskType.DETECTION, OTXDetectionDataset, "fxt_mock_detection_dm_subset"), + (OTXTaskType.ROTATED_DETECTION, OTXInstanceSegDataset, "fxt_mock_segmentation_dm_subset"), + (OTXTaskType.INSTANCE_SEGMENTATION, OTXInstanceSegDataset, "fxt_mock_segmentation_dm_subset"), + (OTXTaskType.SEMANTIC_SEGMENTATION, OTXSegmentationDataset, "fxt_mock_segmentation_dm_subset"), + (OTXTaskType.ANOMALY, OTXAnomalyDataset, "fxt_mock_anomaly_dm_subset"), + (OTXTaskType.ANOMALY_CLASSIFICATION, OTXAnomalyDataset, "fxt_mock_anomaly_dm_subset"), + (OTXTaskType.ANOMALY_DETECTION, OTXAnomalyDataset, "fxt_mock_anomaly_dm_subset"), + (OTXTaskType.ANOMALY_SEGMENTATION, OTXAnomalyDataset, "fxt_mock_anomaly_dm_subset"), ], ) def test_create( self, + request, fxt_mock_hlabelinfo, - fxt_mock_dm_subset, task_type, dataset_cls, + dm_subset_fxt_name, mocker, ) -> None: mocker.patch.object(TransformLibFactory, "generate", return_value=None) + dm_subset = request.getfixturevalue(dm_subset_fxt_name) cfg_subset = mocker.MagicMock(spec=SubsetConfig) image_color_channel = ImageColorChannel.BGR mocker.patch.object(HLabelInfo, "from_dm_label_groups", return_value=fxt_mock_hlabelinfo) assert isinstance( OTXDatasetFactory.create( task=task_type, - dm_subset=fxt_mock_dm_subset, + dm_subset=dm_subset, cfg_subset=cfg_subset, image_color_channel=image_color_channel, data_format="", diff --git a/library/tests/unit/data/test_tiling.py b/library/tests/unit/data/test_tiling.py index 7a3db6f583..bfc4eb3681 100644 --- a/library/tests/unit/data/test_tiling.py +++ b/library/tests/unit/data/test_tiling.py @@ -9,10 +9,7 @@ import numpy as np import pytest -import shapely.geometry as sg import torch -from datumaro import Dataset as DmDataset -from datumaro import Polygon from model_api.models import Model from model_api.models.result import ImageResultWithSoftPrediction from model_api.tilers import SemanticSegmentationTiler @@ -28,7 +25,6 @@ SubsetConfig, TileConfig, ) -from otx.data.dataset.tile import OTXTileTransform from otx.data.entity.tile import TileBatchDetDataEntity, TileBatchInstSegDataEntity, TileBatchSegDataEntity from otx.data.entity.torch import OTXDataBatch, OTXPredBatch from otx.data.module import OTXDataModule @@ -235,73 +231,6 @@ def inst_seg_dummy_forward(self, x: OTXDataBatch) -> OTXPredBatch: return pred_entity - @pytest.mark.parametrize( - "task", - [OTXTaskType.DETECTION, OTXTaskType.INSTANCE_SEGMENTATION, OTXTaskType.SEMANTIC_SEGMENTATION], - ) - def test_tile_transform(self, task, fxt_data_roots): - if task in (OTXTaskType.INSTANCE_SEGMENTATION, OTXTaskType.DETECTION): - dataset_format = "coco_instances" - elif task == OTXTaskType.SEMANTIC_SEGMENTATION: - dataset_format = "common_semantic_segmentation_with_subset_dirs" - else: - pytest.skip("Task not supported") - - data_root = str(fxt_data_roots[task]) - dataset = DmDataset.import_from(data_root, format=dataset_format) - - rng = np.random.default_rng() - tile_size = rng.integers(low=50, high=128, size=(2,)) - overlap = rng.random(2) - overlap = overlap.clip(0, 0.9) - threshold_drop_ann = rng.random() - tiled_dataset = DmDataset.import_from(data_root, format=dataset_format) - tiled_dataset.transform( - OTXTileTransform, - tile_size=tile_size, - overlap=overlap, - threshold_drop_ann=threshold_drop_ann, - with_full_img=True, - ) - - h_stride = max(int((1 - overlap[0]) * tile_size[0]), 1) - w_stride = max(int((1 - overlap[1]) * tile_size[1]), 1) - - num_tiles = 0 - for dataset_item in dataset: - height, width = dataset_item.media.data.shape[:2] - for _ in range(0, height, h_stride): - for _ in range(0, width, w_stride): - num_tiles += 1 - - assert len(tiled_dataset) == num_tiles + len(dataset), "Incorrect number of tiles" - - tiled_dataset = DmDataset.import_from(data_root, format=dataset_format) - tiled_dataset.transform( - OTXTileTransform, - tile_size=tile_size, - overlap=overlap, - threshold_drop_ann=threshold_drop_ann, - with_full_img=False, - ) - assert len(tiled_dataset) == num_tiles, "Incorrect number of tiles" - - def test_tile_polygon_func(self): - points = np.array([(1, 2), (3, 5), (4, 2), (4, 6), (1, 6)]) - polygon = Polygon(points=points.flatten().tolist()) - roi = sg.Polygon([(0, 0), (5, 0), (5, 5), (0, 5)]) - - inter_polygon = OTXTileTransform._tile_polygon(polygon, roi, threshold_drop_ann=0.0) - assert isinstance(inter_polygon, Polygon), "Intersection should be a Polygon" - assert inter_polygon.get_area() > 0, "Intersection area should be greater than 0" - - assert ( - OTXTileTransform._tile_polygon(polygon, roi, threshold_drop_ann=1.0) is None - ), "Intersection should be None" - - invalid_polygon = Polygon(points=[0, 0, 5, 0, 5, 5, 5, 0]) - assert OTXTileTransform._tile_polygon(invalid_polygon, roi) is None, "Invalid polygon should be None" - def test_adaptive_tiling(self, fxt_data_config): for task, data_config in fxt_data_config.items(): # Enable tile adapter @@ -346,6 +275,7 @@ def test_adaptive_tiling(self, fxt_data_config): else: pytest.skip("Task not supported") + @pytest.mark.xfail(reason="Tiling not yet implemented with new dataset") def test_tile_sampler(self, fxt_data_config): for task, data_config in fxt_data_config.items(): rng = np.random.default_rng() @@ -380,6 +310,7 @@ def test_tile_sampler(self, fxt_data_config): assert sampled_count == count, "Sampled count should be equal to the count of the dataloader batch size" + @pytest.mark.xfail(reason="Tiling not yet implemented with new dataset") def test_train_dataloader(self, fxt_data_config) -> None: for task, data_config in fxt_data_config.items(): # Enable tile adapter @@ -400,6 +331,7 @@ def test_train_dataloader(self, fxt_data_config) -> None: else: pytest.skip("Task not supported") + @pytest.mark.xfail(reason="Tiling not yet implemented with new dataset") def test_val_dataloader(self, fxt_data_config) -> None: for task, data_config in fxt_data_config.items(): # Enable tile adapter @@ -420,6 +352,7 @@ def test_val_dataloader(self, fxt_data_config) -> None: else: pytest.skip("Task not supported") + @pytest.mark.xfail(reason="Tiling not yet implemented with new dataset") def test_det_tile_merge(self, fxt_data_config): data_config = fxt_data_config[OTXTaskType.DETECTION] model = ATSS( @@ -443,6 +376,7 @@ def test_det_tile_merge(self, fxt_data_config): for batch in tile_datamodule.val_dataloader(): model.forward_tiles(batch) + @pytest.mark.xfail(reason="Tiling not yet implemented with new dataset") def test_explain_det_tile_merge(self, fxt_data_config): data_config = fxt_data_config[OTXTaskType.DETECTION] model = ATSS( @@ -468,6 +402,7 @@ def test_explain_det_tile_merge(self, fxt_data_config): assert prediction.saliency_map[0].ndim == 3 self.explain_mode = False + @pytest.mark.xfail(reason="Tiling not yet implemented with new dataset") def test_instseg_tile_merge(self, fxt_data_config): data_config = fxt_data_config[OTXTaskType.INSTANCE_SEGMENTATION] model = MaskRCNN( @@ -491,6 +426,7 @@ def test_instseg_tile_merge(self, fxt_data_config): for batch in tile_datamodule.val_dataloader(): model.forward_tiles(batch) + @pytest.mark.xfail(reason="Tiling not yet implemented with new dataset") def test_explain_instseg_tile_merge(self, fxt_data_config): data_config = fxt_data_config[OTXTaskType.INSTANCE_SEGMENTATION] model = MaskRCNN( @@ -516,6 +452,7 @@ def test_explain_instseg_tile_merge(self, fxt_data_config): assert prediction.saliency_map[0].ndim == 3 self.explain_mode = False + @pytest.mark.xfail(reason="Tiling not yet implemented with new dataset") def test_seg_tile_merge(self, fxt_data_config): data_config = fxt_data_config[OTXTaskType.SEMANTIC_SEGMENTATION] model = LiteHRNet( diff --git a/library/tests/unit/data/transform_libs/test_torchvision.py b/library/tests/unit/data/transform_libs/test_torchvision.py index 966f4f8c1a..6bbacc0b0f 100644 --- a/library/tests/unit/data/transform_libs/test_torchvision.py +++ b/library/tests/unit/data/transform_libs/test_torchvision.py @@ -10,7 +10,6 @@ import numpy as np import pytest import torch -from datumaro import Polygon from torch import LongTensor from torchvision import tv_tensors from torchvision.transforms.v2 import ToDtype @@ -130,10 +129,12 @@ def det_data_entity_with_polygons() -> OTXDataItem: fake_masks = tv_tensors.Mask(masks) # Create corresponding polygons - fake_polygons = [ - Polygon(points=[10, 10, 50, 10, 50, 50, 10, 50]), # Rectangle polygon for first object - Polygon(points=[60, 60, 100, 60, 100, 100, 60, 100]), # Rectangle polygon for second object - ] + fake_polygons = np.array( + [ + np.array([[10, 10], [50, 10], [50, 50], [10, 50]]), # Rectangle polygon for first object + np.array([[60, 60], [100, 60], [100, 100], [60, 100]]), # Rectangle polygon for second object + ] + ) return OTXDataItem( image=tv_tensors.Image(fake_image), @@ -257,8 +258,7 @@ def test_forward_bboxes_masks_polygons( assert all( [ # noqa: C419 np.all( - np.array(rp.points).reshape(-1, 2) - == np.array(fp.points).reshape(-1, 2) * np.array([results.img_info.scale_factor[::-1]]), + rp == fp * np.array([results.img_info.scale_factor[::-1]]), ) for rp, fp in zip(results.polygons, fxt_inst_seg_data_entity[0].polygons) ], @@ -293,15 +293,15 @@ def test_forward( assert torch.all(tv_tensors.Mask(results.masks).flip(-1) == fxt_inst_seg_data_entity[0].masks) # test polygons - def revert_hflip(polygon: list[float], width: int) -> list[float]: - p = np.asarray(polygon.points) - p[0::2] = width - p[0::2] - return p.tolist() + def revert_hflip(polygon: np.ndarray, width: int) -> np.ndarray: + polygon[:, 0] = width - polygon[:, 0] + return polygon width = results.img_info.img_shape[1] polygons_results = deepcopy(results.polygons) - polygons_results = [Polygon(points=revert_hflip(polygon, width)) for polygon in polygons_results] - assert polygons_results == fxt_inst_seg_data_entity[0].polygons + polygons_results = [revert_hflip(polygon, width) for polygon in polygons_results] + for polygon, expected_polygon in zip(polygons_results, fxt_inst_seg_data_entity[0].polygons): + assert np.all(polygon == expected_polygon) class TestPhotoMetricDistortion: @@ -406,8 +406,8 @@ def test_forward_with_polygons_transform_enabled( # Check that polygons are still valid (even number of coordinates) for polygon in results.polygons: - assert len(polygon.points) % 2 == 0 # Should have even number of coordinates - assert len(polygon.points) >= 6 # Should have at least 3 points (6 coordinates) + assert polygon.shape[1] == 2 # Should have (x,y) coordinates + assert polygon.shape[0] >= 3 # Should have at least 3 points def test_forward_with_masks_and_polygons_transform_enabled( self, @@ -502,15 +502,13 @@ def test_polygon_coordinates_validity( height, width = results.image.shape[:2] for polygon in results.polygons: - points = np.array(polygon.points).reshape(-1, 2) - # Check that x coordinates are within [0, width] - assert np.all(points[:, 0] >= 0) - assert np.all(points[:, 0] <= width) + assert np.all(polygon[:, 0] >= 0) + assert np.all(polygon[:, 0] <= width) # Check that y coordinates are within [0, height] - assert np.all(points[:, 1] >= 0) - assert np.all(points[:, 1] <= height) + assert np.all(polygon[:, 1] >= 0) + assert np.all(polygon[:, 1] <= height) @pytest.mark.parametrize("transform_polygon", [True, False]) def test_polygon_transform_parameter_effect( @@ -958,7 +956,7 @@ def iseg_entity(self) -> OTXDataItem: ), label=torch.LongTensor([0, 1]), masks=tv_tensors.Mask(np.zeros((2, 10, 10), np.uint8)), - polygons=[Polygon(points=[0, 0, 0, 7, 7, 7, 7, 0]), Polygon(points=[2, 3, 2, 9, 9, 9, 9, 3])], + polygons=np.array([np.array([[0, 0], [0, 7], [7, 7], [7, 0]]), np.array([[2, 3], [2, 9], [9, 9], [9, 3]])]), ) def test_init_invalid_crop_type(self) -> None: diff --git a/library/tests/unit/data/utils/test_utils.py b/library/tests/unit/data/utils/test_utils.py index 69d2b837f3..79cfaefe19 100644 --- a/library/tests/unit/data/utils/test_utils.py +++ b/library/tests/unit/data/utils/test_utils.py @@ -5,7 +5,6 @@ from __future__ import annotations -from collections import defaultdict from unittest.mock import MagicMock import cv2 @@ -23,8 +22,6 @@ compute_robust_scale_statistics, compute_robust_statistics, get_adaptive_num_workers, - get_idx_list_per_classes, - import_object_from_module, ) @@ -239,29 +236,3 @@ def fxt_dm_dataset() -> DmDataset: ] return DmDataset.from_iterable(dataset_items, categories=["0", "1"]) - - -def test_get_idx_list_per_classes(fxt_dm_dataset): - # Call the function under test - result = get_idx_list_per_classes(fxt_dm_dataset) - - # Assert the expected output - expected_result = defaultdict(list) - expected_result[0] = list(range(100)) - expected_result[1] = list(range(100, 108)) - assert result == expected_result - - # Call the function under test with use_string_label - result = get_idx_list_per_classes(fxt_dm_dataset, use_string_label=True) - - # Assert the expected output - expected_result = defaultdict(list) - expected_result["0"] = list(range(100)) - expected_result["1"] = list(range(100, 108)) - assert result == expected_result - - -def test_import_object_from_module(): - obj_path = "otx.data.utils.get_idx_list_per_classes" - obj = import_object_from_module(obj_path) - assert obj == get_idx_list_per_classes diff --git a/library/tests/unit/tools/test_converter.py b/library/tests/unit/tools/test_converter.py index f1856bbcd4..db02a65c3a 100644 --- a/library/tests/unit/tools/test_converter.py +++ b/library/tests/unit/tools/test_converter.py @@ -112,6 +112,7 @@ def test_classification_augs(self, tmp_path): assert engine.datamodule.train_dataloader().dataset.transforms is not None assert len(engine.datamodule.train_dataloader().dataset.transforms.transforms) == 9 + @pytest.mark.xfail(reason="Tiling not yet implemented with new dataset") def test_detection_augs(self, tmp_path): supported_augs_list_for_configuration = [ "otx.data.transform_libs.torchvision.MinIoURandomCrop",