diff --git a/library/src/otx/backend/native/models/base.py b/library/src/otx/backend/native/models/base.py index 093cc239121..9f5e9a8f5ac 100644 --- a/library/src/otx/backend/native/models/base.py +++ b/library/src/otx/backend/native/models/base.py @@ -15,12 +15,15 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence import torch +import kornia from datumaro import LabelCategories from lightning import LightningModule, Trainer from torch import Tensor, nn from torch.optim.lr_scheduler import ConstantLR from torch.optim.sgd import SGD from torchmetrics import Metric, MetricCollection +from kornia.augmentation.container import AugmentationSequential +from torchvision.transforms.v2 import Compose from otx import __version__ from otx.backend.native.optimizers.callable import OptimizerCallableSupportAdaptiveBS @@ -66,9 +69,9 @@ class DataInputParams: """Parameters of the input data such as input size, mean, and std.""" - input_size: tuple[int, int] - mean: tuple[float, float, float] - std: tuple[float, float, float] + input_size: tuple[int, int] | None = None + mean: tuple[float, float, float] | None = None + std: tuple[float, float, float] | None = None def as_dict(self) -> dict[str, Any]: """Convert to dictionary.""" @@ -76,7 +79,11 @@ def as_dict(self) -> dict[str, Any]: def as_ncwh(self, batch_size: int = 1) -> tuple[int, int, int, int]: """Convert input_size to NCWH format.""" - return (batch_size, 3, *self.input_size) + if self.input_size is not None: + return (batch_size, 3, *self.input_size) + + msg = "input_size should not be None." + raise ValueError(msg) def _default_optimizer_callable(params: params_t) -> Optimizer: @@ -134,8 +141,10 @@ def __init__( self, label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | dict | None = None, - task: OTXTaskType | None = None, model_name: str = "OTXModel", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = NullMetricCallable, @@ -153,6 +162,12 @@ def __init__( If None is given, default parameters for the specific model will be used. Defaults to None. model_name (str, optional): Name of the model. Defaults to "OTXModel". + apply_gpu_transforms (bool, optional): Flag to indicate whether to apply GPU transforms. + It is recommended to use GPU transforms. Defaults to True. + batch_train_transforms (AugmentationSequential | Compose | None): GPU transforms for training applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. + batch_val_transforms (AugmentationSequential | Compose | None): GPU transforms for validation / testing applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. Typically just normalization. optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable): Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable. @@ -160,29 +175,16 @@ def __init__( torch_compile (bool, optional): Flag to indicate if torch.compile should be used. Defaults to False. tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False). - Returns: - None """ super().__init__() self._label_info = self._dispatch_label_info(label_info) self.model_name = model_name - if isinstance(data_input_params, dict): - data_input_params = DataInputParams(**data_input_params) - elif data_input_params is None: - data_input_params = ( - self._default_preprocessing_params[self.model_name] - if isinstance(self._default_preprocessing_params, dict) - else self._default_preprocessing_params - ) - self._check_preprocessing_params(data_input_params) - self.data_input_params = data_input_params + self.data_input_params = self._configure_preprocessing_params(data_input_params) self.model = self._create_model() self.optimizer_callable = ensure_callable(optimizer) self.scheduler_callable = ensure_callable(scheduler) self.metric_callable = ensure_callable(metric) - self._task = task - self.torch_compile = torch_compile self._explain_mode = False @@ -190,6 +192,16 @@ def __init__( if isinstance(tile_config, dict): tile_config = TileConfig(**tile_config) self._tile_config = tile_config.clone() + + # Augmentation configuration + if apply_gpu_transforms: + self.batch_train_transforms, self.batch_val_transforms = self._configure_batch_augmentation(batch_train_transforms, batch_val_transforms) + self.batch_train_transforms.to(self.device) + self.batch_val_transforms.to(self.device) + else: + self.batch_train_transforms = None + self.batch_val_transforms = None + self.save_hyperparameters( logger=False, ignore=["optimizer", "scheduler", "metric", "label_info", "tile_config", "data_input_params"], @@ -197,6 +209,7 @@ def __init__( def training_step(self, batch: OTXDataBatch, batch_idx: int) -> Tensor: """Step for model training.""" + self._apply_batch_augmentations(self.batch_train_transforms, batch) train_loss = self.forward(inputs=batch) if train_loss is None: msg = "Loss is None." @@ -251,6 +264,7 @@ def validation_step(self, batch: OTXDataBatch, batch_idx: int) -> OTXPredBatch: Updates test metrics based on the prediction results and batch data. Handles both single dictionary and list of dictionaries for metric inputs. """ + self._apply_batch_augmentations(self.batch_val_transforms, batch) preds = self.forward(inputs=batch) if isinstance(preds, OTXBatchLossEntity): @@ -282,6 +296,7 @@ def test_step(self, batch: OTXDataBatch, batch_idx: int) -> OTXPredBatch: When torch_compile is enabled and stage is "fit", compiles the model for optimized performance with appropriate logging level adjustments. """ + self._apply_batch_augmentations(self.batch_val_transforms, batch) preds = self.forward(inputs=batch) if isinstance(preds, OTXBatchLossEntity): @@ -408,6 +423,76 @@ def configure_metric(self) -> None: self._metric = metric.to(self.device) + def _configure_batch_augmentation( + self, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, + ) -> tuple[AugmentationSequential, AugmentationSequential]: + """Configure batch augmentation. + + Args: + batch_train_transforms (AugmentationSequential, optional): batch train transforms + batch_val_transforms (AugmentationSequential, optional): batch val transforms + + Returns: + GPU augmentation pipeline or None + """ + if batch_train_transforms is not None: + if not ensure_callable(batch_train_transforms): + msg = "Batch train transforms should be callable. " \ + "Please use kornia AugmentationSequential or torchvision Compose" + + raise TypeError(msg) + + if batch_val_transforms is not None: + if not ensure_callable(batch_val_transforms): + msg = "Batch val transforms should be callable. " \ + "Please use kornia AugmentationSequential or torchvision Compose" + + raise TypeError(msg) + + train_aug_pipeline = batch_train_transforms if batch_train_transforms is not None else self._default_train_transforms + val_aug_pipeline = batch_val_transforms if batch_val_transforms is not None else self._default_val_transforms + + return train_aug_pipeline, val_aug_pipeline + + @staticmethod + @torch.no_grad() + def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: + if augmentations_pipeline is not None: + batch.images = augmentations_pipeline(batch.images) + + @property + def _default_train_transforms(self): + if self.task == OTXTaskType.DETECTION: + data_keys = ["input", "bbox"] + elif self.task == OTXTaskType.SEMANTIC_SEGMENTATION: + data_keys = ["input", "mask"] + elif self.task == OTXTaskType.INSTANCE_SEGMENTATION: + data_keys = ["input", "bbox", "mask"] + elif self.task == OTXTaskType.KEYPOINT_DETECTION: + data_keys = ["input", "keypoints"] + else: + data_keys = ["input"] + + return AugmentationSequential(kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=data_keys, keepdim=True) + + @property + def _default_val_transforms(self): + if self.task == OTXTaskType.DETECTION: + data_keys = ["input", "bbox"] + elif self.task == OTXTaskType.SEMANTIC_SEGMENTATION: + data_keys = ["input", "mask"] + elif self.task == OTXTaskType.INSTANCE_SEGMENTATION: + data_keys = ["input", "bbox", "mask"] + elif self.task == OTXTaskType.KEYPOINT_DETECTION: + data_keys = ["input", "keypoints"] + else: + data_keys = ["input"] + + return AugmentationSequential(kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=data_keys, keepdim=True) + @property def metric(self) -> Metric | MetricCollection: """Metric module for this OTX model.""" @@ -615,9 +700,16 @@ def _customize_outputs( def forward( self, - inputs: OTXDataBatch | Tensor, + inputs: OTXDataBatch, ) -> OTXPredBatch | OTXBatchLossEntity | Tensor: - """Model forward function.""" + """Model forward function. + + Args: + inputs: Batch of input data. + + Returns: + Model predictions or loss entity depending on training mode. + """ # Simple forward if isinstance(inputs, Tensor): return self.forward_for_tracing(inputs) @@ -662,6 +754,14 @@ def forward_tiles( """Model forward function for tile task.""" raise NotImplementedError + @property + def transforms(self) -> AugmentationSequential: + """Kornia Image Transforms (DEPRECATED). + + Use augmentation_config instead for new code. + """ + return AugmentationSequential() + def register_load_state_dict_pre_hook(self, model_classes: list[str], ckpt_classes: list[str]) -> None: """Register load_state_dict_pre_hook. @@ -902,6 +1002,7 @@ def tile_config(self, tile_config: TileConfig) -> None: self._tile_config = tile_config + @abstractmethod def get_dummy_input(self, batch_size: int = 1) -> OTXDataBatch: """Generates a dummy input, suitable for launching forward() on it. @@ -911,7 +1012,6 @@ def get_dummy_input(self, batch_size: int = 1) -> OTXDataBatch: Returns: TorchDataBatch: A batch containing randomly generated inference data. """ - raise NotImplementedError @staticmethod def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo: @@ -936,40 +1036,51 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo: raise TypeError(label_info) - def _check_preprocessing_params(self, preprocessing_params: DataInputParams | None) -> None: + def _configure_preprocessing_params(self, preprocessing_params: DataInputParams | None) -> DataInputParams: """Check the validity of the preprocessing parameters.""" - if preprocessing_params is None: - msg = "Data input parameters should not be None." - raise ValueError(msg) + if isinstance(preprocessing_params, dict): + data_input_params = DataInputParams(**preprocessing_params) + elif isinstance(preprocessing_params, DataInputParams): + data_input_params = preprocessing_params + else: + # `preprocessing_params` is None + data_input_params = DataInputParams() + + default_data_input_params = ( + self._default_preprocessing_params[self.model_name] + if isinstance(self._default_preprocessing_params, dict) + else self._default_preprocessing_params + ) - input_size = preprocessing_params.input_size - mean = preprocessing_params.mean - std = preprocessing_params.std + # Assign default values if not given in `preprocessing_params` + data_input_params.input_size = data_input_params.input_size or default_data_input_params.input_size + data_input_params.mean = data_input_params.mean or default_data_input_params.mean + data_input_params.std = data_input_params.std or default_data_input_params.std - if not (len(mean) == 3 and all(isinstance(m, float) for m in mean)): - msg = f"Mean should be a tuple of 3 float values, but got {mean} instead." + # Validate + if not (len(data_input_params.mean) == 3 and all(isinstance(m, float) for m in data_input_params.mean)): + msg = f"Mean should be a tuple of 3 float values, but got {data_input_params.mean} instead." raise ValueError(msg) - if not (len(std) == 3 and all(isinstance(s, float) for s in std)): - msg = f"Std should be a tuple of 3 float values, but got {std} instead." + if not (len(data_input_params.std) == 3 and all(isinstance(s, float) for s in data_input_params.std)): + msg = f"Std should be a tuple of 3 float values, but got {data_input_params.std} instead." raise ValueError(msg) - if not all(0 <= m <= 255 for m in mean): - msg = f"Mean values should be in the range [0, 255], but got {mean} instead." + if not all(0 <= m <= 255 for m in data_input_params.mean): + msg = f"Mean values should be in the range [0, 255], but got {data_input_params.mean} instead." raise ValueError(msg) - if not all(0 <= s <= 255 for s in std): - msg = f"Std values should be in the range [0, 255], but got {std} instead." + if not all(0 <= s <= 255 for s in data_input_params.std): + msg = f"Std values should be in the range [0, 255], but got {data_input_params.std} instead." raise ValueError(msg) - if input_size is not None and ( - input_size[0] % self.input_size_multiplier != 0 or input_size[1] % self.input_size_multiplier != 0 + if data_input_params.input_size is not None and ( + data_input_params.input_size[0] % self.input_size_multiplier != 0 or data_input_params.input_size[1] % self.input_size_multiplier != 0 ): - msg = f"Input size should be a multiple of {self.input_size_multiplier}, but got {input_size} instead." + msg = f"Input size should be a multiple of {self.input_size_multiplier}, but got {data_input_params.input_size} instead." raise ValueError(msg) + return data_input_params + @property + @abstractmethod def task(self) -> OTXTaskType: """Get task type.""" - if self._task is None: - msg = "Task type is not set. Please set the task type before using this model." - raise ValueError(msg) - return self._task diff --git a/library/src/otx/backend/native/models/classification/hlabel_models/base.py b/library/src/otx/backend/native/models/classification/hlabel_models/base.py index 1c8a0e8829d..7986bccc8f0 100644 --- a/library/src/otx/backend/native/models/classification/hlabel_models/base.py +++ b/library/src/otx/backend/native/models/classification/hlabel_models/base.py @@ -41,6 +41,12 @@ class OTXHlabelClsModel(OTXModel): data_input_params (DataInputParams | None, optional): Parameters for image data preprocessing. If None is given, default parameters for the specific model will be used. model_name (str, optional): Name of the model. Defaults to "hlabel_classification_model". + apply_gpu_transforms (bool, optional): Flag to indicate whether to apply GPU transforms. + It is recommended to use GPU transforms. Defaults to True. + batch_train_transforms (AugmentationSequential | Compose | None): GPU transforms for training applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. + batch_val_transforms (AugmentationSequential | Compose | None): GPU transforms for validation / testing applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. Typically just normalization. optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable. @@ -55,6 +61,9 @@ def __init__( label_info: HLabelInfo, data_input_params: DataInputParams | None = None, model_name: str = "hlabel_classification_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, freeze_backbone: bool = False, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, diff --git a/library/src/otx/backend/native/models/classification/multiclass_models/base.py b/library/src/otx/backend/native/models/classification/multiclass_models/base.py index f672485db2c..db0ea4f8025 100644 --- a/library/src/otx/backend/native/models/classification/multiclass_models/base.py +++ b/library/src/otx/backend/native/models/classification/multiclass_models/base.py @@ -9,6 +9,7 @@ import torch from torch import Tensor +import kornia from otx.backend.native.exporter.base import OTXModelExporter from otx.backend.native.exporter.native import OTXNativeModelExporter @@ -23,6 +24,10 @@ from otx.types.export import TaskLevelExportParameters from otx.types.label import LabelInfoTypes from otx.types.task import OTXTaskType +from kornia.augmentation.container import AugmentationSequential +from kornia.augmentation import Normalize +from kornia.augmentation.auto import AutoAugment +from torchvision.transforms.v2 import Compose if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -40,6 +45,12 @@ class OTXMulticlassClsModel(OTXModel): data_input_params (DataInputParams | None, optional): Parameters for the image data preprocessing. If None is given, default parameters for the specific model will be used. model_name (str, optional): Name of the model. Defaults to "multiclass_classification_model". + apply_gpu_transforms (bool, optional): Flag to indicate whether to apply GPU transforms. + It is recommended to use GPU transforms. Defaults to True. + batch_train_transforms (AugmentationSequential | Compose | None): GPU transforms for training applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. + batch_val_transforms (AugmentationSequential | Compose | None): GPU transforms for validation / testing applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. Typically just normalization. optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable. @@ -52,6 +63,9 @@ def __init__( label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | None = None, model_name: str = "multiclass_classification_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, freeze_backbone: bool = False, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, @@ -61,8 +75,10 @@ def __init__( super().__init__( label_info=label_info, data_input_params=data_input_params, - task=OTXTaskType.MULTI_CLASS_CLS, model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, @@ -121,6 +137,12 @@ def _customize_outputs( scores=list(scores), ) + @property + def _default_train_transforms(self): + return AugmentationSequential(kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1), + Normalize(self.data_input_params.mean, self.data_input_params.std)) + @property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" @@ -185,6 +207,10 @@ def forward_explain(self, inputs: OTXDataBatch) -> OTXPredBatch: feature_vector=[feature_vector.unsqueeze(0) for feature_vector in outputs["feature_vector"]], ) + @property + def task(self) -> OTXTaskType: + return OTXTaskType.MULTI_CLASS_CLS + @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(224, 224), mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375)) + return DataInputParams(input_size=(224, 224), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) diff --git a/library/src/otx/backend/native/models/classification/multiclass_models/efficientnet.py b/library/src/otx/backend/native/models/classification/multiclass_models/efficientnet.py index a6a767ab26d..526e96bfbb5 100644 --- a/library/src/otx/backend/native/models/classification/multiclass_models/efficientnet.py +++ b/library/src/otx/backend/native/models/classification/multiclass_models/efficientnet.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Literal from torch import Tensor, nn +import kornia from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable from otx.backend.native.models.classification.backbones.efficientnet import EfficientNetBackbone @@ -96,3 +97,21 @@ def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: return self.model(images=image, mode="explain") return self.model(images=image, mode="tensor") + + @property + def transforms(self): + if self.training: + return kornia.augmentation.AugmentationSequential( + # kornia.augmentation.RandomResizedCrop(self.data_input_params.input_size, scale=(0.08, 1.0)), + kornia.augmentation.RandomAffine(degrees=10.0, translate=[0.1, 0.1], scale=[0.5,1.5], shear=2.0), + kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1), + kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.RandomGaussianBlur(5, (0.1, 2.0)), + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input"], + same_on_batch=False + ) + return kornia.augmentation.AugmentationSequential( + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input"], + ) diff --git a/library/src/otx/backend/native/models/classification/multilabel_models/base.py b/library/src/otx/backend/native/models/classification/multilabel_models/base.py index 3406588fd8b..92f60f3b3a9 100644 --- a/library/src/otx/backend/native/models/classification/multilabel_models/base.py +++ b/library/src/otx/backend/native/models/classification/multilabel_models/base.py @@ -38,6 +38,12 @@ class OTXMultilabelClsModel(OTXModel): if `Sequence` is given, label info will be constructed from the sequence of label names. data_input_params (DataInputParams | None, optional): Parameters for the image data preprocessing. model_name (str, optional): Name of the model. Defaults to "multilabel_classification_model". + apply_gpu_transforms (bool, optional): Flag to indicate whether to apply GPU transforms. + It is recommended to use GPU transforms. Defaults to True. + batch_train_transforms (AugmentationSequential | Compose | None): GPU transforms for training applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. + batch_val_transforms (AugmentationSequential | Compose | None): GPU transforms for validation / testing applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. Typically just normalization. optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable. @@ -50,6 +56,9 @@ def __init__( label_info: LabelInfoTypes | Sequence, data_input_params: DataInputParams | None = None, model_name: str = "multiclass_classification_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, freeze_backbone: bool = False, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, @@ -59,8 +68,10 @@ def __init__( super().__init__( label_info=label_info, data_input_params=data_input_params, - task=OTXTaskType.MULTI_LABEL_CLS, model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, diff --git a/library/src/otx/backend/native/models/detection/atss.py b/library/src/otx/backend/native/models/detection/atss.py index c48f0b9ed5e..9a5fc7df222 100644 --- a/library/src/otx/backend/native/models/detection/atss.py +++ b/library/src/otx/backend/native/models/detection/atss.py @@ -6,6 +6,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, ClassVar, Literal +import kornia as K from otx.backend.native.exporter.base import OTXModelExporter from otx.backend.native.exporter.native import OTXNativeModelExporter @@ -203,4 +204,4 @@ def _exporter(self) -> OTXModelExporter: @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(800, 992), mean=(0.0, 0.0, 0.0), std=(255.0, 255.0, 255.0)) + return DataInputParams(input_size=(800, 992), mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)) diff --git a/library/src/otx/backend/native/models/detection/base.py b/library/src/otx/backend/native/models/detection/base.py index 5d729d4bc2e..f4c565fc89f 100644 --- a/library/src/otx/backend/native/models/detection/base.py +++ b/library/src/otx/backend/native/models/detection/base.py @@ -15,6 +15,9 @@ import torch from torchmetrics import Metric, MetricCollection from torchvision import tv_tensors +import kornia +from kornia.geometry.boxes import Boxes +from kornia.augmentation.container import AugmentationSequential from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel from otx.backend.native.models.utils.utils import InstanceData @@ -71,6 +74,9 @@ def __init__( label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | dict | None = None, model_name: str = "otx_detection_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MeanAveragePrecisionFMeasureCallable, @@ -80,9 +86,11 @@ def __init__( ) -> None: super().__init__( label_info=label_info, - model_name=model_name, - task=OTXTaskType.DETECTION, data_input_params=data_input_params, + model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, @@ -187,7 +195,6 @@ def _customize_inputs( inputs["entity"] = entity inputs["mode"] = "loss" if self.training else "predict" - return inputs def _customize_outputs( @@ -546,6 +553,19 @@ def get_num_anchors(self) -> list[int]: return [1] * 10 + @staticmethod + @torch.no_grad() + def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: + if augmentations_pipeline is not None: + # Convert bounding boxes to Kornia Boxes [N, 4, 2] + kornia_boxes = Boxes.from_tensor(batch.bboxes, mode='xyxy') + batch.images, kornia_boxes = augmentations_pipeline(batch.images, kornia_boxes) + batch.bboxes = kornia_boxes.to_tensor(mode='xyxy') + @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(640, 640), mean=(0.0, 0.0, 0.0), std=(255.0, 255.0, 255.0)) + return DataInputParams(input_size=(640, 640), mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)) + + @property + def task(self) -> OTXTaskType: + return OTXTaskType.DETECTION diff --git a/library/src/otx/backend/native/models/detection/d_fine.py b/library/src/otx/backend/native/models/detection/d_fine.py index 9906d07a090..fe43a532b0f 100644 --- a/library/src/otx/backend/native/models/detection/d_fine.py +++ b/library/src/otx/backend/native/models/detection/d_fine.py @@ -17,6 +17,7 @@ from otx.backend.native.models.utils.utils import load_checkpoint from otx.config.data import TileConfig from otx.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable +import kornia if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -174,3 +175,17 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: ckpt.pop("model.decoder.anchors") ckpt.pop("model.decoder.valid_mask") return super().load_state_dict(ckpt, *args, strict=False, **kwargs) + + @property + def transforms(self): + if self.training: + return kornia.augmentation.AugmentationSequential( + kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input", "bbox"], + same_on_batch=False + ) + return kornia.augmentation.AugmentationSequential( + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), + data_keys=["input", "bbox"], + ) diff --git a/library/src/otx/backend/native/models/instance_segmentation/base.py b/library/src/otx/backend/native/models/instance_segmentation/base.py index 805e8538351..a3ddc5d010e 100644 --- a/library/src/otx/backend/native/models/instance_segmentation/base.py +++ b/library/src/otx/backend/native/models/instance_segmentation/base.py @@ -18,6 +18,8 @@ from torchmetrics import Metric, MetricCollection from torchvision import tv_tensors from torchvision.models.detection.image_list import ImageList +from kornia.geometry.boxes import Boxes +from kornia.augmentation.container import AugmentationSequential from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel from otx.backend.native.models.instance_segmentation.segmentors.maskrcnn_tv import MaskRCNN @@ -61,6 +63,12 @@ class OTXInstanceSegModel(OTXModel): data_input_params (DataInputParams | None, optional): Parameters for the image data preprocessing. If None is given, default parameters for the specific model will be used. model_name (str, optional): Name of the model. Defaults to "inst_segm_model". + apply_gpu_transforms (bool, optional): Flag to indicate whether to apply GPU transforms. + It is recommended to use GPU transforms. Defaults to True. + batch_train_transforms (AugmentationSequential | Compose | None): GPU transforms for training applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. + batch_val_transforms (AugmentationSequential | Compose | None): GPU transforms for validation / testing applied directly to the batch. + If None is given, default augmentation pipeline for the model will be used. Typically just normalization. optimizer (OptimizerCallable, optional): Optimizer for the model. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Scheduler for the model. Defaults to DefaultSchedulerCallable. @@ -76,6 +84,9 @@ def __init__( label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | None = None, model_name: str = "inst_segm_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MaskRLEMeanAPFMeasureCallable, @@ -85,8 +96,10 @@ def __init__( super().__init__( label_info=label_info, data_input_params=data_input_params, - task=OTXTaskType.INSTANCE_SEGMENTATION, model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, @@ -616,6 +629,22 @@ def _restore_model_forward(self) -> None: self.model.forward = func_type(self.original_model_forward, self.model) self.original_model_forward = None + @staticmethod + @torch.no_grad() + def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: + if augmentations_pipeline is not None: + # Convert bounding boxes to Kornia Boxes [N, 4, 2] + kornia_boxes = Boxes.from_tensor(batch.bboxes, mode='xyxy') + breakpoint() + batch.images, kornia_boxes, masks = augmentations_pipeline(batch.images, kornia_boxes, batch.masks) + batch.bboxes = kornia_boxes.to_tensor(mode='xyxy') + breakpoint() + batch.masks = masks + @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(1024, 1024), mean=(103.53, 116.28, 123.675), std=(57.375, 57.12, 58.395)) + return DataInputParams(input_size=(1024, 1024), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + + @property + def task(self) -> OTXTaskType: + return OTXTaskType.INSTANCE_SEGMENTATION diff --git a/library/src/otx/backend/native/models/keypoint_detection/base.py b/library/src/otx/backend/native/models/keypoint_detection/base.py index 1f9cffccb16..983be62e7b7 100644 --- a/library/src/otx/backend/native/models/keypoint_detection/base.py +++ b/library/src/otx/backend/native/models/keypoint_detection/base.py @@ -10,7 +10,9 @@ from typing import TYPE_CHECKING, Any, Sequence import torch +import kornia +from kornia.augmentation.container import AugmentationSequential from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel from otx.backend.native.schedulers import LRSchedulerListCallable from otx.data.entity.base import ImageInfo, OTXBatchLossEntity @@ -48,6 +50,9 @@ def __init__( label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | None = None, model_name: str = "keypoint_detection_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = PCKMeasureCallable, @@ -56,8 +61,10 @@ def __init__( super().__init__( label_info=label_info, data_input_params=data_input_params, - task=OTXTaskType.KEYPOINT_DETECTION, model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, @@ -209,6 +216,29 @@ def _export_parameters(self) -> TaskLevelExportParameters: confidence_threshold=self.hparams.get("best_confidence_threshold", None), ) + @staticmethod + @torch.no_grad() + def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: + if augmentations_pipeline is not None: + stacked_kps = torch.stack(batch.keypoints) + # Apply augmentations + batch.images, augmented_kps = augmentations_pipeline(batch.images, stacked_kps[:, :, :2]) + stacked_kps[:, :, :2] = augmented_kps + h, w = batch.images.shape[-2:] + # Compute visible mask. Keypoints should be visible if they are inside the image (>=0, x<=w, y<=h) + visible_mask = (augmented_kps > 0).all(axis=2) * (augmented_kps[:, :, 0] <= w) * (augmented_kps[:, :, 1] <= h) + stacked_kps[:, :, 2] = stacked_kps[:, :, 2] * visible_mask + # Update visible keypoints with augmented values + batch.keypoints = [kps for kps in stacked_kps] + + @property + def _default_train_transforms(self): + return AugmentationSequential(kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=["input", "keypoints"], keepdim=True) + @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(640, 640), mean=(0.0, 0.0, 0.0), std=(255.0, 255.0, 255.0)) + return DataInputParams(input_size=(512, 512), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + + @property + def task(self) -> OTXTaskType: + return OTXTaskType.KEYPOINT_DETECTION diff --git a/library/src/otx/backend/native/models/keypoint_detection/rtmpose.py b/library/src/otx/backend/native/models/keypoint_detection/rtmpose.py index 7a96218b91c..d3f359a3a24 100644 --- a/library/src/otx/backend/native/models/keypoint_detection/rtmpose.py +++ b/library/src/otx/backend/native/models/keypoint_detection/rtmpose.py @@ -35,9 +35,6 @@ class RTMPose(OTXKeypointDetectionModel): "rtmpose_tiny": "https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/cspnext-tiny_udp-aic-coco_210e-256x192-cbed682d_20230130.pth", } - _default_preprocessing_params: ClassVar[dict[str, DataInputParams] | DataInputParams] = { - "rtmpose_tiny": DataInputParams(input_size=(640, 640), mean=(0.0, 0.0, 0.0), std=(255.0, 255.0, 255.0)), - } def __init__( self, diff --git a/library/src/otx/backend/native/models/segmentation/base.py b/library/src/otx/backend/native/models/segmentation/base.py index a5e5577b22c..9689bafbd6b 100644 --- a/library/src/otx/backend/native/models/segmentation/base.py +++ b/library/src/otx/backend/native/models/segmentation/base.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any import torch +import kornia import torch.nn.functional as f from torchvision import tv_tensors @@ -20,6 +21,7 @@ from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel from otx.backend.native.schedulers import LRSchedulerListCallable from otx.backend.native.tools.tile_merge import SegmentationTileMerge +from kornia.augmentation.container import AugmentationSequential from otx.config.data import TileConfig from otx.data.entity.base import ImageInfo, OTXBatchLossEntity from otx.data.entity.tile import OTXTileBatchDataEntity @@ -59,6 +61,9 @@ def __init__( label_info: LabelInfoTypes | int | Sequence, data_input_params: DataInputParams | None = None, model_name: str = "otx_segmentation_model", + apply_gpu_transforms: bool = True, + batch_train_transforms: AugmentationSequential | Compose | None = None, + batch_val_transforms: AugmentationSequential | Compose | None = None, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = SegmCallable, # type: ignore[assignment] @@ -68,8 +73,10 @@ def __init__( super().__init__( label_info=label_info, data_input_params=data_input_params, - task=OTXTaskType.SEMANTIC_SEGMENTATION, model_name=model_name, + apply_gpu_transforms=apply_gpu_transforms, + batch_train_transforms=batch_train_transforms, + batch_val_transforms=batch_val_transforms, optimizer=optimizer, scheduler=scheduler, metric=metric, @@ -297,6 +304,23 @@ def get_dummy_input(self, batch_size: int = 1) -> OTXDataBatch: # type: ignore[ ) return OTXDataBatch(batch_size, images, imgs_info=infos, masks=[]) # type: ignore[arg-type] + @staticmethod + @torch.no_grad() + def _apply_batch_augmentations(augmentations_pipeline: AugmentationSequential | Compose | None, batch: OTXDataBatch) -> None: + if augmentations_pipeline is not None: + batch.images, batch.masks = augmentations_pipeline(batch.images, batch.masks) + + @property + def _default_train_transforms(self): + return AugmentationSequential( + kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1), + kornia.augmentation.RandomHorizontalFlip(), + kornia.augmentation.Normalize(self.data_input_params.mean, self.data_input_params.std), data_keys=["input", "mask"], keepdim=True) + @property def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: - return DataInputParams(input_size=(512, 512), mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375)) + return DataInputParams(input_size=(512, 512), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + + @property + def task(self) -> OTXTaskType: + return OTXTaskType.SEMANTIC_SEGMENTATION diff --git a/library/src/otx/data/dataset/base.py b/library/src/otx/data/dataset/base.py index d8c3ad736ed..3893466b51c 100644 --- a/library/src/otx/data/dataset/base.py +++ b/library/src/otx/data/dataset/base.py @@ -9,9 +9,11 @@ from collections.abc import Iterable from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Union +from torchvision.transforms.v2.functional import to_dtype, to_image import cv2 import numpy as np +import torch from datumaro.components.annotation import AnnotationType from datumaro.util.image import IMAGE_BACKEND, IMAGE_COLOR_CHANNEL, ImageBackend from datumaro.util.image import ImageColorChannel as DatumaroImageColorChannel @@ -194,6 +196,7 @@ def _get_img_data_and_shape( img_data = img_data[y1:y2, x1:x2] roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)} + img_data = to_dtype(to_image(img_data), scale=True, dtype=torch.float32).clamp_(0, 1) return img_data, img_data.shape[:2], roi_meta @abstractmethod diff --git a/library/src/otx/data/dataset/classification.py b/library/src/otx/data/dataset/classification.py index 9a9fb8cc439..5d911b545d6 100644 --- a/library/src/otx/data/dataset/classification.py +++ b/library/src/otx/data/dataset/classification.py @@ -90,7 +90,7 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: img = item.media_as(Image) roi = item.attributes.get("roi", None) img_data, img_shape, _ = self._get_img_data_and_shape(img, roi) - image = to_dtype(to_image(img_data), dtype=torch.float32) + if roi: # extract labels from ROI labels_ids = [ @@ -118,7 +118,12 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: image_color_channel=self.image_color_channel, ), ) - return self._apply_transforms(entity) + entity = self._apply_transforms(entity) + + image = entity.image + image.clamp_(0, 1) + + return entity @property def task_type(self) -> OTXTaskType: diff --git a/library/src/otx/data/dataset/instance_segmentation.py b/library/src/otx/data/dataset/instance_segmentation.py index afb25d4b4bc..4bc33c2df73 100644 --- a/library/src/otx/data/dataset/instance_segmentation.py +++ b/library/src/otx/data/dataset/instance_segmentation.py @@ -119,8 +119,8 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: warnings.warn(f"No valid annotations found for image {item.id}!", stacklevel=2) bboxes = np.stack(gt_bboxes, dtype=np.float32, axis=0) if gt_bboxes else np.empty((0, 4)) + # TODO(@kprokofi): NO MASKS!!! masks = np.stack(gt_masks, axis=0) if gt_masks else np.zeros((0, *img_shape), dtype=bool) - labels = np.array(gt_labels, dtype=np.int64) entity = OTXDataItem( diff --git a/library/src/otx/data/dataset/keypoint_detection.py b/library/src/otx/data/dataset/keypoint_detection.py index 0589d53dbfd..9713fc14e7a 100644 --- a/library/src/otx/data/dataset/keypoint_detection.py +++ b/library/src/otx/data/dataset/keypoint_detection.py @@ -147,7 +147,7 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: keypoints = np.hstack((keypoints, keypoints_visible.reshape(-1, 1))) entity = OTXDataItem( - image=to_dtype(to_image(img_data), torch.float32), + image=img_data, img_info=ImageInfo( img_idx=index, img_shape=img_shape, diff --git a/library/src/otx/data/dataset/segmentation.py b/library/src/otx/data/dataset/segmentation.py index e25c9be61f1..01958158cf8 100644 --- a/library/src/otx/data/dataset/segmentation.py +++ b/library/src/otx/data/dataset/segmentation.py @@ -245,7 +245,7 @@ def _get_item_impl(self, index: int) -> OTXDataItem | None: masks = tv_tensors.Mask(extracted_mask[None], dtype=torch.long) entity = OTXDataItem( - image=to_dtype(to_image(img_data), dtype=torch.float32), + image=img_data, img_info=ImageInfo( img_idx=index, img_shape=img_shape, diff --git a/library/src/otx/data/entity/torch/torch.py b/library/src/otx/data/entity/torch/torch.py index ff85c15c514..2d8a2c6cf0c 100644 --- a/library/src/otx/data/entity/torch/torch.py +++ b/library/src/otx/data/entity/torch/torch.py @@ -62,12 +62,7 @@ def collate_fn(items: list[OTXDataItem]) -> OTXDataBatch: Returns: Batched TorchDataItems with stacked tensors """ - # Check if all images have the same size. TODO(kprokofi): remove this check once OV IR models are moved. - if all(item.image.shape == items[0].image.shape for item in items): - images = torch.stack([item.image for item in items]) - else: - # we need this only in case of OV inference, where no resize - images = [item.image for item in items] + images = torch.stack([item.image for item in items]) return OTXDataBatch( batch_size=len(items), diff --git a/library/src/otx/data/entity/torch/validations.py b/library/src/otx/data/entity/torch/validations.py index 65bd63c2d24..051c7114957 100644 --- a/library/src/otx/data/entity/torch/validations.py +++ b/library/src/otx/data/entity/torch/validations.py @@ -9,6 +9,7 @@ import numpy as np import torch +from kornia.geometry.boxes import Boxes from datumaro import Polygon from torchvision.tv_tensors import BoundingBoxes, Mask @@ -61,6 +62,7 @@ def _label_validator(label: torch.Tensor) -> torch.Tensor: raise TypeError(msg) if label.dtype != torch.long: msg = f"Label must have dtype torch.long, but got {label.dtype}" + print(label) raise ValueError(msg) # detection tasks allow multiple labels so the shape is [B, N] if label.ndim > 2: diff --git a/library/src/otx/data/module.py b/library/src/otx/data/module.py index 7bf4031ca5a..71d86531433 100644 --- a/library/src/otx/data/module.py +++ b/library/src/otx/data/module.py @@ -193,7 +193,7 @@ def _setup_otx_dataset(self, dataset: DmDataset) -> None: def extract_normalization_params( self, transforms_source: list | None - ) -> tuple[tuple[float, float, float], tuple[float, float, float]]: + ) -> tuple[tuple[float, float, float] | None, tuple[float, float, float]] | None: """Extract mean and std from transforms. Args: @@ -202,8 +202,8 @@ def extract_normalization_params( Returns: Tuple of (mean, std) tuples. """ - mean = (0.0, 0.0, 0.0) - std = (1.0, 1.0, 1.0) + mean = None + std = None if transforms_source is not None: for transform in transforms_source: @@ -441,6 +441,7 @@ def train_dataloader(self) -> DataLoader: "persistent_workers": config.num_workers > 0, "sampler": sampler, "shuffle": sampler is None, + "prefetch_factor": 2, } tile_config = self.tile_config diff --git a/library/src/otx/data/transform_libs/torchvision.py b/library/src/otx/data/transform_libs/torchvision.py index cb56b2ab871..42d9f8f3b69 100644 --- a/library/src/otx/data/transform_libs/torchvision.py +++ b/library/src/otx/data/transform_libs/torchvision.py @@ -2969,7 +2969,8 @@ def __call__(self, *_inputs: OTXDataItem) -> OTXDataItem | None: inputs.keypoints = torch.zeros([]) else: # update keypoints_visible after affine transforms - inputs.keypoints[:, 2] = inputs.keypoints[:, 2] * (inputs.keypoints[:, :2] > 0).all(axis=1) + # update keypoints_visible. Keypoints should be visible if they are inside the image (>=0, x<=w, y<=h) + inputs.keypoints[:, 2] = inputs.keypoints[:, 2] * (inputs.keypoints[:, :2] >= 0).all(axis=1) * (inputs.keypoints[:, 0] <= w) * (inputs.keypoints[:, 1] <= h) return self.convert(inputs) diff --git a/library/src/otx/recipe/_base_/data/classification.yaml b/library/src/otx/recipe/_base_/data/classification.yaml index 8bd2149f66d..c4104ee4b97 100644 --- a/library/src/otx/recipe/_base_/data/classification.yaml +++ b/library/src/otx/recipe/_base_/data/classification.yaml @@ -9,73 +9,18 @@ train_subset: subset_name: train transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 2 - to_tv_image: false + num_workers: 4 + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.RandomResizedCrop + - class_path: torchvision.transforms.v2.RandomResizedCrop init_args: - scale: $(input_size) - crop_ratio_range: + size: $(input_size) + scale: - 0.08 - 1.0 - aspect_ratio_range: + ratio: - 0.75 - 1.34 - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false - init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomAffine - enable: false - init_args: - max_rotate_degree: 10.0 - max_translate_ratio: 0.1 - scaling_ratio_range: - - 0.5 - - 1.5 - max_shear_degree: 2.0 - - class_path: otx.data.transform_libs.torchvision.RandomFlip - enable: true - init_args: - probability: 0.5 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: false - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] sampler: class_path: otx.data.samplers.balanced_sampler.BalancedSampler init_args: null @@ -84,20 +29,12 @@ val_subset: subset_name: val transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 2 - to_tv_image: false + num_workers: 4 + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - scale: $(input_size) - - class_path: torchvision.transforms.v2.ToDtype + - class_path: torchvision.transforms.v2.Resize init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: false - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -106,20 +43,12 @@ test_subset: subset_name: test transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 2 - to_tv_image: false + num_workers: 4 + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - scale: $(input_size) - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: false - - class_path: torchvision.transforms.v2.Normalize + - class_path: torchvision.transforms.v2.Resize init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null diff --git a/library/src/otx/recipe/_base_/data/detection.yaml b/library/src/otx/recipe/_base_/data/detection.yaml index 503a22bd5b4..3eac60cc796 100644 --- a/library/src/otx/recipe/_base_/data/detection.yaml +++ b/library/src/otx/recipe/_base_/data/detection.yaml @@ -8,70 +8,17 @@ unannotated_items_ratio: 0.0 train_subset: subset_name: train transform_lib_type: TORCHVISION - batch_size: 1 - num_workers: 2 + batch_size: 8 + num_workers: 4 to_tv_image: false transforms: - - class_path: otx.data.transform_libs.torchvision.MinIoURandomCrop - enable: true - - class_path: otx.data.transform_libs.torchvision.Resize + - class_path: torchvision.transforms.v2.RandomIoUCrop + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes init_args: - scale: $(input_size) - transform_bbox: true - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false + min_size: 1 + - class_path: torchvision.transforms.v2.Resize init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomAffine - enable: false - init_args: - max_rotate_degree: 10.0 - max_translate_ratio: 0.1 - scaling_ratio_range: - - 0.5 - - 1.5 - max_shear_degree: 2.0 - - class_path: otx.data.transform_libs.torchvision.RandomFlip - enable: true - init_args: - probability: 0.5 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [0.0, 0.0, 0.0] - std: [255.0, 255.0, 255.0] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -79,20 +26,13 @@ train_subset: val_subset: subset_name: val transform_lib_type: TORCHVISION - batch_size: 1 - num_workers: 2 - to_tv_image: false + batch_size: 8 + num_workers: 4 + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - scale: $(input_size) - - class_path: torchvision.transforms.v2.ToDtype + - class_path: torchvision.transforms.v2.Resize init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [0.0, 0.0, 0.0] - std: [255.0, 255.0, 255.0] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -100,20 +40,14 @@ val_subset: test_subset: subset_name: test transform_lib_type: TORCHVISION - batch_size: 1 - num_workers: 2 - to_tv_image: false + batch_size: 8 + num_workers: 4 + to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize + - class_path: torchvision.transforms.v2.Resize init_args: - scale: $(input_size) - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [0.0, 0.0, 0.0] - std: [255.0, 255.0, 255.0] + size: $(input_size) + sampler: class_path: torch.utils.data.RandomSampler init_args: null diff --git a/library/src/otx/recipe/_base_/data/instance_segmentation.yaml b/library/src/otx/recipe/_base_/data/instance_segmentation.yaml index 73b2d608d62..3167d83887e 100644 --- a/library/src/otx/recipe/_base_/data/instance_segmentation.yaml +++ b/library/src/otx/recipe/_base_/data/instance_segmentation.yaml @@ -9,76 +9,13 @@ unannotated_items_ratio: 0.0 train_subset: subset_name: train transform_lib_type: TORCHVISION - batch_size: 1 - num_workers: 2 + batch_size: 8 + num_workers: 4 to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize + - class_path: torchvision.transforms.v2.Resize init_args: - keep_ratio: true - transform_bbox: true - transform_mask: true - scale: $(input_size) - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false - init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomAffine - enable: false - init_args: - max_rotate_degree: 10.0 - max_translate_ratio: 0.1 - scaling_ratio_range: - - 0.5 - - 1.5 - max_shear_degree: 2.0 - is_numpy_to_tvtensor: false - - class_path: otx.data.transform_libs.torchvision.Pad - enable: true - init_args: - pad_to_square: true - transform_mask: true - - class_path: otx.data.transform_libs.torchvision.RandomFlip - enable: true - init_args: - probability: 0.5 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -86,25 +23,13 @@ train_subset: val_subset: subset_name: val transform_lib_type: TORCHVISION - batch_size: 1 - num_workers: 2 + batch_size: 8 + num_workers: 4 to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - keep_ratio: true - scale: $(input_size) - is_numpy_to_tvtensor: false - - class_path: otx.data.transform_libs.torchvision.Pad - init_args: - pad_to_square: true - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize + - class_path: torchvision.transforms.v2.Resize init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -112,25 +37,13 @@ val_subset: test_subset: subset_name: test transform_lib_type: TORCHVISION - batch_size: 1 - num_workers: 2 + batch_size: 8 + num_workers: 4 to_tv_image: true transforms: - - class_path: otx.data.transform_libs.torchvision.Resize - init_args: - keep_ratio: true - scale: $(input_size) - is_numpy_to_tvtensor: false - - class_path: otx.data.transform_libs.torchvision.Pad - init_args: - pad_to_square: true - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize + - class_path: torchvision.transforms.v2.Resize init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + size: $(input_size) sampler: class_path: torch.utils.data.RandomSampler init_args: null diff --git a/library/src/otx/recipe/_base_/data/keypoint_detection.yaml b/library/src/otx/recipe/_base_/data/keypoint_detection.yaml index e0ed303562f..022311c7022 100644 --- a/library/src/otx/recipe/_base_/data/keypoint_detection.yaml +++ b/library/src/otx/recipe/_base_/data/keypoint_detection.yaml @@ -11,47 +11,14 @@ train_subset: num_workers: 2 to_tv_image: true transforms: - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false - init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - class_path: otx.data.transform_libs.torchvision.TopdownAffine init_args: input_size: $(input_size) probability: 1.0 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + sampler: + class_path: torch.utils.data.RandomSampler + init_args: null + val_subset: subset_name: val batch_size: 32 @@ -67,13 +34,10 @@ val_subset: init_args: size: $(input_size) pad_val: 0 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + sampler: + class_path: torch.utils.data.RandomSampler + init_args: null + test_subset: subset_name: test batch_size: 32 @@ -89,10 +53,6 @@ test_subset: init_args: size: $(input_size) pad_val: 0 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] + sampler: + class_path: torch.utils.data.RandomSampler + init_args: null diff --git a/library/src/otx/recipe/_base_/data/semantic_segmentation.yaml b/library/src/otx/recipe/_base_/data/semantic_segmentation.yaml index ede61422dad..9577bb9566f 100644 --- a/library/src/otx/recipe/_base_/data/semantic_segmentation.yaml +++ b/library/src/otx/recipe/_base_/data/semantic_segmentation.yaml @@ -24,57 +24,6 @@ train_subset: - 0.5 - 2.0 transform_mask: true - - class_path: otx.data.transform_libs.torchvision.PhotoMetricDistortion - enable: true - init_args: - brightness_delta: 32 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue_delta: 18 - probability: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomAffine - enable: false - init_args: - max_rotate_degree: 10.0 - max_translate_ratio: 0.1 - scaling_ratio_range: - - 0.5 - - 1.5 - max_shear_degree: 2.0 - - class_path: otx.data.transform_libs.torchvision.RandomFlip - enable: true - init_args: - probability: 0.5 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: false - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -90,13 +39,6 @@ val_subset: init_args: scale: $(input_size) transform_mask: true - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] sampler: class_path: torch.utils.data.RandomSampler init_args: null @@ -112,13 +54,6 @@ test_subset: init_args: scale: $(input_size) transform_mask: true - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] sampler: class_path: torch.utils.data.RandomSampler init_args: null diff --git a/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml b/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml index 893ee13fd82..4e1706c2b23 100644 --- a/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml +++ b/library/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml @@ -46,73 +46,4 @@ callbacks: filename: "checkpoints/epoch_{epoch:03d}" overrides: - reset: - - data.train_subset.transforms max_epochs: 90 - - data: - train_subset: - transforms: - - class_path: otx.data.transform_libs.torchvision.EfficientNetRandomCrop - init_args: - scale: $(input_size) - crop_ratio_range: - - 0.08 - - 1.0 - aspect_ratio_range: - - 0.75 - - 1.34 - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false - init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomAffine - enable: false - init_args: - max_rotate_degree: 10.0 - max_translate_ratio: 0.1 - scaling_ratio_range: - - 0.5 - - 1.5 - max_shear_degree: 2.0 - - class_path: otx.data.transform_libs.torchvision.RandomFlip - enable: true - init_args: - probability: 0.5 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - - class_path: torchvision.transforms.v2.ToDtype - init_args: - dtype: ${as_torch_dtype:torch.float32} - scale: false - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [123.675, 116.28, 103.53] - std: [58.395, 57.12, 57.375] diff --git a/library/src/otx/recipe/detection/atss_mobilenetv2.yaml b/library/src/otx/recipe/detection/atss_mobilenetv2.yaml index b24d4acb32e..1ebc4c87d32 100644 --- a/library/src/otx/recipe/detection/atss_mobilenetv2.yaml +++ b/library/src/otx/recipe/detection/atss_mobilenetv2.yaml @@ -4,7 +4,6 @@ model: init_args: model_name: atss_mobilenetv2 label_info: 80 - optimizer: class_path: torch.optim.SGD init_args: diff --git a/library/src/otx/recipe/detection/yolox_x.yaml b/library/src/otx/recipe/detection/yolox_x.yaml index 30206207646..1594d572653 100644 --- a/library/src/otx/recipe/detection/yolox_x.yaml +++ b/library/src/otx/recipe/detection/yolox_x.yaml @@ -69,75 +69,74 @@ overrides: - class_path: otx.data.transform_libs.torchvision.Resize init_args: scale: $(input_size) - keep_ratio: true - transform_bbox: true - - class_path: otx.data.transform_libs.torchvision.CachedMosaic - init_args: - random_pop: false - max_cached_images: 20 - img_scale: $(input_size) # (H, W) - - class_path: otx.data.transform_libs.torchvision.RandomAffine - init_args: - border: $(input_size) * -0.5 - - class_path: otx.data.transform_libs.torchvision.CachedMixUp - init_args: - img_scale: $(input_size) # (H, W) - ratio_range: - - 1.0 - - 1.0 - probability: 0.5 - random_pop: false - max_cached_images: 10 - - class_path: torchvision.transforms.v2.RandomPhotometricDistort - enable: false - init_args: - brightness: - - 0.875 - - 1.125 - contrast: - - 0.5 - - 1.5 - saturation: - - 0.5 - - 1.5 - hue: - - -0.05 - - 0.05 - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.YOLOXHSVRandomAug - - class_path: otx.data.transform_libs.torchvision.RandomFlip - init_args: - probability: 0.5 - is_numpy_to_tvtensor: false - - class_path: otx.data.transform_libs.torchvision.Pad - init_args: - pad_to_square: true - pad_val: 114 - - class_path: torchvision.transforms.v2.RandomVerticalFlip - enable: false - init_args: - p: 0.5 - - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur - enable: false - init_args: - kernel_size: 5 - sigma: - - 0.1 - - 2.0 - probability: 0.5 + # - class_path: otx.data.transform_libs.torchvision.CachedMosaic + # init_args: + # random_pop: false + # max_cached_images: 20 + # img_scale: $(input_size) # (H, W) + # - class_path: otx.data.transform_libs.torchvision.RandomAffine + # init_args: + # border: $(input_size) * -0.5 + # - class_path: otx.data.transform_libs.torchvision.CachedMixUp + # init_args: + # img_scale: $(input_size) # (H, W) + # ratio_range: + # - 1.0 + # - 1.0 + # probability: 0.5 + # random_pop: false + # max_cached_images: 10 + # - class_path: torchvision.transforms.v2.RandomPhotometricDistort + # enable: false + # init_args: + # brightness: + # - 0.875 + # - 1.125 + # contrast: + # - 0.5 + # - 1.5 + # saturation: + # - 0.5 + # - 1.5 + # hue: + # - -0.05 + # - 0.05 + # p: 0.5 + # - class_path: otx.data.transform_libs.torchvision.YOLOXHSVRandomAug + # - class_path: otx.data.transform_libs.torchvision.RandomFlip + # init_args: + # probability: 0.5 + # is_numpy_to_tvtensor: false + # - class_path: otx.data.transform_libs.torchvision.Pad + # init_args: + # pad_to_square: true + # pad_val: 114 + # - class_path: torchvision.transforms.v2.RandomVerticalFlip + # enable: false + # init_args: + # p: 0.5 + # - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur + # enable: false + # init_args: + # kernel_size: 5 + # sigma: + # - 0.1 + # - 2.0 + # probability: 0.5 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} - - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise - enable: false - init_args: - mean: 0.0 - sigma: 0.1 - probability: 0.5 - - class_path: torchvision.transforms.v2.Normalize - init_args: - mean: [0.0, 0.0, 0.0] - std: [1.0, 1.0, 1.0] + scale: true + # - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise + # enable: false + # init_args: + # mean: 0.0 + # sigma: 0.1 + # probability: 0.5 + # - class_path: torchvision.transforms.v2.Normalize + # init_args: + # mean: [0.0, 0.0, 0.0] + # std: [1.0, 1.0, 1.0] sampler: class_path: otx.data.samplers.balanced_sampler.BalancedSampler