From 845b63fa127118e1f76acf46e99c85a1a6241b8a Mon Sep 17 00:00:00 2001 From: kprokofi Date: Wed, 12 Nov 2025 21:39:47 +0900 Subject: [PATCH 1/6] add some kornia augs --- library/src/otx/backend/native/models/base.py | 165 +++++++++++++++++- .../backend/native/models/detection/atss.py | 11 ++ .../backend/native/models/detection/base.py | 25 ++- .../backend/native/models/detection/d_fine.py | 15 ++ .../backend/native/models/detection/yolox.py | 19 ++ library/src/otx/data/entity/torch/torch.py | 7 +- .../src/otx/data/entity/torch/validations.py | 2 + .../src/otx/recipe/_base_/data/detection.yaml | 74 ++------ .../recipe/detection/atss_mobilenetv2.yaml | 4 + library/src/otx/recipe/detection/yolox_x.yaml | 131 +++++++------- 10 files changed, 317 insertions(+), 136 deletions(-) diff --git a/library/src/otx/backend/native/models/base.py b/library/src/otx/backend/native/models/base.py index 23ee4b52818..9a25f5b3e28 100644 --- a/library/src/otx/backend/native/models/base.py +++ b/library/src/otx/backend/native/models/base.py @@ -9,8 +9,10 @@ import inspect import logging +import threading import warnings from abc import abstractmethod +from contextlib import contextmanager from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence @@ -21,6 +23,7 @@ from torch.optim.lr_scheduler import ConstantLR from torch.optim.sgd import SGD from torchmetrics import Metric, MetricCollection +from kornia.augmentation.container import AugmentationSequential from otx import __version__ from otx.backend.native.optimizers.callable import OptimizerCallableSupportAdaptiveBS @@ -126,6 +129,7 @@ def __init__( metric: MetricCallable = NullMetricCallable, torch_compile: bool = False, tile_config: TileConfig | dict = TileConfig(enable_tiler=False), + enable_async_streams: bool = False, ) -> None: """Initialize the base model with the given parameters. @@ -141,6 +145,8 @@ def __init__( metric (MetricCallable, optional): Callable for the metric. Defaults to NullMetricCallable. torch_compile (bool, optional): Flag to indicate if torch.compile should be used. Defaults to False. tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False). + enable_async_streams (bool, optional): Enable asynchronous CUDA/XPU streams for concurrent operations. + Defaults to True. Returns: None @@ -158,7 +164,6 @@ def __init__( self.scheduler_callable = ensure_callable(scheduler) self.metric_callable = ensure_callable(metric) self._task = task - self.torch_compile = torch_compile self._explain_mode = False @@ -166,6 +171,14 @@ def __init__( if isinstance(tile_config, dict): tile_config = TileConfig(**tile_config) self._tile_config = tile_config.clone() + + # Stream support for concurrent operations + self.enable_async_streams = enable_async_streams + self._augmentation_stream: torch.cuda.Stream | Any | None = None + self._compute_stream: torch.cuda.Stream | Any | None = None + self._stream_lock = threading.Lock() + self._augmentation_event: torch.cuda.Event | Any | None = None + self.save_hyperparameters( logger=False, ignore=["optimizer", "scheduler", "metric", "label_info", "tile_config", "data_input_params"], @@ -317,6 +330,20 @@ def on_test_epoch_end(self) -> None: """Callback triggered when the test epoch ends.""" self._log_metrics(self.metric, "test") + def on_train_epoch_end(self) -> None: + """Callback triggered when the training epoch ends. + + Synchronizes streams to ensure all async operations are complete. + """ + if self.enable_async_streams and self._augmentation_stream is not None: + device_type = self.device.type if hasattr(self.device, 'type') else str(self.device) + + # Ensure all stream operations are complete + if device_type == "cuda" and hasattr(self._augmentation_stream, 'synchronize'): + self._augmentation_stream.synchronize() + elif device_type == "xpu" and hasattr(self._augmentation_stream, 'synchronize'): + self._augmentation_stream.synchronize() + def setup(self, stage: str) -> None: """Lightning hook called at the beginning of fit, validate, test, or predict stages. @@ -327,9 +354,14 @@ def setup(self, stage: str) -> None: stage: The current stage, either "fit", "validate", "test", or "predict". Note: + Initializes CUDA/XPU streams for asynchronous operations during training. When torch_compile is enabled and stage is "fit", compiles the model for optimized performance with appropriate logging level adjustments. """ + # Initialize streams for training + if self.enable_async_streams and stage == "fit": + self._initialize_streams() + if self.torch_compile and stage == "fit": # Set the log_level of this to error due to the numerous warning messages from compile. torch._logging.set_logs(dynamo=logging.ERROR) # noqa: SLF001 @@ -342,6 +374,100 @@ def setup(self, stage: str) -> None: stacklevel=1, ) + def _initialize_streams(self) -> None: + """Initialize CUDA or XPU streams for asynchronous operations. + + Creates separate streams for: + - Augmentation operations (GPU-based transforms) + - Model computation (forward/backward pass) + + This enables overlapping of augmentation and computation for better GPU utilization. + """ + device_type = self.device.type if hasattr(self.device, 'type') else str(self.device) + + try: + if device_type == "cuda" and torch.cuda.is_available(): + # Create CUDA streams + self._augmentation_stream = torch.cuda.Stream() + self._compute_stream = torch.cuda.default_stream() + self._augmentation_event = torch.cuda.Event() + logger.info("CUDA streams initialized for asynchronous augmentation and computation") + + elif device_type == "xpu" and hasattr(torch, 'xpu') and torch.xpu.is_available(): + # Create XPU streams + self._augmentation_stream = torch.xpu.Stream() + self._compute_stream = torch.xpu.default_stream() + # XPU events might have different API, adjust as needed + if hasattr(torch.xpu, 'Event'): + self._augmentation_event = torch.xpu.Event() + logger.info("XPU streams initialized for asynchronous augmentation and computation") + + else: + logger.info(f"Async streams not available for device type: {device_type}. Using default execution.") + self.enable_async_streams = False + + except Exception as e: + logger.warning(f"Failed to initialize async streams: {e}. Falling back to default execution.") + self.enable_async_streams = False + self._augmentation_stream = None + self._compute_stream = None + self._augmentation_event = None + + @contextmanager + def _augmentation_stream_context(self): + """Context manager for augmentation stream. + + Yields: + Stream context for GPU-based augmentations. + + Example: + ```python + with self._augmentation_stream_context(): + augmented_batch = self.apply_batch_transforms(batch) + ``` + """ + if self._augmentation_stream is not None: + device_type = self.device.type if hasattr(self.device, 'type') else str(self.device) + + if device_type == "cuda": + with torch.cuda.stream(self._augmentation_stream): + yield + elif device_type == "xpu" and hasattr(torch.xpu, 'stream'): + with torch.xpu.stream(self._augmentation_stream): + yield + else: + yield + else: + yield + + def _wait_for_augmentation(self) -> None: + """Wait for augmentation stream to complete before model forward. + + Ensures that augmentation operations are finished before the main + computation stream proceeds with the forward pass. + """ + if self._compute_stream is not None and self._augmentation_stream is not None: + device_type = self.device.type if hasattr(self.device, 'type') else str(self.device) + + if device_type == "cuda": + # Record event in augmentation stream + if self._augmentation_event is not None: + self._augmentation_event.record(self._augmentation_stream) + # Make compute stream wait for augmentation + self._compute_stream.wait_event(self._augmentation_event) + else: + # Fallback: synchronize streams + self._compute_stream.wait_stream(self._augmentation_stream) + + elif device_type == "xpu" and hasattr(torch.xpu, 'current_stream'): + # XPU stream synchronization + if hasattr(self._compute_stream, 'wait_stream'): + self._compute_stream.wait_stream(self._augmentation_stream) + else: + # Fallback: synchronize augmentation stream + if hasattr(self._augmentation_stream, 'synchronize'): + self._augmentation_stream.synchronize() + def configure_optimizers(self) -> OptimizerLRScheduler: """Configure an optimizer and learning-rate schedulers. @@ -596,7 +722,33 @@ def forward( self, inputs: OTXDataBatch, ) -> OTXPredBatch | OTXBatchLossEntity: - """Model forward function.""" + """Model forward function with asynchronous stream support. + + When async streams are enabled, this method: + 1. Applies batch augmentations on a separate stream + 2. Waits for augmentation to complete + 3. Runs model forward on the main compute stream + + This overlaps augmentation with previous batch's backward pass for better GPU utilization. + + Args: + inputs: Batch of input data. + + Returns: + Model predictions or loss entity depending on training mode. + """ + # Apply batch augmentations asynchronously + if self.enable_async_streams and self.training and self._augmentation_stream is not None: + # Run augmentations on separate stream + with self._augmentation_stream_context(): + self.apply_batch_transforms(inputs) + + # Wait for augmentations to complete before forward pass + self._wait_for_augmentation() + else: + # Synchronous execution (validation/test or when streams disabled) + self.apply_batch_transforms(inputs) + # If customize_inputs is overridden if isinstance(inputs, OTXTileBatchDataEntity): return self.forward_tiles(inputs) @@ -637,6 +789,15 @@ def forward_tiles( """Model forward function for tile task.""" raise NotImplementedError + def apply_batch_transforms(self, inputs: OTXDataBatch) -> None: + """Apply kornia batch transforms to OTXDataBatch.""" + self.transforms(inputs.images) + + @property + def transforms(self) -> AugmentationSequential: + """Kornia Image Transforms.""" + return AugmentationSequential() + def register_load_state_dict_pre_hook(self, model_classes: list[str], ckpt_classes: list[str]) -> None: """Register load_state_dict_pre_hook. diff --git a/library/src/otx/backend/native/models/detection/atss.py b/library/src/otx/backend/native/models/detection/atss.py index 9f83b3e6a9a..6583c2167cd 100644 --- a/library/src/otx/backend/native/models/detection/atss.py +++ b/library/src/otx/backend/native/models/detection/atss.py @@ -6,6 +6,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, ClassVar, Literal +import kornia as K from otx.backend.native.exporter.base import OTXModelExporter from otx.backend.native.exporter.native import OTXNativeModelExporter @@ -204,3 +205,13 @@ def _exporter(self) -> OTXModelExporter: def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_det_ckpt(state_dict, add_prefix) + + # @property + # def transforms(self): + # return K.augmentation.AugmentationSequential( + # K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), + # K.augmentation.RandomAffine(360, [0.1, 0.1], [0.7, 1.2], [30.0, 50.0], p=1.0), + # K.augmentation.RandomPerspective(0.5, p=1.0), + # data_keys=["images", "bbox"], + # same_on_batch=False, + # ) diff --git a/library/src/otx/backend/native/models/detection/base.py b/library/src/otx/backend/native/models/detection/base.py index b23b08886a6..7bbfdc3a341 100644 --- a/library/src/otx/backend/native/models/detection/base.py +++ b/library/src/otx/backend/native/models/detection/base.py @@ -15,6 +15,8 @@ import torch from torchmetrics import Metric, MetricCollection from torchvision import tv_tensors +import kornia +from kornia.geometry.boxes import Boxes from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel from otx.backend.native.models.utils.utils import InstanceData @@ -175,7 +177,6 @@ def _customize_inputs( inputs["entity"] = entity inputs["mode"] = "loss" if self.training else "predict" - return inputs def _customize_outputs( @@ -252,6 +253,28 @@ def _customize_outputs( labels=labels, ) + def apply_batch_transforms(self, inputs: OTXDataBatch) -> types.NoneType: + """Apply batch augmentations to Object Detection.""" + # Convert bounding boxes to Kornia Boxes [N, 4, 2] + kornia_boxes = Boxes.from_tensor(inputs.bboxes, mode='xyxy') + self.transforms(inputs.images, kornia_boxes) + inputs.bboxes = kornia_boxes.to_tensor(mode='xyxy') + + @property + def transforms(self): + if self.training: + return kornia.augmentation.AugmentationSequential( + kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input", "bbox"], + same_on_batch=False + ) + return kornia.augmentation.AugmentationSequential( + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input", "bbox"], + ) + + def forward_tiles(self, inputs: OTXTileBatchDataEntity) -> OTXPredBatch: """Unpack detection tiles. diff --git a/library/src/otx/backend/native/models/detection/d_fine.py b/library/src/otx/backend/native/models/detection/d_fine.py index 2c01388f798..5a0f9338550 100644 --- a/library/src/otx/backend/native/models/detection/d_fine.py +++ b/library/src/otx/backend/native/models/detection/d_fine.py @@ -17,6 +17,7 @@ from otx.backend.native.models.utils.utils import load_checkpoint from otx.config.data import TileConfig from otx.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable +import kornia if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -172,3 +173,17 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: ckpt.pop("model.decoder.anchors") ckpt.pop("model.decoder.valid_mask") return super().load_state_dict(ckpt, *args, strict=False, **kwargs) + + @property + def transforms(self): + if self.training: + return kornia.augmentation.AugmentationSequential( + kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input", "bbox"], + same_on_batch=False + ) + return kornia.augmentation.AugmentationSequential( + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input", "bbox"], + ) diff --git a/library/src/otx/backend/native/models/detection/yolox.py b/library/src/otx/backend/native/models/detection/yolox.py index e9a2a0ca745..cb926d97beb 100644 --- a/library/src/otx/backend/native/models/detection/yolox.py +++ b/library/src/otx/backend/native/models/detection/yolox.py @@ -25,6 +25,7 @@ from otx.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable from otx.types.export import OTXExportFormatType from otx.types.precision import OTXPrecisionType +import kornia if TYPE_CHECKING: from pathlib import Path @@ -197,3 +198,21 @@ def export( def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_det_ckpt(state_dict, add_prefix) + + @property + def transforms(self): + if self.training: + return kornia.augmentation.AugmentationSequential( + # kornia.augmentation.RandomMixUpV2(data_keys=["input", "bbox"], p=0.5, lambda_val=[0.5, 0.5]), + kornia.augmentation.RandomMosaic(self.data_input_params.input_size, data_keys=["input", "bbox"]), + kornia.augmentation.RandomAffine(degrees=10.0, translate=[0.1, 0.1], scale=[0.5,1.5], shear=2.0), + kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1), + kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input", "bbox"], + same_on_batch=False + ) + return kornia.augmentation.AugmentationSequential( + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input", "bbox"], + ) diff --git a/library/src/otx/data/entity/torch/torch.py b/library/src/otx/data/entity/torch/torch.py index ff85c15c514..2d8a2c6cf0c 100644 --- a/library/src/otx/data/entity/torch/torch.py +++ b/library/src/otx/data/entity/torch/torch.py @@ -62,12 +62,7 @@ def collate_fn(items: list[OTXDataItem]) -> OTXDataBatch: Returns: Batched TorchDataItems with stacked tensors """ - # Check if all images have the same size. TODO(kprokofi): remove this check once OV IR models are moved. - if all(item.image.shape == items[0].image.shape for item in items): - images = torch.stack([item.image for item in items]) - else: - # we need this only in case of OV inference, where no resize - images = [item.image for item in items] + images = torch.stack([item.image for item in items]) return OTXDataBatch( batch_size=len(items), diff --git a/library/src/otx/data/entity/torch/validations.py b/library/src/otx/data/entity/torch/validations.py index 9b9d89cb508..47c0b1c5e6f 100644 --- a/library/src/otx/data/entity/torch/validations.py +++ b/library/src/otx/data/entity/torch/validations.py @@ -9,6 +9,7 @@ import numpy as np import torch +from kornia.geometry.boxes import Boxes from datumaro import Polygon from torchvision.tv_tensors import BoundingBoxes, Mask @@ -61,6 +62,7 @@ def _label_validator(label: torch.Tensor) -> torch.Tensor: raise TypeError(msg) if label.dtype != torch.long: msg = f"Label must have dtype torch.long, but got {label.dtype}" + print(label) raise ValueError(msg) # detection tasks allow multiple labels so the shape is [B, N] if label.ndim > 2: diff --git a/library/src/otx/recipe/_base_/data/detection.yaml b/library/src/otx/recipe/_base_/data/detection.yaml index 503a22bd5b4..da59453a832 100644 --- a/library/src/otx/recipe/_base_/data/detection.yaml +++ b/library/src/otx/recipe/_base_/data/detection.yaml @@ -9,69 +9,21 @@ train_subset: subset_name: train transform_lib_type: TORCHVISION batch_size: 1 - num_workers: 2 - to_tv_image: false + num_workers: 4 + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.MinIoURandomCrop - enable: true - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - scale: $(input_size) - transform_bbox: true - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false + - class_path: torchvision.transforms.v2.RandomIoUCrop + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomAffine - enable: false + min_size: 1 + - class_path: torchvision.transforms.v2.Resize init_args: - max_rotate_degree: 10.0 - max_translate_ratio: 0.1 - scaling_ratio_range: - - 0.5 - - 1.5 - max_shear_degree: 2.0 - - class_path: otx.data.transform_libs.torchvision.RandomFlip - enable: true - init_args: - probability: 0.5 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 + size: $(input_size) - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [0.0, 0.0, 0.0] - std: [255.0, 255.0, 255.0] + scale: true + sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -80,12 +32,12 @@ val_subset: subset_name: val transform_lib_type: TORCHVISION batch_size: 1 - num_workers: 2 + num_workers: 4 to_tv_image: false transforms: - - class_path: otx.data.transform_libs.torchvision.Resize + - class_path: torchvision.transforms.v2.Resize init_args: - scale: $(input_size) + size: $(input_size) - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} @@ -101,7 +53,7 @@ test_subset: subset_name: test transform_lib_type: TORCHVISION batch_size: 1 - num_workers: 2 + num_workers: 4 to_tv_image: false transforms: - class_path: otx.data.transform_libs.torchvision.Resize diff --git a/library/src/otx/recipe/detection/atss_mobilenetv2.yaml b/library/src/otx/recipe/detection/atss_mobilenetv2.yaml index b24d4acb32e..272b8035783 100644 --- a/library/src/otx/recipe/detection/atss_mobilenetv2.yaml +++ b/library/src/otx/recipe/detection/atss_mobilenetv2.yaml @@ -4,6 +4,10 @@ model: init_args: model_name: atss_mobilenetv2 label_info: 80 + data_input_params: + input_size: [800, 992] + mean: [0.0, 0.0, 0.0] + std: [255.0, 255.0, 255.0] optimizer: class_path: torch.optim.SGD diff --git a/library/src/otx/recipe/detection/yolox_x.yaml b/library/src/otx/recipe/detection/yolox_x.yaml index 30206207646..1594d572653 100644 --- a/library/src/otx/recipe/detection/yolox_x.yaml +++ b/library/src/otx/recipe/detection/yolox_x.yaml @@ -69,75 +69,74 @@ overrides: - class_path: otx.data.transform_libs.torchvision.Resize init_args: scale: $(input_size) - keep_ratio: true - transform_bbox: true - - class_path: otx.data.transform_libs.torchvision.CachedMosaic - init_args: - random_pop: false - max_cached_images: 20 - img_scale: $(input_size) # (H, W) - - class_path: otx.data.transform_libs.torchvision.RandomAffine - init_args: - border: $(input_size) * -0.5 - - class_path: otx.data.transform_libs.torchvision.CachedMixUp - init_args: - img_scale: $(input_size) # (H, W) - ratio_range: - - 1.0 - - 1.0 - probability: 0.5 - random_pop: false - max_cached_images: 10 - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false - init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.YOLOXHSVRandomAug - - class_path: otx.data.transform_libs.torchvision.RandomFlip - init_args: - probability: 0.5 - is_numpy_to_tvtensor: false - - class_path: otx.data.transform_libs.torchvision.Pad - init_args: - pad_to_square: true - pad_val: 114 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 + # - class_path: otx.data.transform_libs.torchvision.CachedMosaic + # init_args: + # random_pop: false + # max_cached_images: 20 + # img_scale: $(input_size) # (H, W) + # - class_path: otx.data.transform_libs.torchvision.RandomAffine + # init_args: + # border: $(input_size) * -0.5 + # - class_path: otx.data.transform_libs.torchvision.CachedMixUp + # init_args: + # img_scale: $(input_size) # (H, W) + # ratio_range: + # - 1.0 + # - 1.0 + # probability: 0.5 + # random_pop: false + # max_cached_images: 10 + # - class_path: torchvision.transforms.v2.RandomPhotometricDistort + # enable: false + # init_args: + # brightness: + # - 0.875 + # - 1.125 + # contrast: + # - 0.5 + # - 1.5 + # saturation: + # - 0.5 + # - 1.5 + # hue: + # - -0.05 + # - 0.05 + # p: 0.5 + # - class_path: otx.data.transform_libs.torchvision.YOLOXHSVRandomAug + # - class_path: otx.data.transform_libs.torchvision.RandomFlip + # init_args: + # probability: 0.5 + # is_numpy_to_tvtensor: false + # - class_path: otx.data.transform_libs.torchvision.Pad + # init_args: + # pad_to_square: true + # pad_val: 114 + # - class_path: torchvision.transforms.v2.RandomVerticalFlip + # enable: false + # init_args: + # p: 0.5 + # - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur + # enable: false + # init_args: + # kernel_size: 5 + # sigma: + # - 0.1 + # - 2.0 + # probability: 0.5 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [0.0, 0.0, 0.0] - std: [1.0, 1.0, 1.0] + scale: true + # - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise + # enable: false + # init_args: + # mean: 0.0 + # sigma: 0.1 + # probability: 0.5 + # - class_path: torchvision.transforms.v2.Normalize + # init_args: + # mean: [0.0, 0.0, 0.0] + # std: [1.0, 1.0, 1.0] sampler: class_path: otx.data.samplers.balanced_sampler.BalancedSampler From 0c44aef9d640c2b627e50fc356e51b573b55fd61 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Wed, 12 Nov 2025 23:59:39 +0900 Subject: [PATCH 2/6] minor change --- library/src/otx/backend/native/models/base.py | 1 + library/src/otx/backend/native/models/detection/base.py | 1 + library/src/otx/data/module.py | 1 + 3 files changed, 3 insertions(+) diff --git a/library/src/otx/backend/native/models/base.py b/library/src/otx/backend/native/models/base.py index 9a25f5b3e28..4cf6d039687 100644 --- a/library/src/otx/backend/native/models/base.py +++ b/library/src/otx/backend/native/models/base.py @@ -789,6 +789,7 @@ def forward_tiles( """Model forward function for tile task.""" raise NotImplementedError + @torch.no_grad() def apply_batch_transforms(self, inputs: OTXDataBatch) -> None: """Apply kornia batch transforms to OTXDataBatch.""" self.transforms(inputs.images) diff --git a/library/src/otx/backend/native/models/detection/base.py b/library/src/otx/backend/native/models/detection/base.py index 7bbfdc3a341..7cbe5f7df89 100644 --- a/library/src/otx/backend/native/models/detection/base.py +++ b/library/src/otx/backend/native/models/detection/base.py @@ -253,6 +253,7 @@ def _customize_outputs( labels=labels, ) + @torch.no_grad() def apply_batch_transforms(self, inputs: OTXDataBatch) -> types.NoneType: """Apply batch augmentations to Object Detection.""" # Convert bounding boxes to Kornia Boxes [N, 4, 2] diff --git a/library/src/otx/data/module.py b/library/src/otx/data/module.py index 419fa51d058..be6d0d13dcc 100644 --- a/library/src/otx/data/module.py +++ b/library/src/otx/data/module.py @@ -445,6 +445,7 @@ def train_dataloader(self) -> DataLoader: "persistent_workers": config.num_workers > 0, "sampler": sampler, "shuffle": sampler is None, + "prefetch_factor": 2, } tile_config = self.tile_config From 8d3460177ad3d601f395a17455cbbe015451bf44 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Thu, 13 Nov 2025 19:06:02 +0900 Subject: [PATCH 3/6] update YOLOX interface --- .../multiclass_models/efficientnet.py | 19 +++ .../recipe/_base_/data/classification.yaml | 102 +++++++-------- .../multi_class_cls/efficientnet_b0.yaml | 119 +++++++++--------- 3 files changed, 130 insertions(+), 110 deletions(-) diff --git a/library/src/otx/backend/native/models/classification/multiclass_models/efficientnet.py b/library/src/otx/backend/native/models/classification/multiclass_models/efficientnet.py index 0d68a8e18ef..02ef2d88b19 100644 --- a/library/src/otx/backend/native/models/classification/multiclass_models/efficientnet.py +++ b/library/src/otx/backend/native/models/classification/multiclass_models/efficientnet.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Literal from torch import Tensor, nn +import kornia from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable from otx.backend.native.models.classification.backbones.efficientnet import EfficientNetBackbone @@ -101,3 +102,21 @@ def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: return self.model(images=image, mode="explain") return self.model(images=image, mode="tensor") + + @property + def transforms(self): + if self.training: + return kornia.augmentation.AugmentationSequential( + # kornia.augmentation.RandomResizedCrop(self.data_input_params.input_size, scale=(0.08, 1.0)), + kornia.augmentation.RandomAffine(degrees=10.0, translate=[0.1, 0.1], scale=[0.5,1.5], shear=2.0), + kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1), + kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.RandomGaussianBlur(5, (0.1, 2.0)), + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input"], + same_on_batch=False + ) + return kornia.augmentation.AugmentationSequential( + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input"], + ) diff --git a/library/src/otx/recipe/_base_/data/classification.yaml b/library/src/otx/recipe/_base_/data/classification.yaml index 8bd2149f66d..5017fa97868 100644 --- a/library/src/otx/recipe/_base_/data/classification.yaml +++ b/library/src/otx/recipe/_base_/data/classification.yaml @@ -21,61 +21,61 @@ train_subset: aspect_ratio_range: - 0.75 - 1.34 - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false - init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomAffine - enable: false - init_args: - max_rotate_degree: 10.0 - max_translate_ratio: 0.1 - scaling_ratio_range: - - 0.5 - - 1.5 - max_shear_degree: 2.0 - - class_path: otx.data.transform_libs.torchvision.RandomFlip - enable: true - init_args: - probability: 0.5 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 + # - class_path: torchvision.transforms.v2.RandomPhotometricDistort + # enable: false + # init_args: + # brightness: + # - 0.875 + # - 1.125 + # contrast: + # - 0.5 + # - 1.5 + # saturation: + # - 0.5 + # - 1.5 + # hue: + # - -0.05 + # - 0.05 + # p: 0.5 + # - class_path: otx.data.transform_libs.torchvision.RandomAffine + # enable: false + # init_args: + # max_rotate_degree: 10.0 + # max_translate_ratio: 0.1 + # scaling_ratio_range: + # - 0.5 + # - 1.5 + # max_shear_degree: 2.0 + # - class_path: otx.data.transform_libs.torchvision.RandomFlip + # enable: true + # init_args: + # probability: 0.5 + # - class_path: torchvision.transforms.v2.RandomVerticalFlip + # enable: false + # init_args: + # p: 0.5 + # - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur + # enable: false + # init_args: + # kernel_size: 5 + # sigma: + # - 0.1 + # - 2.0 + # probability: 0.5 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} scale: false - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + # - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise + # enable: false + # init_args: + # mean: 0.0 + # sigma: 0.1 + # probability: 0.5 + # - class_path: torchvision.transforms.v2.Normalize + # init_args: + # mean: [123.675, 116.28, 103.53] + # std: [58.395, 57.12, 57.375] sampler: class_path: otx.data.samplers.balanced_sampler.BalancedSampler init_args: null diff --git a/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml b/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml index 893ee13fd82..d70a22cfce1 100644 --- a/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml +++ b/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml @@ -52,67 +52,68 @@ overrides: data: train_subset: + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.EfficientNetRandomCrop + - class_path: torchvision.transforms.v2.RandomResizedCrop init_args: - scale: $(input_size) - crop_ratio_range: - - 0.08 - - 1.0 - aspect_ratio_range: - - 0.75 - - 1.34 - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false - init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomAffine - enable: false - init_args: - max_rotate_degree: 10.0 - max_translate_ratio: 0.1 - scaling_ratio_range: - - 0.5 - - 1.5 - max_shear_degree: 2.0 - - class_path: otx.data.transform_libs.torchvision.RandomFlip - enable: true - init_args: - probability: 0.5 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 + size: $(input_size) + # crop_ratio_range: + # - 0.08 + # - 1.0 + # aspect_ratio_range: + # - 0.75 + # - 1.34 + # - class_path: torchvision.transforms.v2.RandomPhotometricDistort + # enable: false + # init_args: + # brightness: + # - 0.875 + # - 1.125 + # contrast: + # - 0.5 + # - 1.5 + # saturation: + # - 0.5 + # - 1.5 + # hue: + # - -0.05 + # - 0.05 + # p: 0.5 + # - class_path: otx.data.transform_libs.torchvision.RandomAffine + # enable: false + # init_args: + # max_rotate_degree: 10.0 + # max_translate_ratio: 0.1 + # scaling_ratio_range: + # - 0.5 + # - 1.5 + # max_shear_degree: 2.0 + # - class_path: otx.data.transform_libs.torchvision.RandomFlip + # enable: true + # init_args: + # probability: 0.5 + # - class_path: torchvision.transforms.v2.RandomVerticalFlip + # enable: false + # init_args: + # p: 0.5 + # - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur + # enable: false + # init_args: + # kernel_size: 5 + # sigma: + # - 0.1 + # - 2.0 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} - scale: false - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + scale: true + # - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise + # enable: false + # init_args: + # mean: 0.0 + # sigma: 0.1 + # probability: 0.5 + # - class_path: torchvision.transforms.v2.Normalize + # init_args: + # mean: [123.675, 116.28, 103.53] + # std: [58.395, 57.12, 57.375] From 185128b53df7a411e06f298410f4975de7d5e172 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Tue, 18 Nov 2025 22:35:24 +0900 Subject: [PATCH 4/6] update augmentations, delete cuda streams --- library/src/otx/backend/native/models/base.py | 246 +++++++----------- .../classification/multiclass_models/base.py | 13 +- 2 files changed, 103 insertions(+), 156 deletions(-) diff --git a/library/src/otx/backend/native/models/base.py b/library/src/otx/backend/native/models/base.py index 4cf6d039687..cc33adf8c1f 100644 --- a/library/src/otx/backend/native/models/base.py +++ b/library/src/otx/backend/native/models/base.py @@ -9,10 +9,8 @@ import inspect import logging -import threading import warnings from abc import abstractmethod -from contextlib import contextmanager from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence @@ -24,6 +22,7 @@ from torch.optim.sgd import SGD from torchmetrics import Metric, MetricCollection from kornia.augmentation.container import AugmentationSequential +from torchvision.transforms.v2 import Compose from otx import __version__ from otx.backend.native.optimizers.callable import OptimizerCallableSupportAdaptiveBS @@ -122,14 +121,15 @@ def __init__( self, label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | dict, - task: OTXTaskType | None = None, model_name: str = "OTXModel", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = NullMetricCallable, torch_compile: bool = False, tile_config: TileConfig | dict = TileConfig(enable_tiler=False), - enable_async_streams: bool = False, ) -> None: """Initialize the base model with the given parameters. @@ -139,17 +139,19 @@ def __init__( if `Sequence` is given, label info will be constructed from the sequence of label names. data_input_params (DataInputParams | dict): Parameters of the input data such as input size, mean, and std. model_name (str, optional): Name of the model. Defaults to "OTXModel". + apply_gpu_transforms (bool, optional): Flag to indicate whether to apply GPU transforms. + It is recommended to use GPU transforms. Defaults to True. + batch_train_transforms (AugmentationSequential | Compose | None): GPU transforms for training applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. + batch_val_transforms (AugmentationSequential | Compose | None): GPU transforms for validation / testing applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. Typically just normalization. optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable): Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable. metric (MetricCallable, optional): Callable for the metric. Defaults to NullMetricCallable. torch_compile (bool, optional): Flag to indicate if torch.compile should be used. Defaults to False. tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False). - enable_async_streams (bool, optional): Enable asynchronous CUDA/XPU streams for concurrent operations. - Defaults to True. - Returns: - None """ super().__init__() @@ -163,7 +165,6 @@ def __init__( self.optimizer_callable = ensure_callable(optimizer) self.scheduler_callable = ensure_callable(scheduler) self.metric_callable = ensure_callable(metric) - self._task = task self.torch_compile = torch_compile self._explain_mode = False @@ -172,12 +173,9 @@ def __init__( tile_config = TileConfig(**tile_config) self._tile_config = tile_config.clone() - # Stream support for concurrent operations - self.enable_async_streams = enable_async_streams - self._augmentation_stream: torch.cuda.Stream | Any | None = None - self._compute_stream: torch.cuda.Stream | Any | None = None - self._stream_lock = threading.Lock() - self._augmentation_event: torch.cuda.Event | Any | None = None + # Augmentation configuration + if apply_gpu_transforms: + self.batch_train_transforms, self.batch_val_transforms = self._configure_batch_augmentation(batch_train_transforms, batch_val_transforms) self.save_hyperparameters( logger=False, @@ -186,6 +184,7 @@ def __init__( def training_step(self, batch: OTXDataBatch, batch_idx: int) -> Tensor: """Step for model training.""" + self._apply_batch_augmentations(self.batch_train_transforms, batch) train_loss = self.forward(inputs=batch) if train_loss is None: msg = "Loss is None." @@ -240,6 +239,7 @@ def validation_step(self, batch: OTXDataBatch, batch_idx: int) -> OTXPredBatch: Updates test metrics based on the prediction results and batch data. Handles both single dictionary and list of dictionaries for metric inputs. """ + self._apply_batch_augmentations(self.batch_val_transforms, batch) preds = self.forward(inputs=batch) if isinstance(preds, OTXBatchLossEntity): @@ -271,6 +271,7 @@ def test_step(self, batch: OTXDataBatch, batch_idx: int) -> OTXPredBatch: When torch_compile is enabled and stage is "fit", compiles the model for optimized performance with appropriate logging level adjustments. """ + self._apply_batch_augmentations(self.batch_val_transforms, batch) preds = self.forward(inputs=batch) if isinstance(preds, OTXBatchLossEntity): @@ -330,20 +331,6 @@ def on_test_epoch_end(self) -> None: """Callback triggered when the test epoch ends.""" self._log_metrics(self.metric, "test") - def on_train_epoch_end(self) -> None: - """Callback triggered when the training epoch ends. - - Synchronizes streams to ensure all async operations are complete. - """ - if self.enable_async_streams and self._augmentation_stream is not None: - device_type = self.device.type if hasattr(self.device, 'type') else str(self.device) - - # Ensure all stream operations are complete - if device_type == "cuda" and hasattr(self._augmentation_stream, 'synchronize'): - self._augmentation_stream.synchronize() - elif device_type == "xpu" and hasattr(self._augmentation_stream, 'synchronize'): - self._augmentation_stream.synchronize() - def setup(self, stage: str) -> None: """Lightning hook called at the beginning of fit, validate, test, or predict stages. @@ -354,14 +341,9 @@ def setup(self, stage: str) -> None: stage: The current stage, either "fit", "validate", "test", or "predict". Note: - Initializes CUDA/XPU streams for asynchronous operations during training. When torch_compile is enabled and stage is "fit", compiles the model for optimized performance with appropriate logging level adjustments. """ - # Initialize streams for training - if self.enable_async_streams and stage == "fit": - self._initialize_streams() - if self.torch_compile and stage == "fit": # Set the log_level of this to error due to the numerous warning messages from compile. torch._logging.set_logs(dynamo=logging.ERROR) # noqa: SLF001 @@ -374,100 +356,6 @@ def setup(self, stage: str) -> None: stacklevel=1, ) - def _initialize_streams(self) -> None: - """Initialize CUDA or XPU streams for asynchronous operations. - - Creates separate streams for: - - Augmentation operations (GPU-based transforms) - - Model computation (forward/backward pass) - - This enables overlapping of augmentation and computation for better GPU utilization. - """ - device_type = self.device.type if hasattr(self.device, 'type') else str(self.device) - - try: - if device_type == "cuda" and torch.cuda.is_available(): - # Create CUDA streams - self._augmentation_stream = torch.cuda.Stream() - self._compute_stream = torch.cuda.default_stream() - self._augmentation_event = torch.cuda.Event() - logger.info("CUDA streams initialized for asynchronous augmentation and computation") - - elif device_type == "xpu" and hasattr(torch, 'xpu') and torch.xpu.is_available(): - # Create XPU streams - self._augmentation_stream = torch.xpu.Stream() - self._compute_stream = torch.xpu.default_stream() - # XPU events might have different API, adjust as needed - if hasattr(torch.xpu, 'Event'): - self._augmentation_event = torch.xpu.Event() - logger.info("XPU streams initialized for asynchronous augmentation and computation") - - else: - logger.info(f"Async streams not available for device type: {device_type}. Using default execution.") - self.enable_async_streams = False - - except Exception as e: - logger.warning(f"Failed to initialize async streams: {e}. Falling back to default execution.") - self.enable_async_streams = False - self._augmentation_stream = None - self._compute_stream = None - self._augmentation_event = None - - @contextmanager - def _augmentation_stream_context(self): - """Context manager for augmentation stream. - - Yields: - Stream context for GPU-based augmentations. - - Example: - ```python - with self._augmentation_stream_context(): - augmented_batch = self.apply_batch_transforms(batch) - ``` - """ - if self._augmentation_stream is not None: - device_type = self.device.type if hasattr(self.device, 'type') else str(self.device) - - if device_type == "cuda": - with torch.cuda.stream(self._augmentation_stream): - yield - elif device_type == "xpu" and hasattr(torch.xpu, 'stream'): - with torch.xpu.stream(self._augmentation_stream): - yield - else: - yield - else: - yield - - def _wait_for_augmentation(self) -> None: - """Wait for augmentation stream to complete before model forward. - - Ensures that augmentation operations are finished before the main - computation stream proceeds with the forward pass. - """ - if self._compute_stream is not None and self._augmentation_stream is not None: - device_type = self.device.type if hasattr(self.device, 'type') else str(self.device) - - if device_type == "cuda": - # Record event in augmentation stream - if self._augmentation_event is not None: - self._augmentation_event.record(self._augmentation_stream) - # Make compute stream wait for augmentation - self._compute_stream.wait_event(self._augmentation_event) - else: - # Fallback: synchronize streams - self._compute_stream.wait_stream(self._augmentation_stream) - - elif device_type == "xpu" and hasattr(torch.xpu, 'current_stream'): - # XPU stream synchronization - if hasattr(self._compute_stream, 'wait_stream'): - self._compute_stream.wait_stream(self._augmentation_stream) - else: - # Fallback: synchronize augmentation stream - if hasattr(self._augmentation_stream, 'synchronize'): - self._augmentation_stream.synchronize() - def configure_optimizers(self) -> OptimizerLRScheduler: """Configure an optimizer and learning-rate schedulers. @@ -510,6 +398,53 @@ def configure_metric(self) -> None: self._metric = metric.to(self.device) + def _configure_batch_augmentation( + self, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, + ) -> tuple[AugmentationSequential, AugmentationSequential]: + """Configure batch augmentation. + + Args: + batch_train_transforms (AugmentationSequential, optional): batch train transforms + batch_val_transforms (AugmentationSequential, optional): batch val transforms + + Returns: + GPU augmentation pipeline or None + """ + if batch_train_transforms is not None: + if not ensure_callable(batch_train_transforms): + msg = "Batch train transforms should be callable. " \ + "Please use kornia AugmentationSequential or torchvision Compose" + + raise TypeError(msg) + + if batch_val_transforms is not None: + if not ensure_callable(batch_val_transforms): + msg = "Batch val transforms should be callable. " \ + "Please use kornia AugmentationSequential or torchvision Compose" + + raise TypeError(msg) + + train_aug_pipeline = batch_train_transforms if batch_train_transforms is not None else self._default_train_transforms + val_aug_pipeline = batch_val_transforms if batch_val_transforms is not None else self._default_val_transforms + + return train_aug_pipeline, val_aug_pipeline + + @torch.no_grad() + @staticmethod + def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: + if augmentations_pipeline is not None: + batch.images = augmentations_pipeline(batch.images) + + @property + def _default_train_transforms(self): + return AugmentationSequential() + + @property + def _default_val_transforms(self): + return AugmentationSequential() + @property def metric(self) -> Metric | MetricCollection: """Metric module for this OTX model.""" @@ -722,14 +657,7 @@ def forward( self, inputs: OTXDataBatch, ) -> OTXPredBatch | OTXBatchLossEntity: - """Model forward function with asynchronous stream support. - - When async streams are enabled, this method: - 1. Applies batch augmentations on a separate stream - 2. Waits for augmentation to complete - 3. Runs model forward on the main compute stream - - This overlaps augmentation with previous batch's backward pass for better GPU utilization. + """Model forward function. Args: inputs: Batch of input data. @@ -737,17 +665,8 @@ def forward( Returns: Model predictions or loss entity depending on training mode. """ - # Apply batch augmentations asynchronously - if self.enable_async_streams and self.training and self._augmentation_stream is not None: - # Run augmentations on separate stream - with self._augmentation_stream_context(): - self.apply_batch_transforms(inputs) - - # Wait for augmentations to complete before forward pass - self._wait_for_augmentation() - else: - # Synchronous execution (validation/test or when streams disabled) - self.apply_batch_transforms(inputs) + # Apply batch augmentations + self.apply_batch_transforms(inputs) # If customize_inputs is overridden if isinstance(inputs, OTXTileBatchDataEntity): @@ -791,12 +710,32 @@ def forward_tiles( @torch.no_grad() def apply_batch_transforms(self, inputs: OTXDataBatch) -> None: - """Apply kornia batch transforms to OTXDataBatch.""" - self.transforms(inputs.images) + """Apply GPU batch transforms to OTXDataBatch. + + Args: + inputs: Batch data to transform + """ + # Use configured GPU pipeline if available + if self._gpu_augmentation_pipeline is not None: + inputs.images = self._gpu_augmentation_pipeline(inputs.images) + return + + # Fall back to deprecated transforms property + if hasattr(self, 'batch_augmentation') and self.batch_augmentation is not None: + inputs.images = self.batch_augmentation(inputs.images) + return + + # Use legacy transforms property + pipeline = self.transforms + if pipeline and len(pipeline) > 0: + inputs.images = pipeline(inputs.images) @property def transforms(self) -> AugmentationSequential: - """Kornia Image Transforms.""" + """Kornia Image Transforms (DEPRECATED). + + Use augmentation_config instead for new code. + """ return AugmentationSequential() def register_load_state_dict_pre_hook(self, model_classes: list[str], ckpt_classes: list[str]) -> None: @@ -1042,6 +981,7 @@ def tile_config(self, tile_config: TileConfig) -> None: self._tile_config = tile_config + @abstractmethod def get_dummy_input(self, batch_size: int = 1) -> OTXDataBatch: """Generates a dummy input, suitable for launching forward() on it. @@ -1051,7 +991,6 @@ def get_dummy_input(self, batch_size: int = 1) -> OTXDataBatch: Returns: TorchDataBatch: A batch containing randomly generated inference data. """ - raise NotImplementedError @staticmethod def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo: @@ -1107,9 +1046,6 @@ def _check_preprocessing_params(self, preprocessing_params: DataInputParams | No raise ValueError(msg) @property + @abstractmethod def task(self) -> OTXTaskType: """Get task type.""" - if self._task is None: - msg = "Task type is not set. Please set the task type before using this model." - raise ValueError(msg) - return self._task diff --git a/library/src/otx/backend/native/models/classification/multiclass_models/base.py b/library/src/otx/backend/native/models/classification/multiclass_models/base.py index 634c9f16ea5..a8d757ef91b 100644 --- a/library/src/otx/backend/native/models/classification/multiclass_models/base.py +++ b/library/src/otx/backend/native/models/classification/multiclass_models/base.py @@ -23,6 +23,8 @@ from otx.types.export import TaskLevelExportParameters from otx.types.label import LabelInfoTypes from otx.types.task import OTXTaskType +from kornia.augmentation.container import AugmentationSequential +from torchvision.transforms.v2 import Compose if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -51,6 +53,9 @@ def __init__( label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams, model_name: str = "multiclass_classification_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, freeze_backbone: bool = False, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, @@ -60,8 +65,10 @@ def __init__( super().__init__( label_info=label_info, data_input_params=data_input_params, - task=OTXTaskType.MULTI_CLASS_CLS, model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, @@ -183,3 +190,7 @@ def forward_explain(self, inputs: OTXDataBatch) -> OTXPredBatch: saliency_map=[saliency_map.to(torch.float32) for saliency_map in outputs["saliency_map"]], feature_vector=[feature_vector.unsqueeze(0) for feature_vector in outputs["feature_vector"]], ) + + @property + def task(self) -> OTXTaskType: + return OTXTaskType.MULTI_CLASS_CLS From 5b20a1a8f338d8c8e3c72e8900f7880be6f666b6 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Wed, 19 Nov 2025 01:50:47 +0900 Subject: [PATCH 5/6] apply kornia pipeline to classificaiton and OD --- library/src/otx/backend/native/models/base.py | 130 ++++++++++-------- .../classification/multiclass_models/base.py | 11 +- .../backend/native/models/detection/atss.py | 2 +- .../backend/native/models/detection/base.py | 48 +++---- .../backend/native/models/detection/yolox.py | 1 - .../src/otx/data/dataset/classification.py | 11 +- library/src/otx/data/dataset/detection.py | 2 + library/src/otx/data/module.py | 6 +- .../recipe/_base_/data/classification.yaml | 99 ++----------- .../src/otx/recipe/_base_/data/detection.yaml | 36 ++--- .../multi_class_cls/efficientnet_b0.yaml | 70 ---------- .../recipe/detection/atss_mobilenetv2.yaml | 5 - 12 files changed, 143 insertions(+), 278 deletions(-) diff --git a/library/src/otx/backend/native/models/base.py b/library/src/otx/backend/native/models/base.py index 425a80e7664..c452ae19ccb 100644 --- a/library/src/otx/backend/native/models/base.py +++ b/library/src/otx/backend/native/models/base.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence import torch +import kornia from datumaro import LabelCategories from lightning import LightningModule, Trainer from torch import Tensor, nn @@ -68,9 +69,9 @@ class DataInputParams: """Parameters of the input data such as input size, mean, and std.""" - input_size: tuple[int, int] - mean: tuple[float, float, float] - std: tuple[float, float, float] + input_size: tuple[int, int] | None = None + mean: tuple[float, float, float] | None = None + std: tuple[float, float, float] | None = None def as_dict(self) -> dict[str, Any]: """Convert to dictionary.""" @@ -78,7 +79,11 @@ def as_dict(self) -> dict[str, Any]: def as_ncwh(self, batch_size: int = 1) -> tuple[int, int, int, int]: """Convert input_size to NCWH format.""" - return (batch_size, 3, *self.input_size) + if self.input_size is not None: + return (batch_size, 3, *self.input_size) + + msg = "input_size should not be None." + raise ValueError(msg) def _default_optimizer_callable(params: params_t) -> Optimizer: @@ -175,16 +180,7 @@ def __init__( self._label_info = self._dispatch_label_info(label_info) self.model_name = model_name - if isinstance(data_input_params, dict): - data_input_params = DataInputParams(**data_input_params) - elif data_input_params is None: - data_input_params = ( - self._default_preprocessing_params[self.model_name] - if isinstance(self._default_preprocessing_params, dict) - else self._default_preprocessing_params - ) - self._check_preprocessing_params(data_input_params) - self.data_input_params = data_input_params + self.data_input_params = self._configure_preprocessing_params(data_input_params) self.model = self._create_model() self.optimizer_callable = ensure_callable(optimizer) self.scheduler_callable = ensure_callable(scheduler) @@ -200,6 +196,11 @@ def __init__( # Augmentation configuration if apply_gpu_transforms: self.batch_train_transforms, self.batch_val_transforms = self._configure_batch_augmentation(batch_train_transforms, batch_val_transforms) + self.batch_train_transforms.to(self.device) + self.batch_val_transforms.to(self.device) + else: + self.batch_train_transforms = None + self.batch_val_transforms = None self.save_hyperparameters( logger=False, @@ -455,19 +456,42 @@ def _configure_batch_augmentation( return train_aug_pipeline, val_aug_pipeline - @torch.no_grad() @staticmethod + @torch.no_grad() def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: if augmentations_pipeline is not None: batch.images = augmentations_pipeline(batch.images) @property def _default_train_transforms(self): - return AugmentationSequential() + if self.task == OTXTaskType.DETECTION: + data_keys = ["input", "bbox"] + elif self.task == OTXTaskType.SEMANTIC_SEGMENTATION: + data_keys = ["input", "mask"] + elif self.task == OTXTaskType.INSTANCE_SEGMENTATION: + data_keys = ["input", "bbox", "mask"] + elif self.task == OTXTaskType.KEYPOINT_DETECTION: + data_keys = ["input", "keypoints"] + else: + data_keys = ["input"] + + return AugmentationSequential(kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=data_keys) @property def _default_val_transforms(self): - return AugmentationSequential() + if self.task == OTXTaskType.DETECTION: + data_keys = ["input", "bbox"] + elif self.task == OTXTaskType.SEMANTIC_SEGMENTATION: + data_keys = ["input", "mask"] + elif self.task == OTXTaskType.INSTANCE_SEGMENTATION: + data_keys = ["input", "bbox", "mask"] + elif self.task == OTXTaskType.KEYPOINT_DETECTION: + data_keys = ["input", "keypoints"] + else: + data_keys = ["input"] + + return AugmentationSequential(kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=data_keys) @property def metric(self) -> Metric | MetricCollection: @@ -730,28 +754,6 @@ def forward_tiles( """Model forward function for tile task.""" raise NotImplementedError - @torch.no_grad() - def apply_batch_transforms(self, inputs: OTXDataBatch) -> None: - """Apply GPU batch transforms to OTXDataBatch. - - Args: - inputs: Batch data to transform - """ - # Use configured GPU pipeline if available - if self._gpu_augmentation_pipeline is not None: - inputs.images = self._gpu_augmentation_pipeline(inputs.images) - return - - # Fall back to deprecated transforms property - if hasattr(self, 'batch_augmentation') and self.batch_augmentation is not None: - inputs.images = self.batch_augmentation(inputs.images) - return - - # Use legacy transforms property - pipeline = self.transforms - if pipeline and len(pipeline) > 0: - inputs.images = pipeline(inputs.images) - @property def transforms(self) -> AugmentationSequential: """Kornia Image Transforms (DEPRECATED). @@ -1034,36 +1036,50 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo: raise TypeError(label_info) - def _check_preprocessing_params(self, preprocessing_params: DataInputParams | None) -> None: + def _configure_preprocessing_params(self, preprocessing_params: DataInputParams | None) -> DataInputParams: """Check the validity of the preprocessing parameters.""" - if preprocessing_params is None: - msg = "Data input parameters should not be None." - raise ValueError(msg) + if isinstance(preprocessing_params, dict): + data_input_params = DataInputParams(**preprocessing_params) + elif isinstance(preprocessing_params, DataInputParams): + data_input_params = preprocessing_params + else: + # `preprocessing_params` is None + data_input_params = DataInputParams() - input_size = preprocessing_params.input_size - mean = preprocessing_params.mean - std = preprocessing_params.std + default_data_input_params = ( + self._default_preprocessing_params[self.model_name] + if isinstance(self._default_preprocessing_params, dict) + else self._default_preprocessing_params + ) - if not (len(mean) == 3 and all(isinstance(m, float) for m in mean)): - msg = f"Mean should be a tuple of 3 float values, but got {mean} instead." + # Assign default values if not given in `preprocessing_params` + data_input_params.input_size = data_input_params.input_size or default_data_input_params.input_size + data_input_params.mean = data_input_params.mean or default_data_input_params.mean + data_input_params.std = data_input_params.std or default_data_input_params.std + + # Validate + if not (len(data_input_params.mean) == 3 and all(isinstance(m, float) for m in data_input_params.mean)): + msg = f"Mean should be a tuple of 3 float values, but got {data_input_params.mean} instead." raise ValueError(msg) - if not (len(std) == 3 and all(isinstance(s, float) for s in std)): - msg = f"Std should be a tuple of 3 float values, but got {std} instead." + if not (len(data_input_params.std) == 3 and all(isinstance(s, float) for s in data_input_params.std)): + msg = f"Std should be a tuple of 3 float values, but got {data_input_params.std} instead." raise ValueError(msg) - if not all(0 <= m <= 255 for m in mean): - msg = f"Mean values should be in the range [0, 255], but got {mean} instead." + if not all(0 <= m <= 255 for m in data_input_params.mean): + msg = f"Mean values should be in the range [0, 255], but got {data_input_params.mean} instead." raise ValueError(msg) - if not all(0 <= s <= 255 for s in std): - msg = f"Std values should be in the range [0, 255], but got {std} instead." + if not all(0 <= s <= 255 for s in data_input_params.std): + msg = f"Std values should be in the range [0, 255], but got {data_input_params.std} instead." raise ValueError(msg) - if input_size is not None and ( - input_size[0] % self.input_size_multiplier != 0 or input_size[1] % self.input_size_multiplier != 0 + if data_input_params.input_size is not None and ( + data_input_params.input_size[0] % self.input_size_multiplier != 0 or data_input_params.input_size[1] % self.input_size_multiplier != 0 ): - msg = f"Input size should be a multiple of {self.input_size_multiplier}, but got {input_size} instead." + msg = f"Input size should be a multiple of {self.input_size_multiplier}, but got {data_input_params.input_size} instead." raise ValueError(msg) + return data_input_params + @property @abstractmethod def task(self) -> OTXTaskType: diff --git a/library/src/otx/backend/native/models/classification/multiclass_models/base.py b/library/src/otx/backend/native/models/classification/multiclass_models/base.py index 0caaacc5b88..42899833e3d 100644 --- a/library/src/otx/backend/native/models/classification/multiclass_models/base.py +++ b/library/src/otx/backend/native/models/classification/multiclass_models/base.py @@ -9,6 +9,7 @@ import torch from torch import Tensor +import kornia from otx.backend.native.exporter.base import OTXModelExporter from otx.backend.native.exporter.native import OTXNativeModelExporter @@ -24,6 +25,8 @@ from otx.types.label import LabelInfoTypes from otx.types.task import OTXTaskType from kornia.augmentation.container import AugmentationSequential +from kornia.augmentation import Normalize +from kornia.augmentation.auto import AutoAugment from torchvision.transforms.v2 import Compose if TYPE_CHECKING: @@ -128,6 +131,12 @@ def _customize_outputs( scores=list(scores), ) + @property + def _default_train_transforms(self): + return AugmentationSequential(kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1), + Normalize(self.data_input_params.mean, self.data_input_params.std)) + @property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" @@ -198,4 +207,4 @@ def task(self) -> OTXTaskType: @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(224, 224), mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375)) + return DataInputParams(input_size=(224, 224), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) diff --git a/library/src/otx/backend/native/models/detection/atss.py b/library/src/otx/backend/native/models/detection/atss.py index c8de1396a22..9a5fc7df222 100644 --- a/library/src/otx/backend/native/models/detection/atss.py +++ b/library/src/otx/backend/native/models/detection/atss.py @@ -204,4 +204,4 @@ def _exporter(self) -> OTXModelExporter: @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(800, 992), mean=(0.0, 0.0, 0.0), std=(255.0, 255.0, 255.0)) + return DataInputParams(input_size=(800, 992), mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)) diff --git a/library/src/otx/backend/native/models/detection/base.py b/library/src/otx/backend/native/models/detection/base.py index 233d2306d70..f4c565fc89f 100644 --- a/library/src/otx/backend/native/models/detection/base.py +++ b/library/src/otx/backend/native/models/detection/base.py @@ -17,6 +17,7 @@ from torchvision import tv_tensors import kornia from kornia.geometry.boxes import Boxes +from kornia.augmentation.container import AugmentationSequential from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel from otx.backend.native.models.utils.utils import InstanceData @@ -73,6 +74,9 @@ def __init__( label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | dict | None = None, model_name: str = "otx_detection_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MeanAveragePrecisionFMeasureCallable, @@ -82,9 +86,11 @@ def __init__( ) -> None: super().__init__( label_info=label_info, - model_name=model_name, - task=OTXTaskType.DETECTION, data_input_params=data_input_params, + model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, @@ -265,29 +271,6 @@ def _customize_outputs( labels=labels, ) - @torch.no_grad() - def apply_batch_transforms(self, inputs: OTXDataBatch) -> types.NoneType: - """Apply batch augmentations to Object Detection.""" - # Convert bounding boxes to Kornia Boxes [N, 4, 2] - kornia_boxes = Boxes.from_tensor(inputs.bboxes, mode='xyxy') - self.transforms(inputs.images, kornia_boxes) - inputs.bboxes = kornia_boxes.to_tensor(mode='xyxy') - - @property - def transforms(self): - if self.training: - return kornia.augmentation.AugmentationSequential( - kornia.augmentation.RandomHorizontalFlip(), - kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), - data_keys=["input", "bbox"], - same_on_batch=False - ) - return kornia.augmentation.AugmentationSequential( - kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), - data_keys=["input", "bbox"], - ) - - def forward_tiles(self, inputs: OTXTileBatchDataEntity) -> OTXPredBatch: """Unpack detection tiles. @@ -570,6 +553,19 @@ def get_num_anchors(self) -> list[int]: return [1] * 10 + @staticmethod + @torch.no_grad() + def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: + if augmentations_pipeline is not None: + # Convert bounding boxes to Kornia Boxes [N, 4, 2] + kornia_boxes = Boxes.from_tensor(batch.bboxes, mode='xyxy') + batch.images, kornia_boxes = augmentations_pipeline(batch.images, kornia_boxes) + batch.bboxes = kornia_boxes.to_tensor(mode='xyxy') + @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(640, 640), mean=(0.0, 0.0, 0.0), std=(255.0, 255.0, 255.0)) + return DataInputParams(input_size=(640, 640), mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)) + + @property + def task(self) -> OTXTaskType: + return OTXTaskType.DETECTION diff --git a/library/src/otx/backend/native/models/detection/yolox.py b/library/src/otx/backend/native/models/detection/yolox.py index 33ebf2ababc..5448f079954 100644 --- a/library/src/otx/backend/native/models/detection/yolox.py +++ b/library/src/otx/backend/native/models/detection/yolox.py @@ -24,7 +24,6 @@ from otx.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable from otx.types.export import OTXExportFormatType from otx.types.precision import OTXPrecisionType -import kornia if TYPE_CHECKING: from pathlib import Path diff --git a/library/src/otx/data/dataset/classification.py b/library/src/otx/data/dataset/classification.py index 9a9fb8cc439..fea85890d86 100644 --- a/library/src/otx/data/dataset/classification.py +++ b/library/src/otx/data/dataset/classification.py @@ -90,7 +90,9 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: img = item.media_as(Image) roi = item.attributes.get("roi", None) img_data, img_shape, _ = self._get_img_data_and_shape(img, roi) - image = to_dtype(to_image(img_data), dtype=torch.float32) + image = to_dtype(to_image(img_data), scale=True, dtype=torch.float32) + image.clamp_(0, 1) + if roi: # extract labels from ROI labels_ids = [ @@ -118,7 +120,12 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: image_color_channel=self.image_color_channel, ), ) - return self._apply_transforms(entity) + entity = self._apply_transforms(entity) + + image = entity.image + image.clamp_(0, 1) + + return entity @property def task_type(self) -> OTXTaskType: diff --git a/library/src/otx/data/dataset/detection.py b/library/src/otx/data/dataset/detection.py index 0a546b9438a..fd07f17e835 100644 --- a/library/src/otx/data/dataset/detection.py +++ b/library/src/otx/data/dataset/detection.py @@ -16,6 +16,7 @@ from otx.data.entity.torch import OTXDataItem from otx.types import OTXTaskType from otx.types.image import ImageColorChannel +from torchvision.transforms.v2.functional import to_dtype, to_image from .base import OTXDataset, Transforms from .mixins import DataAugSwitchMixin @@ -84,6 +85,7 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item img_data, img_shape, _ = self._get_img_data_and_shape(img) + img_data = to_dtype(to_image(img_data), scale=True, dtype=torch.float32) bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] diff --git a/library/src/otx/data/module.py b/library/src/otx/data/module.py index 47a5184c8d4..71d86531433 100644 --- a/library/src/otx/data/module.py +++ b/library/src/otx/data/module.py @@ -193,7 +193,7 @@ def _setup_otx_dataset(self, dataset: DmDataset) -> None: def extract_normalization_params( self, transforms_source: list | None - ) -> tuple[tuple[float, float, float], tuple[float, float, float]]: + ) -> tuple[tuple[float, float, float] | None, tuple[float, float, float]] | None: """Extract mean and std from transforms. Args: @@ -202,8 +202,8 @@ def extract_normalization_params( Returns: Tuple of (mean, std) tuples. """ - mean = (0.0, 0.0, 0.0) - std = (1.0, 1.0, 1.0) + mean = None + std = None if transforms_source is not None: for transform in transforms_source: diff --git a/library/src/otx/recipe/_base_/data/classification.yaml b/library/src/otx/recipe/_base_/data/classification.yaml index 5017fa97868..c4104ee4b97 100644 --- a/library/src/otx/recipe/_base_/data/classification.yaml +++ b/library/src/otx/recipe/_base_/data/classification.yaml @@ -9,73 +9,18 @@ train_subset: subset_name: train transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 2 - to_tv_image: false + num_workers: 4 + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.RandomResizedCrop + - class_path: torchvision.transforms.v2.RandomResizedCrop init_args: - scale: $(input_size) - crop_ratio_range: + size: $(input_size) + scale: - 0.08 - 1.0 - aspect_ratio_range: + ratio: - 0.75 - 1.34 - # - class_path: torchvision.transforms.v2.RandomPhotometricDistort - # enable: false - # init_args: - # brightness: - # - 0.875 - # - 1.125 - # contrast: - # - 0.5 - # - 1.5 - # saturation: - # - 0.5 - # - 1.5 - # hue: - # - -0.05 - # - 0.05 - # p: 0.5 - # - class_path: otx.data.transform_libs.torchvision.RandomAffine - # enable: false - # init_args: - # max_rotate_degree: 10.0 - # max_translate_ratio: 0.1 - # scaling_ratio_range: - # - 0.5 - # - 1.5 - # max_shear_degree: 2.0 - # - class_path: otx.data.transform_libs.torchvision.RandomFlip - # enable: true - # init_args: - # probability: 0.5 - # - class_path: torchvision.transforms.v2.RandomVerticalFlip - # enable: false - # init_args: - # p: 0.5 - # - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - # enable: false - # init_args: - # kernel_size: 5 - # sigma: - # - 0.1 - # - 2.0 - # probability: 0.5 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: false - # - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - # enable: false - # init_args: - # mean: 0.0 - # sigma: 0.1 - # probability: 0.5 - # - class_path: torchvision.transforms.v2.Normalize - # init_args: - # mean: [123.675, 116.28, 103.53] - # std: [58.395, 57.12, 57.375] sampler: class_path: otx.data.samplers.balanced_sampler.BalancedSampler init_args: null @@ -84,20 +29,12 @@ val_subset: subset_name: val transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 2 - to_tv_image: false + num_workers: 4 + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - scale: $(input_size) - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: false - - class_path: torchvision.transforms.v2.Normalize + - class_path: torchvision.transforms.v2.Resize init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -106,20 +43,12 @@ test_subset: subset_name: test transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 2 - to_tv_image: false + num_workers: 4 + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - scale: $(input_size) - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: false - - class_path: torchvision.transforms.v2.Normalize + - class_path: torchvision.transforms.v2.Resize init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null diff --git a/library/src/otx/recipe/_base_/data/detection.yaml b/library/src/otx/recipe/_base_/data/detection.yaml index da59453a832..3eac60cc796 100644 --- a/library/src/otx/recipe/_base_/data/detection.yaml +++ b/library/src/otx/recipe/_base_/data/detection.yaml @@ -8,9 +8,9 @@ unannotated_items_ratio: 0.0 train_subset: subset_name: train transform_lib_type: TORCHVISION - batch_size: 1 + batch_size: 8 num_workers: 4 - to_tv_image: true + to_tv_image: false transforms: - class_path: torchvision.transforms.v2.RandomIoUCrop - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes @@ -19,11 +19,6 @@ train_subset: - class_path: torchvision.transforms.v2.Resize init_args: size: $(input_size) - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: true - sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -31,20 +26,13 @@ train_subset: val_subset: subset_name: val transform_lib_type: TORCHVISION - batch_size: 1 + batch_size: 8 num_workers: 4 - to_tv_image: false + to_tv_image: true transforms: - class_path: torchvision.transforms.v2.Resize init_args: size: $(input_size) - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [0.0, 0.0, 0.0] - std: [255.0, 255.0, 255.0] sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -52,20 +40,14 @@ val_subset: test_subset: subset_name: test transform_lib_type: TORCHVISION - batch_size: 1 + batch_size: 8 num_workers: 4 - to_tv_image: false + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - scale: $(input_size) - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize + - class_path: torchvision.transforms.v2.Resize init_args: - mean: [0.0, 0.0, 0.0] - std: [255.0, 255.0, 255.0] + size: $(input_size) + sampler: class_path: torch.utils.data.RandomSampler init_args: null diff --git a/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml b/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml index d70a22cfce1..4e1706c2b23 100644 --- a/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml +++ b/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml @@ -46,74 +46,4 @@ callbacks: filename: "checkpoints/epoch_{epoch:03d}" overrides: - reset: - - data.train_subset.transforms max_epochs: 90 - - data: - train_subset: - to_tv_image: true - transforms: - - class_path: torchvision.transforms.v2.RandomResizedCrop - init_args: - size: $(input_size) - # crop_ratio_range: - # - 0.08 - # - 1.0 - # aspect_ratio_range: - # - 0.75 - # - 1.34 - # - class_path: torchvision.transforms.v2.RandomPhotometricDistort - # enable: false - # init_args: - # brightness: - # - 0.875 - # - 1.125 - # contrast: - # - 0.5 - # - 1.5 - # saturation: - # - 0.5 - # - 1.5 - # hue: - # - -0.05 - # - 0.05 - # p: 0.5 - # - class_path: otx.data.transform_libs.torchvision.RandomAffine - # enable: false - # init_args: - # max_rotate_degree: 10.0 - # max_translate_ratio: 0.1 - # scaling_ratio_range: - # - 0.5 - # - 1.5 - # max_shear_degree: 2.0 - # - class_path: otx.data.transform_libs.torchvision.RandomFlip - # enable: true - # init_args: - # probability: 0.5 - # - class_path: torchvision.transforms.v2.RandomVerticalFlip - # enable: false - # init_args: - # p: 0.5 - # - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - # enable: false - # init_args: - # kernel_size: 5 - # sigma: - # - 0.1 - # - 2.0 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: true - # - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - # enable: false - # init_args: - # mean: 0.0 - # sigma: 0.1 - # probability: 0.5 - # - class_path: torchvision.transforms.v2.Normalize - # init_args: - # mean: [123.675, 116.28, 103.53] - # std: [58.395, 57.12, 57.375] diff --git a/library/src/otx/recipe/detection/atss_mobilenetv2.yaml b/library/src/otx/recipe/detection/atss_mobilenetv2.yaml index 272b8035783..1ebc4c87d32 100644 --- a/library/src/otx/recipe/detection/atss_mobilenetv2.yaml +++ b/library/src/otx/recipe/detection/atss_mobilenetv2.yaml @@ -4,11 +4,6 @@ model: init_args: model_name: atss_mobilenetv2 label_info: 80 - data_input_params: - input_size: [800, 992] - mean: [0.0, 0.0, 0.0] - std: [255.0, 255.0, 255.0] - optimizer: class_path: torch.optim.SGD init_args: From b89f7dc0c1c76b1028e21b54162590dd2e2328f0 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Thu, 20 Nov 2025 03:25:21 +0900 Subject: [PATCH 6/6] added pipeline to all base models --- library/src/otx/backend/native/models/base.py | 4 +- .../classification/hlabel_models/base.py | 9 ++ .../classification/multiclass_models/base.py | 6 + .../classification/multilabel_models/base.py | 13 +- .../models/instance_segmentation/base.py | 33 +++++- .../native/models/keypoint_detection/base.py | 34 +++++- .../models/keypoint_detection/rtmpose.py | 3 - .../native/models/segmentation/base.py | 28 ++++- library/src/otx/data/dataset/base.py | 3 + .../src/otx/data/dataset/classification.py | 2 - library/src/otx/data/dataset/detection.py | 2 - .../otx/data/dataset/instance_segmentation.py | 2 +- .../otx/data/dataset/keypoint_detection.py | 2 +- library/src/otx/data/dataset/segmentation.py | 2 +- .../otx/data/transform_libs/torchvision.py | 3 +- .../_base_/data/instance_segmentation.yaml | 111 ++---------------- .../_base_/data/keypoint_detection.yaml | 62 ++-------- .../_base_/data/semantic_segmentation.yaml | 65 ---------- 18 files changed, 149 insertions(+), 235 deletions(-) diff --git a/library/src/otx/backend/native/models/base.py b/library/src/otx/backend/native/models/base.py index c452ae19ccb..9f5e9a8f5ac 100644 --- a/library/src/otx/backend/native/models/base.py +++ b/library/src/otx/backend/native/models/base.py @@ -476,7 +476,7 @@ def _default_train_transforms(self): data_keys = ["input"] return AugmentationSequential(kornia.augmentation.RandomHorizontalFlip(), - kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=data_keys) + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=data_keys, keepdim=True) @property def _default_val_transforms(self): @@ -491,7 +491,7 @@ def _default_val_transforms(self): else: data_keys = ["input"] - return AugmentationSequential(kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=data_keys) + return AugmentationSequential(kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=data_keys, keepdim=True) @property def metric(self) -> Metric | MetricCollection: diff --git a/library/src/otx/backend/native/models/classification/hlabel_models/base.py b/library/src/otx/backend/native/models/classification/hlabel_models/base.py index 1c8a0e8829d..7986bccc8f0 100644 --- a/library/src/otx/backend/native/models/classification/hlabel_models/base.py +++ b/library/src/otx/backend/native/models/classification/hlabel_models/base.py @@ -41,6 +41,12 @@ class OTXHlabelClsModel(OTXModel): data_input_params (DataInputParams | None, optional): Parameters for image data preprocessing. If None is given, default parameters for the specific model will be used. model_name (str, optional): Name of the model. Defaults to "hlabel_classification_model". + apply_gpu_transforms (bool, optional): Flag to indicate whether to apply GPU transforms. + It is recommended to use GPU transforms. Defaults to True. + batch_train_transforms (AugmentationSequential | Compose | None): GPU transforms for training applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. + batch_val_transforms (AugmentationSequential | Compose | None): GPU transforms for validation / testing applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. Typically just normalization. optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable. @@ -55,6 +61,9 @@ def __init__( label_info: HLabelInfo, data_input_params: DataInputParams | None = None, model_name: str = "hlabel_classification_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, freeze_backbone: bool = False, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, diff --git a/library/src/otx/backend/native/models/classification/multiclass_models/base.py b/library/src/otx/backend/native/models/classification/multiclass_models/base.py index 42899833e3d..db0ea4f8025 100644 --- a/library/src/otx/backend/native/models/classification/multiclass_models/base.py +++ b/library/src/otx/backend/native/models/classification/multiclass_models/base.py @@ -45,6 +45,12 @@ class OTXMulticlassClsModel(OTXModel): data_input_params (DataInputParams | None, optional): Parameters for the image data preprocessing. If None is given, default parameters for the specific model will be used. model_name (str, optional): Name of the model. Defaults to "multiclass_classification_model". + apply_gpu_transforms (bool, optional): Flag to indicate whether to apply GPU transforms. + It is recommended to use GPU transforms. Defaults to True. + batch_train_transforms (AugmentationSequential | Compose | None): GPU transforms for training applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. + batch_val_transforms (AugmentationSequential | Compose | None): GPU transforms for validation / testing applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. Typically just normalization. optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable. diff --git a/library/src/otx/backend/native/models/classification/multilabel_models/base.py b/library/src/otx/backend/native/models/classification/multilabel_models/base.py index 3406588fd8b..92f60f3b3a9 100644 --- a/library/src/otx/backend/native/models/classification/multilabel_models/base.py +++ b/library/src/otx/backend/native/models/classification/multilabel_models/base.py @@ -38,6 +38,12 @@ class OTXMultilabelClsModel(OTXModel): if `Sequence` is given, label info will be constructed from the sequence of label names. data_input_params (DataInputParams | None, optional): Parameters for the image data preprocessing. model_name (str, optional): Name of the model. Defaults to "multilabel_classification_model". + apply_gpu_transforms (bool, optional): Flag to indicate whether to apply GPU transforms. + It is recommended to use GPU transforms. Defaults to True. + batch_train_transforms (AugmentationSequential | Compose | None): GPU transforms for training applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. + batch_val_transforms (AugmentationSequential | Compose | None): GPU transforms for validation / testing applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. Typically just normalization. optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable. @@ -50,6 +56,9 @@ def __init__( label_info: LabelInfoTypes | Sequence, data_input_params: DataInputParams | None = None, model_name: str = "multiclass_classification_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, freeze_backbone: bool = False, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, @@ -59,8 +68,10 @@ def __init__( super().__init__( label_info=label_info, data_input_params=data_input_params, - task=OTXTaskType.MULTI_LABEL_CLS, model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, diff --git a/library/src/otx/backend/native/models/instance_segmentation/base.py b/library/src/otx/backend/native/models/instance_segmentation/base.py index 805e8538351..a3ddc5d010e 100644 --- a/library/src/otx/backend/native/models/instance_segmentation/base.py +++ b/library/src/otx/backend/native/models/instance_segmentation/base.py @@ -18,6 +18,8 @@ from torchmetrics import Metric, MetricCollection from torchvision import tv_tensors from torchvision.models.detection.image_list import ImageList +from kornia.geometry.boxes import Boxes +from kornia.augmentation.container import AugmentationSequential from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel from otx.backend.native.models.instance_segmentation.segmentors.maskrcnn_tv import MaskRCNN @@ -61,6 +63,12 @@ class OTXInstanceSegModel(OTXModel): data_input_params (DataInputParams | None, optional): Parameters for the image data preprocessing. If None is given, default parameters for the specific model will be used. model_name (str, optional): Name of the model. Defaults to "inst_segm_model". + apply_gpu_transforms (bool, optional): Flag to indicate whether to apply GPU transforms. + It is recommended to use GPU transforms. Defaults to True. + batch_train_transforms (AugmentationSequential | Compose | None): GPU transforms for training applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. + batch_val_transforms (AugmentationSequential | Compose | None): GPU transforms for validation / testing applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. Typically just normalization. optimizer (OptimizerCallable, optional): Optimizer for the model. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Scheduler for the model. Defaults to DefaultSchedulerCallable. @@ -76,6 +84,9 @@ def __init__( label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | None = None, model_name: str = "inst_segm_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MaskRLEMeanAPFMeasureCallable, @@ -85,8 +96,10 @@ def __init__( super().__init__( label_info=label_info, data_input_params=data_input_params, - task=OTXTaskType.INSTANCE_SEGMENTATION, model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, @@ -616,6 +629,22 @@ def _restore_model_forward(self) -> None: self.model.forward = func_type(self.original_model_forward, self.model) self.original_model_forward = None + @staticmethod + @torch.no_grad() + def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: + if augmentations_pipeline is not None: + # Convert bounding boxes to Kornia Boxes [N, 4, 2] + kornia_boxes = Boxes.from_tensor(batch.bboxes, mode='xyxy') + breakpoint() + batch.images, kornia_boxes, masks = augmentations_pipeline(batch.images, kornia_boxes, batch.masks) + batch.bboxes = kornia_boxes.to_tensor(mode='xyxy') + breakpoint() + batch.masks = masks + @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(1024, 1024), mean=(103.53, 116.28, 123.675), std=(57.375, 57.12, 58.395)) + return DataInputParams(input_size=(1024, 1024), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + + @property + def task(self) -> OTXTaskType: + return OTXTaskType.INSTANCE_SEGMENTATION diff --git a/library/src/otx/backend/native/models/keypoint_detection/base.py b/library/src/otx/backend/native/models/keypoint_detection/base.py index 1f9cffccb16..983be62e7b7 100644 --- a/library/src/otx/backend/native/models/keypoint_detection/base.py +++ b/library/src/otx/backend/native/models/keypoint_detection/base.py @@ -10,7 +10,9 @@ from typing import TYPE_CHECKING, Any, Sequence import torch +import kornia +from kornia.augmentation.container import AugmentationSequential from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel from otx.backend.native.schedulers import LRSchedulerListCallable from otx.data.entity.base import ImageInfo, OTXBatchLossEntity @@ -48,6 +50,9 @@ def __init__( label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | None = None, model_name: str = "keypoint_detection_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = PCKMeasureCallable, @@ -56,8 +61,10 @@ def __init__( super().__init__( label_info=label_info, data_input_params=data_input_params, - task=OTXTaskType.KEYPOINT_DETECTION, model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, @@ -209,6 +216,29 @@ def _export_parameters(self) -> TaskLevelExportParameters: confidence_threshold=self.hparams.get("best_confidence_threshold", None), ) + @staticmethod + @torch.no_grad() + def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: + if augmentations_pipeline is not None: + stacked_kps = torch.stack(batch.keypoints) + # Apply augmentations + batch.images, augmented_kps = augmentations_pipeline(batch.images, stacked_kps[:, :, :2]) + stacked_kps[:, :, :2] = augmented_kps + h, w = batch.images.shape[-2:] + # Compute visible mask. Keypoints should be visible if they are inside the image (>=0, x<=w, y<=h) + visible_mask = (augmented_kps > 0).all(axis=2) * (augmented_kps[:, :, 0] <= w) * (augmented_kps[:, :, 1] <= h) + stacked_kps[:, :, 2] = stacked_kps[:, :, 2] * visible_mask + # Update visible keypoints with augmented values + batch.keypoints = [kps for kps in stacked_kps] + + @property + def _default_train_transforms(self): + return AugmentationSequential(kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=["input", "keypoints"], keepdim=True) + @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(640, 640), mean=(0.0, 0.0, 0.0), std=(255.0, 255.0, 255.0)) + return DataInputParams(input_size=(512, 512), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + + @property + def task(self) -> OTXTaskType: + return OTXTaskType.KEYPOINT_DETECTION diff --git a/library/src/otx/backend/native/models/keypoint_detection/rtmpose.py b/library/src/otx/backend/native/models/keypoint_detection/rtmpose.py index 7a96218b91c..d3f359a3a24 100644 --- a/library/src/otx/backend/native/models/keypoint_detection/rtmpose.py +++ b/library/src/otx/backend/native/models/keypoint_detection/rtmpose.py @@ -35,9 +35,6 @@ class RTMPose(OTXKeypointDetectionModel): "rtmpose_tiny": "https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/cspnext-tiny_udp-aic-coco_210e-256x192-cbed682d_20230130.pth", } - _default_preprocessing_params: ClassVar[dict[str, DataInputParams] | DataInputParams] = { - "rtmpose_tiny": DataInputParams(input_size=(640, 640), mean=(0.0, 0.0, 0.0), std=(255.0, 255.0, 255.0)), - } def __init__( self, diff --git a/library/src/otx/backend/native/models/segmentation/base.py b/library/src/otx/backend/native/models/segmentation/base.py index a5e5577b22c..9689bafbd6b 100644 --- a/library/src/otx/backend/native/models/segmentation/base.py +++ b/library/src/otx/backend/native/models/segmentation/base.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any import torch +import kornia import torch.nn.functional as f from torchvision import tv_tensors @@ -20,6 +21,7 @@ from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel from otx.backend.native.schedulers import LRSchedulerListCallable from otx.backend.native.tools.tile_merge import SegmentationTileMerge +from kornia.augmentation.container import AugmentationSequential from otx.config.data import TileConfig from otx.data.entity.base import ImageInfo, OTXBatchLossEntity from otx.data.entity.tile import OTXTileBatchDataEntity @@ -59,6 +61,9 @@ def __init__( label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | None = None, model_name: str = "otx_segmentation_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = SegmCallable, # type: ignore[assignment] @@ -68,8 +73,10 @@ def __init__( super().__init__( label_info=label_info, data_input_params=data_input_params, - task=OTXTaskType.SEMANTIC_SEGMENTATION, model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, @@ -297,6 +304,23 @@ def get_dummy_input(self, batch_size: int = 1) -> OTXDataBatch: # type: ignore[ ) return OTXDataBatch(batch_size, images, imgs_info=infos, masks=[]) # type: ignore[arg-type] + @staticmethod + @torch.no_grad() + def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: + if augmentations_pipeline is not None: + batch.images, batch.masks = augmentations_pipeline(batch.images, batch.masks) + + @property + def _default_train_transforms(self): + return AugmentationSequential( + kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1), + kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=["input", "mask"], keepdim=True) + @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(512, 512), mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375)) + return DataInputParams(input_size=(512, 512), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + + @property + def task(self) -> OTXTaskType: + return OTXTaskType.SEMANTIC_SEGMENTATION diff --git a/library/src/otx/data/dataset/base.py b/library/src/otx/data/dataset/base.py index d8c3ad736ed..3893466b51c 100644 --- a/library/src/otx/data/dataset/base.py +++ b/library/src/otx/data/dataset/base.py @@ -9,9 +9,11 @@ from collections.abc import Iterable from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Union +from torchvision.transforms.v2.functional import to_dtype, to_image import cv2 import numpy as np +import torch from datumaro.components.annotation import AnnotationType from datumaro.util.image import IMAGE_BACKEND, IMAGE_COLOR_CHANNEL, ImageBackend from datumaro.util.image import ImageColorChannel as DatumaroImageColorChannel @@ -194,6 +196,7 @@ def _get_img_data_and_shape( img_data = img_data[y1:y2, x1:x2] roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)} + img_data = to_dtype(to_image(img_data), scale=True, dtype=torch.float32).clamp_(0, 1) return img_data, img_data.shape[:2], roi_meta @abstractmethod diff --git a/library/src/otx/data/dataset/classification.py b/library/src/otx/data/dataset/classification.py index fea85890d86..5d911b545d6 100644 --- a/library/src/otx/data/dataset/classification.py +++ b/library/src/otx/data/dataset/classification.py @@ -90,8 +90,6 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: img = item.media_as(Image) roi = item.attributes.get("roi", None) img_data, img_shape, _ = self._get_img_data_and_shape(img, roi) - image = to_dtype(to_image(img_data), scale=True, dtype=torch.float32) - image.clamp_(0, 1) if roi: # extract labels from ROI diff --git a/library/src/otx/data/dataset/detection.py b/library/src/otx/data/dataset/detection.py index fd07f17e835..0a546b9438a 100644 --- a/library/src/otx/data/dataset/detection.py +++ b/library/src/otx/data/dataset/detection.py @@ -16,7 +16,6 @@ from otx.data.entity.torch import OTXDataItem from otx.types import OTXTaskType from otx.types.image import ImageColorChannel -from torchvision.transforms.v2.functional import to_dtype, to_image from .base import OTXDataset, Transforms from .mixins import DataAugSwitchMixin @@ -85,7 +84,6 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item img_data, img_shape, _ = self._get_img_data_and_shape(img) - img_data = to_dtype(to_image(img_data), scale=True, dtype=torch.float32) bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] diff --git a/library/src/otx/data/dataset/instance_segmentation.py b/library/src/otx/data/dataset/instance_segmentation.py index afb25d4b4bc..4bc33c2df73 100644 --- a/library/src/otx/data/dataset/instance_segmentation.py +++ b/library/src/otx/data/dataset/instance_segmentation.py @@ -119,8 +119,8 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: warnings.warn(f"No valid annotations found for image {item.id}!", stacklevel=2) bboxes = np.stack(gt_bboxes, dtype=np.float32, axis=0) if gt_bboxes else np.empty((0, 4)) + # TODO(@kprokofi): NO MASKS!!! masks = np.stack(gt_masks, axis=0) if gt_masks else np.zeros((0, *img_shape), dtype=bool) - labels = np.array(gt_labels, dtype=np.int64) entity = OTXDataItem( diff --git a/library/src/otx/data/dataset/keypoint_detection.py b/library/src/otx/data/dataset/keypoint_detection.py index 0589d53dbfd..9713fc14e7a 100644 --- a/library/src/otx/data/dataset/keypoint_detection.py +++ b/library/src/otx/data/dataset/keypoint_detection.py @@ -147,7 +147,7 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: keypoints = np.hstack((keypoints, keypoints_visible.reshape(-1, 1))) entity = OTXDataItem( - image=to_dtype(to_image(img_data), torch.float32), + image=img_data, img_info=ImageInfo( img_idx=index, img_shape=img_shape, diff --git a/library/src/otx/data/dataset/segmentation.py b/library/src/otx/data/dataset/segmentation.py index e25c9be61f1..01958158cf8 100644 --- a/library/src/otx/data/dataset/segmentation.py +++ b/library/src/otx/data/dataset/segmentation.py @@ -245,7 +245,7 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: masks = tv_tensors.Mask(extracted_mask[None], dtype=torch.long) entity = OTXDataItem( - image=to_dtype(to_image(img_data), dtype=torch.float32), + image=img_data, img_info=ImageInfo( img_idx=index, img_shape=img_shape, diff --git a/library/src/otx/data/transform_libs/torchvision.py b/library/src/otx/data/transform_libs/torchvision.py index cb56b2ab871..42d9f8f3b69 100644 --- a/library/src/otx/data/transform_libs/torchvision.py +++ b/library/src/otx/data/transform_libs/torchvision.py @@ -2969,7 +2969,8 @@ def __call__(self, *_inputs: OTXDataItem) -> OTXDataItem | None: inputs.keypoints = torch.zeros([]) else: # update keypoints_visible after affine transforms - inputs.keypoints[:, 2] = inputs.keypoints[:, 2] * (inputs.keypoints[:, :2] > 0).all(axis=1) + # update keypoints_visible. Keypoints should be visible if they are inside the image (>=0, x<=w, y<=h) + inputs.keypoints[:, 2] = inputs.keypoints[:, 2] * (inputs.keypoints[:, :2] >= 0).all(axis=1) * (inputs.keypoints[:, 0] <= w) * (inputs.keypoints[:, 1] <= h) return self.convert(inputs) diff --git a/library/src/otx/recipe/_base_/data/instance_segmentation.yaml b/library/src/otx/recipe/_base_/data/instance_segmentation.yaml index 73b2d608d62..3167d83887e 100644 --- a/library/src/otx/recipe/_base_/data/instance_segmentation.yaml +++ b/library/src/otx/recipe/_base_/data/instance_segmentation.yaml @@ -9,76 +9,13 @@ unannotated_items_ratio: 0.0 train_subset: subset_name: train transform_lib_type: TORCHVISION - batch_size: 1 - num_workers: 2 + batch_size: 8 + num_workers: 4 to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize + - class_path: torchvision.transforms.v2.Resize init_args: - keep_ratio: true - transform_bbox: true - transform_mask: true - scale: $(input_size) - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false - init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomAffine - enable: false - init_args: - max_rotate_degree: 10.0 - max_translate_ratio: 0.1 - scaling_ratio_range: - - 0.5 - - 1.5 - max_shear_degree: 2.0 - is_numpy_to_tvtensor: false - - class_path: otx.data.transform_libs.torchvision.Pad - enable: true - init_args: - pad_to_square: true - transform_mask: true - - class_path: otx.data.transform_libs.torchvision.RandomFlip - enable: true - init_args: - probability: 0.5 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -86,25 +23,13 @@ train_subset: val_subset: subset_name: val transform_lib_type: TORCHVISION - batch_size: 1 - num_workers: 2 + batch_size: 8 + num_workers: 4 to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - keep_ratio: true - scale: $(input_size) - is_numpy_to_tvtensor: false - - class_path: otx.data.transform_libs.torchvision.Pad - init_args: - pad_to_square: true - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize + - class_path: torchvision.transforms.v2.Resize init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -112,25 +37,13 @@ val_subset: test_subset: subset_name: test transform_lib_type: TORCHVISION - batch_size: 1 - num_workers: 2 + batch_size: 8 + num_workers: 4 to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - keep_ratio: true - scale: $(input_size) - is_numpy_to_tvtensor: false - - class_path: otx.data.transform_libs.torchvision.Pad - init_args: - pad_to_square: true - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize + - class_path: torchvision.transforms.v2.Resize init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null diff --git a/library/src/otx/recipe/_base_/data/keypoint_detection.yaml b/library/src/otx/recipe/_base_/data/keypoint_detection.yaml index e0ed303562f..022311c7022 100644 --- a/library/src/otx/recipe/_base_/data/keypoint_detection.yaml +++ b/library/src/otx/recipe/_base_/data/keypoint_detection.yaml @@ -11,47 +11,14 @@ train_subset: num_workers: 2 to_tv_image: true transforms: - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false - init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - class_path: otx.data.transform_libs.torchvision.TopdownAffine init_args: input_size: $(input_size) probability: 1.0 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + sampler: + class_path: torch.utils.data.RandomSampler + init_args: null + val_subset: subset_name: val batch_size: 32 @@ -67,13 +34,10 @@ val_subset: init_args: size: $(input_size) pad_val: 0 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + sampler: + class_path: torch.utils.data.RandomSampler + init_args: null + test_subset: subset_name: test batch_size: 32 @@ -89,10 +53,6 @@ test_subset: init_args: size: $(input_size) pad_val: 0 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + sampler: + class_path: torch.utils.data.RandomSampler + init_args: null diff --git a/library/src/otx/recipe/_base_/data/semantic_segmentation.yaml b/library/src/otx/recipe/_base_/data/semantic_segmentation.yaml index ede61422dad..9577bb9566f 100644 --- a/library/src/otx/recipe/_base_/data/semantic_segmentation.yaml +++ b/library/src/otx/recipe/_base_/data/semantic_segmentation.yaml @@ -24,57 +24,6 @@ train_subset: - 0.5 - 2.0 transform_mask: true - - class_path: otx.data.transform_libs.torchvision.PhotoMetricDistortion - enable: true - init_args: - brightness_delta: 32 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue_delta: 18 - probability: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomAffine - enable: false - init_args: - max_rotate_degree: 10.0 - max_translate_ratio: 0.1 - scaling_ratio_range: - - 0.5 - - 1.5 - max_shear_degree: 2.0 - - class_path: otx.data.transform_libs.torchvision.RandomFlip - enable: true - init_args: - probability: 0.5 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: false - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -90,13 +39,6 @@ val_subset: init_args: scale: $(input_size) transform_mask: true - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -112,13 +54,6 @@ test_subset: init_args: scale: $(input_size) transform_mask: true - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] sampler: class_path: torch.utils.data.RandomSampler init_args: null