diff --git a/library/docs/source/guide/explanation/algorithms/object_detection/object_detection.rst b/library/docs/source/guide/explanation/algorithms/object_detection/object_detection.rst index e468a311a1e..c7cddb58cbb 100644 --- a/library/docs/source/guide/explanation/algorithms/object_detection/object_detection.rst +++ b/library/docs/source/guide/explanation/algorithms/object_detection/object_detection.rst @@ -2,28 +2,48 @@ Object Detection ================ Object detection is a computer vision task where it's needed to locate objects, finding their bounding boxes coordinates together with defining class. -The input is an image, and the output is a pair of coordinates for bouding box corners and a class number for each detected object. +The input is an image, and the output is a pair of coordinates for bounding box corners and a class number for each detected object. The common approach to building object detection architecture is to take a feature extractor (backbone), that can be inherited from the classification task. Then goes a head that calculates coordinates and class probabilities based on aggregated information from the image. Additionally, some architectures use `Feature Pyramid Network (FPN) `_ to transfer and process feature maps from backbone to head and called neck. -For the supervised training we use the following algorithms components: +******************* +Training Pipeline +******************* + +OTX supports various training configurations that can be customized per model. The default settings vary by model architecture +and are defined in the respective recipe files. To see the exact configuration for a specific model, run: + +.. code-block:: shell + + (otx) ...$ otx train --config --print_config .. _od_supervised_pipeline: -- ``Augmentations``: We use random crop and random rotate, simple bright and color distortions and multiscale training for the training pipeline. +Common training components include: + +- ``Augmentations``: Data augmentation strategies vary by model. Common techniques include random crop, rotation, affine transformations, color/brightness distortions, and advanced techniques like Mosaic and MixUp. + +- ``Optimizer``: Model-specific optimizers are used: + - **AdamW**: Used by transformer-based models (RT-DETR, D-FINE, DEIMv2) with learning rates typically in the range of 1e-4 to 5e-4. + - **SGD**: Used by CNN-based models (YOLOX, SSD, ATSS) with momentum ~0.9 and weight decay ~1e-4. -- ``Optimizer``: We use `SGD `_ optimizer with the weight decay set to **1e-4** and momentum set to **0.9**. +- ``Learning rate schedule``: `ReduceLROnPlateau `_ is commonly used for dataset-agnostic training. It reduces the learning rate when the validation metric plateaus. Many models also use warmup periods at the start of training. -- ``Learning rate schedule``: `ReduceLROnPlateau `_. This learning rate scheduler proved its efficiency in dataset-agnostic trainings, its logic is to drop LR after some time without improving the target accuracy metric. Also, we update it with ``iteration_patience`` parameter that ensures that a certain number of training iterations (steps through the dataset) were passed before dropping LR. +- ``Loss function``: Loss functions are architecture-specific: + - **Traditional detectors** (SSD, ATSS): `Generalized IoU Loss `_ for localization and `FocalLoss `_ for classification. + - **DETR-based models** (RT-DETR, D-FINE, DEIMv2): Hungarian matching with combined classification, L1 box, and GIoU losses. -- ``Loss function``: We use `Generalized IoU Loss `_ for localization loss to train the ability of the model to find the coordinates of the objects. For the classification head, we use a standard `FocalLoss `_. +- ``Additional training techniques``: + - ``Early stopping``: Prevents overfitting by stopping training when validation metrics stop improving. + - ``Backbone pretraining``: Most models use pretrained backbones (ImageNet, DINOv2/DINOv3) for better feature extraction. + - ``Multi-scale training``: Optional technique to improve robustness to different object sizes. -- ``Additional training techniques`` - - ``Early stopping``: To add adaptability to the training pipeline and prevent overfitting. - - `Anchor clustering for SSD `_: This model highly relies on predefined anchor boxes hyperparameter that impacts the size of objects, which can be detected. So before training, we collect object statistics within dataset, cluster them and modify anchor boxes sizes to fit the most for objects the model is going to detect. - - ``Backbone pretraining``: we pretrained MobileNetV2 backbone on large `ImageNet21k `_ dataset to improve feature extractor and learn better and faster. +.. note:: + + Training configurations are fully customizable. Override any setting via command line or by creating a custom recipe file. + See the :doc:`configuration guide <../../../tutorials/base/how_to_train/detection>` for details. ************** @@ -56,35 +76,82 @@ Models We support the following ready-to-use model recipes: -+------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ -| Recipe ID | Name | Complexity (GFLOPs) | Model size (MB) | -+============================================================================================================================================================+=====================+=====================+=================+ -| `Custom_Object_Detection_YOLOX `_ | YOLOX-TINY | 6.5 | 20.4 | -+------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ -| `Object_Detection_YOLOX_S `_ | YOLOX_S | 33.51 | 46.0 | -+------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ -| `Object_Detection_YOLOX_L `_ | YOLOX_L | 194.57 | 207.0 | -+------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ -| `Object_Detection_YOLOX_X `_ | YOLOX_X | 352.42 | 378.0 | -+------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ -| `Custom_Object_Detection_Gen3_SSD `_ | SSD | 9.4 | 7.6 | -+------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ -| `Custom_Object_Detection_Gen3_ATSS `_ | MobileNetV2-ATSS | 20.6 | 9.1 | -+------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ -| `Object_Detection_ResNeXt101_ATSS `_ | ResNeXt101-ATSS | 434.75 | 344.0 | -+------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ -| `D-Fine X Detection ` | D-Fine X | 202.486 | 240.0 | -+------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ - -Above table can be found using the following command - -.. code-block:: shell - - (otx) ...$ otx find --task DETECTION +.. note:: -`MobileNetV2-ATSS `_ is a good medium-range model that works well and fast in most cases. -`SSD `_ and `YOLOX `_ are light models, that a perfect for the fastest inference on low-power hardware. -YOLOX achieved the same accuracy as SSD, and even outperforms its inference on CPU 1.5 times, but requires 3 times more time for training due to `Mosaic augmentation `_, which is even more than for ATSS. -So if you have resources for a long training, you can pick the YOLOX model. -ATSS still shows good performance among `RetinaNet `_ based models. Therfore, We added ATSS with large scale backbone, ResNeXt101-ATSS. We integrated large ResNeXt101 backbone to our Custom ATSS head, and it shows good transfer learning performance. -In addition, we added a YOLOX variants to support users' diverse situations. + For the most up-to-date list of available models, run ``otx find --task DETECTION``. + +Transformer-based Models (DETR Family) +-------------------------------------- + +These models use the Detection Transformer (DETR) paradigm with end-to-end object detection. They eliminate the need for hand-designed components like anchor boxes and non-maximum suppression. + ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| Recipe | Name | Complexity (GFLOPs) | Model size (MB) | ++===============================================================================================================================+=====================+=====================+=================+ +| `deimv2_s `_ | DEIMv2-S | ~15 | ~25 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `deimv2_m `_ | DEIMv2-M | ~25 | ~35 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `deimv2_l `_ | DEIMv2-L | ~50 | ~60 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `deimv2_x `_ | DEIMv2-X | ~80 | ~90 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `deim_dfine_m `_ | DEIM-DFine-M | ~34 | ~52 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `deim_dfine_l `_ | DEIM-DFine-L | ~91 | ~124 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `deim_dfine_x `_ | DEIM-DFine-X | ~202 | ~240 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `dfine_x `_ | D-Fine X | 202.5 | 240 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `rtdetr_18 `_ | RT-DETR-18 | ~60 | ~80 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `rtdetr_50 `_ | RT-DETR-50 | ~136 | ~170 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ + +**DEIM Family Models:** + +- `DEIM-DFine `_ (v1): Uses HGNetV2 backbone with D-FINE decoder. Available in M/L/X variants with increasing accuracy. +- `DEIMv2 `_: An improved version that combines DINOv3/ViT backbones with an efficient DETR decoder: + - **DEIMv2-S/M**: Use lightweight ViT-Tiny backbones, ideal for edge deployment. + - **DEIMv2-L/X**: Use DINOv3 (ViT-S) backbones with self-supervised pretraining for higher accuracy. + +CNN-based Models +---------------- + +Traditional CNN-based detectors with anchor-based or anchor-free designs. + ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| Recipe | Name | Complexity (GFLOPs) | Model size (MB) | ++===============================================================================================================================+=====================+=====================+=================+ +| `yolox_tiny `_ | YOLOX-TINY | 6.5 | 20.4 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `yolox_s `_ | YOLOX-S | 33.5 | 46 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `yolox_l `_ | YOLOX-L | 194.6 | 207 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `yolox_x `_ | YOLOX-X | 352.4 | 378 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `ssd_mobilenetv2 `_ | SSD-MobileNetV2 | 9.4 | 7.6 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `atss_mobilenetv2 `_ | ATSS-MobileNetV2 | 20.6 | 9.1 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `atss_resnext101 `_ | ATSS-ResNeXt101 | 434.8 | 344 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ +| `rtmdet_tiny `_ | RTMDet-Tiny | ~8 | ~15 | ++-------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+ + +Model Selection Guide +--------------------- + +Choose a model based on your requirements: + +- **Best accuracy**: DEIMv2-X, DEIM-DFine-X, D-Fine X, or YOLOX-X +- **Best speed/accuracy trade-off**: DEIMv2-S, DEIMv2-M, DEIM-DFine-M, or YOLOX-L +- **Fastest inference**: YOLOX-TINY, YOLOX-S, SSD-MobileNetV2, or ATSS-MobileNetV2 + +**Recommendations:** + +- For **transformer-based models**, the DEIM family (DEIMv2 and DEIM-DFine) provides state-of-the-art accuracy with excellent inference speed. DEIMv2-S/M are particularly well-suited for real-time and edge deployment scenarios. +- For **CNN-based models**, `YOLOX `_ offers an excellent speed-accuracy trade-off and strong performance across different benchmark datasets. + `MobileNetV2-ATSS `_ and `SSD `_ are also good choices for resource-constrained environments. diff --git a/library/src/otx/backend/native/callbacks/__init__.py b/library/src/otx/backend/native/callbacks/__init__.py index 55fec9ad0be..21b60ff8985 100644 --- a/library/src/otx/backend/native/callbacks/__init__.py +++ b/library/src/otx/backend/native/callbacks/__init__.py @@ -4,5 +4,6 @@ """Module for OTX custom callbacks.""" from .batchsize_finder import BatchSizeFinder +from .cuda_cache_cleaner import CUDACacheCleaner -__all__ = ["BatchSizeFinder"] +__all__ = ["BatchSizeFinder", "CUDACacheCleaner"] diff --git a/library/src/otx/backend/native/callbacks/cuda_cache_cleaner.py b/library/src/otx/backend/native/callbacks/cuda_cache_cleaner.py new file mode 100644 index 00000000000..60d4cf2434c --- /dev/null +++ b/library/src/otx/backend/native/callbacks/cuda_cache_cleaner.py @@ -0,0 +1,114 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""CUDA Cache Cleaner callback for memory management during training.""" + +from __future__ import annotations + +import gc +import logging + +import torch +from lightning import Callback, LightningModule, Trainer + +from otx.data.entity import OTXDataBatch + +logger = logging.getLogger(__name__) + + +class CUDACacheCleaner(Callback): + """Callback to periodically clean CUDA cache to reduce memory fragmentation. + + This callback can help reduce memory usage by clearing the CUDA cache at strategic + points during training. However, use with caution as frequent cache clearing can + slow down training due to memory reallocation overhead. + + Recommended usage: + - Set clean_on_validation_end=True (default) - Most beneficial, frees eval memory + - Set clean_on_epoch_end=True only if experiencing OOM between epochs + - Avoid clean_on_train_batch_end unless absolutely necessary (performance impact) + + Args: + clean_on_epoch_end: Clean cache at the end of each training epoch. + Defaults to False. + clean_on_validation_end: Clean cache after validation. Defaults to True. + clean_on_train_batch_end: Clean cache after each training batch. + WARNING: This significantly slows down training. Defaults to False. + clean_every_n_epochs: Only clean every N epochs (if epoch cleaning enabled). + Defaults to 1. + clean_every_n_batches: Only clean every N batches (if batch cleaning enabled). + Defaults to 100. + run_gc: Also run Python garbage collection before clearing cache. + Defaults to True. + log_memory: Log memory usage before/after cleaning. Defaults to False. + """ + + def __init__( + self, + clean_on_epoch_end: bool = False, + clean_on_validation_end: bool = True, + clean_on_train_batch_end: bool = False, + clean_every_n_epochs: int = 1, + clean_every_n_batches: int = 100, + run_gc: bool = True, + log_memory: bool = False, + ) -> None: + super().__init__() + self.clean_on_epoch_end = clean_on_epoch_end + self.clean_on_validation_end = clean_on_validation_end + self.clean_on_train_batch_end = clean_on_train_batch_end + self.clean_every_n_epochs = clean_every_n_epochs + self.clean_every_n_batches = clean_every_n_batches + self.run_gc = run_gc + self.log_memory = log_memory + + def _clean_cache(self, stage: str) -> None: + """Clean CUDA cache and optionally run garbage collection. + + Args: + stage: Description of when cleaning is happening (for logging). + """ + if not torch.cuda.is_available(): + return + + if self.log_memory: + before_allocated = torch.cuda.memory_allocated() / 1024**3 + before_reserved = torch.cuda.memory_reserved() / 1024**3 + + if self.run_gc: + gc.collect() + + torch.cuda.empty_cache() + + if self.log_memory: + after_allocated = torch.cuda.memory_allocated() / 1024**3 + after_reserved = torch.cuda.memory_reserved() / 1024**3 + freed = before_reserved - after_reserved + logger.info( + f"[{stage}] CUDA cache cleaned. " + f"Allocated: {before_allocated:.2f}GB -> {after_allocated:.2f}GB, " + f"Reserved: {before_reserved:.2f}GB -> {after_reserved:.2f}GB, " + f"Freed: {freed:.2f}GB" + ) + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Clean cache at the end of training epoch if enabled.""" + if self.clean_on_epoch_end and (trainer.current_epoch + 1) % self.clean_every_n_epochs == 0: + self._clean_cache(f"epoch_{trainer.current_epoch}_end") + + def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Clean cache after validation if enabled.""" + if self.clean_on_validation_end: + self._clean_cache("validation_end") + + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: OTXDataBatch, + batch: OTXDataBatch, + batch_idx: int, + ) -> None: + """Clean cache after training batch if enabled (use with caution).""" + if self.clean_on_train_batch_end and (batch_idx + 1) % self.clean_every_n_batches == 0: + self._clean_cache(f"batch_{batch_idx}_end") diff --git a/library/src/otx/backend/native/engine.py b/library/src/otx/backend/native/engine.py index 2f5c69589cf..3f3511c06d0 100644 --- a/library/src/otx/backend/native/engine.py +++ b/library/src/otx/backend/native/engine.py @@ -166,7 +166,7 @@ def train( min_epochs: int = 1, seed: int | None = None, deterministic: bool | Literal["warn"] = False, - precision: _PRECISION_INPUT | None = 16, + precision: _PRECISION_INPUT | None = "bf16-mixed", callbacks: list[Callback] | Callback | None = None, logger: Logger | Iterable[Logger] | bool | None = None, resume: bool = False, diff --git a/library/src/otx/backend/native/models/__init__.py b/library/src/otx/backend/native/models/__init__.py index c24a1dafad5..ee3a004b447 100644 --- a/library/src/otx/backend/native/models/__init__.py +++ b/library/src/otx/backend/native/models/__init__.py @@ -10,13 +10,14 @@ TVModel, VisionTransformer, ) -from .detection import ATSS, RTDETR, SSD, YOLOX, DEIMDFine, DFine, RTMDet +from .detection import ATSS, DEIMV2, RTDETR, SSD, YOLOX, DEIMDFine, DFine, RTMDet from .instance_segmentation import MaskRCNN, MaskRCNNTV, RTMDetInst from .keypoint_detection import RTMPose from .segmentation import DinoV2Seg, LiteHRNet, SegNext __all__ = [ "ATSS", + "DEIMV2", "RTDETR", "SSD", "YOLOX", diff --git a/library/src/otx/backend/native/models/base.py b/library/src/otx/backend/native/models/base.py index 093cc239121..58fa0d83bbe 100644 --- a/library/src/otx/backend/native/models/base.py +++ b/library/src/otx/backend/native/models/base.py @@ -141,6 +141,7 @@ def __init__( metric: MetricCallable = NullMetricCallable, torch_compile: bool = False, tile_config: TileConfig | dict = TileConfig(enable_tiler=False), + log_total_loss_only: bool = True, ) -> None: """Initialize the base model with the given parameters. @@ -167,6 +168,7 @@ def __init__( self._label_info = self._dispatch_label_info(label_info) self.model_name = model_name + self.log_total_loss_only = log_total_loss_only if isinstance(data_input_params, dict): data_input_params = DataInputParams(**data_input_params) elif data_input_params is None: @@ -212,14 +214,15 @@ def training_step(self, batch: OTXDataBatch, batch_idx: int) -> Tensor: ) return train_loss if isinstance(train_loss, dict): - for k, v in train_loss.items(): - self.log( - f"train/{k}", - v, - on_step=True, - on_epoch=False, - prog_bar=True, - ) + if not self.log_total_loss_only: + for k, v in train_loss.items(): + self.log( + f"train/{k}", + v, + on_step=True, + on_epoch=False, + prog_bar=True, + ) total_train_loss = train_loss.get("total_loss", sum(train_loss.values())) self.log( diff --git a/library/src/otx/backend/native/models/classification/utils/embed.py b/library/src/otx/backend/native/models/classification/utils/embed.py deleted file mode 100644 index 38b607a8f16..00000000000 --- a/library/src/otx/backend/native/models/classification/utils/embed.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -# Copyright (c) OpenMMLab. All rights reserved. - -"""Copy from mmpretrain/models/utils/embed.py.""" - -import torch -from torch.nn import functional - - -def resize_pos_embed( - pos_embed: torch.Tensor, - src_shape: tuple, - dst_shape: tuple, - mode: str = "bicubic", - num_extra_tokens: int = 1, -) -> torch.Tensor: - """Resize pos_embed weights. - - Args: - pos_embed (torch.Tensor): Position embedding weights with shape - [1, L, C]. - src_shape (tuple): The resolution of downsampled origin training - image, in format (H, W). - dst_shape (tuple): The resolution of downsampled new training - image, in format (H, W). - mode (str): Algorithm used for upsampling. Choose one from 'nearest', - 'linear', 'bilinear', 'bicubic' and 'trilinear'. - Defaults to 'bicubic'. - num_extra_tokens (int): The number of extra tokens, such as cls_token. - Defaults to 1. - - Returns: - torch.Tensor: The resized pos_embed of shape [1, L_new, C] - """ - if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: - return pos_embed - assert pos_embed.ndim == 3, "shape of pos_embed must be [1, L, C]" # noqa: S101 - _, L, C = pos_embed.shape # noqa: N806 - src_h, src_w = src_shape - assert src_h * src_w + num_extra_tokens == L, ( # noqa: S101 - f"The length of `pos_embed` ({L}) doesn't match the expected " - f"shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the " - "`img_size` argument." - ) - extra_tokens = pos_embed[:, :num_extra_tokens] - - src_weight = pos_embed[:, num_extra_tokens:] - src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) - - # The cubic interpolate algorithm only accepts float32 - dst_weight = functional.interpolate(src_weight.float(), size=dst_shape, align_corners=False, mode=mode) - dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) - dst_weight = dst_weight.to(src_weight.dtype) - - return torch.cat((extra_tokens, dst_weight), dim=1) diff --git a/library/src/otx/backend/native/models/classification/utils/swiglu_ffn.py b/library/src/otx/backend/native/models/classification/utils/swiglu_ffn.py index 9139419e6bd..fbf25da55b3 100644 --- a/library/src/otx/backend/native/models/classification/utils/swiglu_ffn.py +++ b/library/src/otx/backend/native/models/classification/utils/swiglu_ffn.py @@ -12,6 +12,7 @@ import torch from torch import nn +from otx.backend.native.models.common.layers.transformer_layers import ListForwardMixin from otx.backend.native.models.modules.drop import build_dropout from otx.backend.native.models.modules.norm import build_norm_layer @@ -100,3 +101,52 @@ def __init__( out_dims=out_dims, bias=bias, ) + + +class SwiGLUFFNV2(nn.Module, ListForwardMixin): + """SwiGLUFFN module. + + Args: + in_features (int): Input features. + hidden_features (int | None, optional): Hidden features. Defaults to None. + out_features (int | None, optional): Output features. Defaults to None. + act_layer (Callable[..., nn.Module] | None, optional): Activation layer. Defaults to None. + drop (float, optional): Dropout rate. Defaults to 0.0. + bias (bool, optional): Whether to use bias. Defaults to True. + align_to (int, optional): Number of columns to align the hidden features to. Defaults to 8. + device (torch.device, optional): Device to use. Defaults to None. + """ + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: Callable[..., nn.Module] | None = None, + drop: float = 0.0, + bias: bool = True, + align_to: int = 8, + device: torch.device | str | None = None, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + d = int(hidden_features * 2 / 3) + swiglu_hidden_features = d + (-d % align_to) + self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) + self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) + self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply SwiGLU transformation to input tensor. + + Args: + x: Input tensor of shape (..., in_features). + + Returns: + Output tensor of shape (..., out_features). + """ + x1 = self.w1(x) + x2 = self.w2(x) + hidden = nn.functional.silu(x1) * x2 + return self.w3(hidden) diff --git a/library/src/otx/backend/native/models/common/backbones/dinov3.py b/library/src/otx/backend/native/models/common/backbones/dinov3.py new file mode 100644 index 00000000000..161ebafdf4b --- /dev/null +++ b/library/src/otx/backend/native/models/common/backbones/dinov3.py @@ -0,0 +1,536 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This software may be used and distributed in accordance with +# the terms of the DINOv3 License Agreement. + +"""DINOv3 Vision Transformer backbone implementation. + +This module implements the DINOv3 Vision Transformer architecture with +RoPE (Rotary Position Embedding) and various configuration options. +""" + +from __future__ import annotations + +import logging +from enum import Enum +from functools import partial +from typing import Any, Callable + +import torch +import torch.nn.init +from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint as gradient_checkpoint + +from otx.backend.native.models.classification.utils.swiglu_ffn import SwiGLUFFNV2 +from otx.backend.native.models.common.layers.position_embed import RopePositionEmbedding +from otx.backend.native.models.common.layers.transformer_layers import MLP2L, LayerScale, SelfAttentionBlock +from otx.backend.native.models.modules.transformer import UnflattenPatchEmbed as PatchEmbed + + +def named_apply( + fn: Callable, + module: nn.Module, + name: str = "", + depth_first: bool = True, + include_root: bool = False, +) -> nn.Module: + """Apply a function to all named modules recursively. + + Args: + fn: Function to apply, should accept `module` and `name` kwargs. + module: Root module to start from. + name: Name prefix for the root module. + depth_first: If True, apply in depth-first order. + include_root: If True, also apply to the root module. + + Returns: + The input module (for chaining). + """ + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + full_name = f"{name}.{child_name}" if name else child_name + named_apply( + fn=fn, + module=child_module, + name=full_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class Weights(Enum): + """Pretrained weight options for DINOv3 models.""" + + LVD1689M = "LVD1689M" + SAT493M = "SAT493M" + + +#: Configuration dictionary mapping model names to their hyperparameters. +configs: dict[str, dict[str, Any]] = { + "dinov3_vits16": { + "img_size": 224, + "patch_size": 16, + "in_chans": 3, + "pos_embed_rope_base": 100, + "pos_embed_rope_normalize_coords": "separate", + "pos_embed_rope_rescale_coords": 2, + "pos_embed_rope_dtype": "fp32", + "embed_dim": 384, + "depth": 12, + "num_heads": 6, + "ffn_ratio": 4, + "qkv_bias": True, + "drop_path_rate": 0.0, + "layerscale_init": 1.0e-05, + "norm_layer": "layernormbf16", + "ffn_layer": "mlp", + "ffn_bias": True, + "proj_bias": True, + "n_storage_tokens": 4, + "mask_k_bias": True, + "pretrained": True, + "weights": Weights.LVD1689M, + "compact_arch_name": "vits", + "check_hash": False, + }, + "dinov3_vits16plus": { + "img_size": 224, + "patch_size": 16, + "in_chans": 3, + "pos_embed_rope_base": 100, + "pos_embed_rope_normalize_coords": "separate", + "pos_embed_rope_rescale_coords": 2, + "pos_embed_rope_dtype": "fp32", + "embed_dim": 384, + "depth": 12, + "num_heads": 6, + "ffn_ratio": 6, + "qkv_bias": True, + "drop_path_rate": 0.0, + "layerscale_init": 1.0e-05, + "norm_layer": "layernormbf16", + "ffn_layer": "swiglu", + "ffn_bias": True, + "proj_bias": True, + "n_storage_tokens": 4, + "mask_k_bias": True, + "pretrained": True, + "weights": Weights.LVD1689M, + "compact_arch_name": "vitsplus", + "check_hash": False, + }, +} + +logger = logging.getLogger("dinov3") + +#: Mapping from string FFN layer names to their class implementations. +ffn_layer_dict: dict[str, type | partial] = { + "mlp": MLP2L, + "swiglu": SwiGLUFFNV2, + "swiglu32": partial(SwiGLUFFNV2, align_to=32), + "swiglu64": partial(SwiGLUFFNV2, align_to=64), + "swiglu128": partial(SwiGLUFFNV2, align_to=128), +} + +#: Mapping from string norm layer names to their class implementations. +norm_layer_dict: dict[str, type | partial] = { + "layernorm": partial(nn.LayerNorm, eps=1e-6), + "layernormbf16": partial(nn.LayerNorm, eps=1e-5), +} + +#: Mapping from string dtype names to torch dtype objects. +dtype_dict: dict[str, torch.dtype] = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def init_weights_vit(module: nn.Module, name: str = "") -> None: # noqa: ARG001 + """Initialize Vision Transformer module weights. + + Applies truncated normal initialization to Linear layers, and calls + reset_parameters on LayerNorm, LayerScale and PatchEmbed. + + Args: + module: The module to initialize. + name: Name of the module (unused, for compatibility with named_apply). + """ + if isinstance(module, nn.Linear): + torch.nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + if isinstance(module, nn.LayerNorm): + module.reset_parameters() + if isinstance(module, LayerScale): + module.reset_parameters() + if isinstance(module, PatchEmbed): + module.reset_parameters() + + +class DinoVisionTransformer(nn.Module): + """DINOv3 Vision Transformer backbone. + + A Vision Transformer with RoPE (Rotary Position Embedding), optional + SwiGLU FFN layers, and LayerScale. Designed for self-supervised learning + with the DINOv3 methodology. + + Args: + name: Model configuration name from the configs dictionary. + Supported: 'dinov3_vits16', 'dinov3_vits16plus', 'dinov3_vitb16', + 'dinov3_vitb16plus', 'dinov3_vitl16plus'. + gradient_checkpointing: If True, use gradient checkpointing to reduce memory + at the cost of increased computation. Defaults to False. + """ + + def __init__( + self, + name: str, + gradient_checkpointing: bool = False, + ) -> None: + super().__init__() + self.gradient_checkpointing = gradient_checkpointing + config = configs[name] + img_size = config["img_size"] + patch_size = config["patch_size"] + in_chans = config["in_chans"] + pos_embed_rope_min_period = None + pos_embed_rope_max_period = None + pos_embed_rope_shift_coords = None + pos_embed_rope_jitter_coords = None + pos_embed_rope_rescale_coords = None + pos_embed_rope_base = config["pos_embed_rope_base"] + pos_embed_rope_normalize_coords = config["pos_embed_rope_normalize_coords"] + pos_embed_rope_rescale_coords = config["pos_embed_rope_rescale_coords"] + pos_embed_rope_dtype = config["pos_embed_rope_dtype"] + embed_dim = config["embed_dim"] + depth = config["depth"] + num_heads = config["num_heads"] + ffn_ratio = config["ffn_ratio"] + qkv_bias = config["qkv_bias"] + drop_path_rate = config["drop_path_rate"] + layerscale_init = config["layerscale_init"] + norm_layer = config["norm_layer"] + ffn_layer = config["ffn_layer"] + ffn_bias = config["ffn_bias"] + proj_bias = config["proj_bias"] + n_storage_tokens = config["n_storage_tokens"] + mask_k_bias = config["mask_k_bias"] + untie_cls_and_patch_norms = False + untie_global_and_local_cls_norm = False + device = None + + norm_layer_cls = norm_layer_dict[norm_layer] + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + flatten_embedding=False, + ) + + self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device)) + self.n_storage_tokens = n_storage_tokens + if self.n_storage_tokens > 0: + self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device)) + + self.rope_embed = RopePositionEmbedding( + embed_dim=embed_dim, + num_heads=num_heads, + base=pos_embed_rope_base, + min_period=pos_embed_rope_min_period, + max_period=pos_embed_rope_max_period, + normalize_coords=pos_embed_rope_normalize_coords, + shift_coords=pos_embed_rope_shift_coords, + jitter_coords=pos_embed_rope_jitter_coords, + rescale_coords=pos_embed_rope_rescale_coords, + dtype=dtype_dict[pos_embed_rope_dtype], + device=device, + ) + ffn_layer_cls = ffn_layer_dict[ffn_layer] + ffn_ratio_sequence = [ffn_ratio] * depth + blocks_list = [ + SelfAttentionBlock( + dim=embed_dim, + num_heads=num_heads, + ffn_ratio=ffn_ratio_sequence[i], + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=drop_path_rate, + norm_layer=norm_layer_cls, + act_layer=nn.GELU, + ffn_layer=ffn_layer_cls, + init_values=layerscale_init, + mask_k_bias=mask_k_bias, + device=device, + ) + for i in range(depth) + ] + + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + # This norm is applied to everything, or when untying, to patch and mask tokens. + self.norm = norm_layer_cls(embed_dim) + + self.untie_cls_and_patch_norms = untie_cls_and_patch_norms + if untie_cls_and_patch_norms: + # When untying, this norm is applied to CLS tokens and registers. + self.cls_norm = norm_layer_cls(embed_dim) + else: + self.cls_norm = None + + self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm + if untie_global_and_local_cls_norm: + # When untying, this norm is applied to local CLS tokens and registers. + # This norm is never used during eval. + self.local_cls_norm = norm_layer_cls(embed_dim) + else: + self.local_cls_norm = None + self.head = nn.Identity() + self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device)) + + self.init_weights() + + def init_weights(self) -> None: + """Initialize model weights with proper initialization schemes.""" + self.rope_embed._init_weights() # noqa: SLF001 + nn.init.normal_(self.cls_token, std=0.02) + if self.n_storage_tokens > 0: + nn.init.normal_(self.storage_tokens, std=0.02) + nn.init.zeros_(self.mask_token) + named_apply(init_weights_vit, self) + + def prepare_tokens_with_masks(self, x: Tensor, masks: Tensor | None = None) -> tuple[Tensor, tuple[int, int]]: + """Prepare input tokens with optional mask tokens. + + Args: + x: Input image tensor of shape (B, C, H, W). + masks: Optional boolean mask tensor for masked image modeling. + + Returns: + Tuple of (tokens, (H, W)) where tokens has shape (B, N, D) with + cls_token, storage_tokens, and patch tokens concatenated. + """ + x = self.patch_embed(x) + B, H, W, _ = x.shape # noqa: N806 + x = x.flatten(1, 2) + + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + cls_token = self.cls_token + else: + cls_token = self.cls_token + 0 * self.mask_token + if self.n_storage_tokens > 0: + storage_tokens = self.storage_tokens + else: + storage_tokens = torch.empty( + 1, + 0, + cls_token.shape[-1], + dtype=cls_token.dtype, + device=cls_token.device, + ) + + x = torch.cat( + [ + cls_token.expand(B, -1, -1), + storage_tokens.expand(B, -1, -1), + x, + ], + dim=1, + ) + + return x, (H, W) + + def forward_features_list(self, x_list: list[Tensor], masks_list: list[Tensor | None]) -> list[dict[str, Tensor]]: + """Forward pass for a list of images with masks. + + Args: + x_list: List of input image tensors. + masks_list: List of corresponding mask tensors (can be None). + + Returns: + List of dictionaries containing normalized features. + """ + x = [] + rope = [] + for t_x, t_masks in zip(x_list, masks_list): + t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks) + x.append(t2_x) + rope.append(hw_tuple) + for _, blk in enumerate(self.blocks): + if self.rope_embed is not None: + rope_sincos = [self.rope_embed(h=h, w=w) for h, w in rope] + else: + rope_sincos = [None for r in rope] + if self.training and self.gradient_checkpointing: + x = gradient_checkpoint(blk, x, rope_sincos, use_reentrant=False) + else: + x = blk(x, rope_sincos) + all_x = x + output = [] + for idx, (x, masks) in enumerate(zip(all_x, masks_list)): + if self.untie_cls_and_patch_norms or self.untie_global_and_local_cls_norm: + if self.untie_global_and_local_cls_norm and self.training and idx == 1: + # Assume second entry of list corresponds to local crops. + # We only ever apply this during training. + x_norm_cls_reg = self.local_cls_norm(x[:, : self.n_storage_tokens + 1]) # type: ignore[call-overload] + elif self.untie_cls_and_patch_norms: + x_norm_cls_reg = self.cls_norm(x[:, : self.n_storage_tokens + 1]) # type: ignore[call-overload] + else: + x_norm_cls_reg = self.norm(x[:, : self.n_storage_tokens + 1]) # type: ignore[call-overload] + x_norm_patch = self.norm(x[:, self.n_storage_tokens + 1 :]) # type: ignore[call-overload] + else: + x_norm = self.norm(x) + x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1] # type: ignore[call-overload] + x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :] # type: ignore[call-overload] + output.append( + { + "x_norm_clstoken": x_norm_cls_reg[:, 0], + "x_storage_tokens": x_norm_cls_reg[:, 1:], + "x_norm_patchtokens": x_norm_patch, + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features( + self, + x: Tensor | list[Tensor], + masks: Tensor | list[Tensor] | None = None, + ) -> dict[str, Tensor] | list[dict[str, Tensor]]: + """Extract features from input images. + + Args: + x: Input image tensor or list of tensors. + masks: Optional mask tensor for masked image modeling. + + Returns: + Dictionary (single tensor) or list of dictionaries containing + normalized CLS token, storage tokens, patch tokens, and pre-norm features. + """ + if isinstance(x, torch.Tensor): + masks_as_list: list[Tensor | None] = [masks] + return self.forward_features_list([x], masks_as_list)[0] + return self.forward_features_list(x, masks if masks is not None else [None] * len(x)) + + def _get_intermediate_layers_not_chunked(self, x: Tensor, n: int | list[int] = 1) -> list[Tensor]: + """Get intermediate layer outputs without chunking. + + Args: + x: Input tensor. + n: Number of last layers to return, or list of layer indices. + + Returns: + List of intermediate feature tensors. + """ + x, (h, w) = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output: list[Tensor] = [] + total_block_len = len(self.blocks) + blocks_to_take: range | list[int] = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + rope_sincos = self.rope_embed(h=h, w=w) if self.rope_embed is not None else None + if self.training and self.gradient_checkpointing: + # Use gradient checkpointing to reduce memory during training + x = gradient_checkpoint(blk, x, rope_sincos, use_reentrant=False) + else: + x = blk(x, rope_sincos) + if i in blocks_to_take: + output.append(x) + if len(output) != len(blocks_to_take): + msg = f"only {len(output)} / {len(blocks_to_take)} blocks found" + raise RuntimeError(msg) + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + *, + n: int | list[int] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + return_extra_tokens: bool = False, + norm: bool = True, + ) -> tuple[torch.Tensor | tuple[torch.Tensor, ...], ...]: + """Get intermediate layer representations. + + Args: + x: Input image tensor of shape (B, C, H, W). + n: Number of last layers to return, or list of specific layer indices. + reshape: If True, reshape outputs to spatial format (B, C, H', W'). + return_class_token: If True, also return class tokens. + return_extra_tokens: If True, also return extra/storage tokens. + norm: If True, apply layer normalization to outputs. + + Returns: + Tuple of outputs. Format depends on return flags: + - Default: (outputs,) for each layer + - With class token: ((output, cls_token),) for each layer + - With extra tokens: ((output, extra_tokens),) for each layer + - Both: ((output, cls_token, extra_tokens),) for each layer + """ + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs_normed = [] + for out in outputs: + if self.untie_cls_and_patch_norms: + x_norm_cls_reg = self.cls_norm(out[:, : self.n_storage_tokens + 1]) + x_norm_patch = self.norm(out[:, self.n_storage_tokens + 1 :]) + outputs_normed.append(torch.cat((x_norm_cls_reg, x_norm_patch), dim=1)) + else: + outputs_normed.append(self.norm(out)) + outputs = outputs_normed + class_tokens = [out[:, 0] for out in outputs] + extra_tokens = [out[:, 1 : self.n_storage_tokens + 1] for out in outputs] + outputs = [out[:, self.n_storage_tokens + 1 :] for out in outputs] + if reshape: + b, _, h, w = x.shape + outputs = [ + out.reshape(b, h // self.patch_size, w // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token and not return_extra_tokens: + return tuple(zip(outputs, class_tokens)) + if not return_class_token and return_extra_tokens: + return tuple(zip(outputs, extra_tokens)) + if return_class_token and return_extra_tokens: + return tuple(zip(outputs, class_tokens, extra_tokens)) + return tuple(outputs) + + def forward( + self, + *args: Any, # noqa: ANN401 + is_training: bool = False, + **kwargs: Any, # noqa: ANN401 + ) -> dict[str, Tensor] | list[dict[str, Tensor]] | Tensor: + """Forward pass through the model. + + Args: + *args: Positional arguments passed to forward_features. + is_training: If True, return full feature dict; otherwise return + classification head output. + **kwargs: Keyword arguments passed to forward_features. + + Returns: + Feature dictionary during training, or head output during inference. + """ + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + if isinstance(ret, list): + ret = ret[0] + return self.head(ret["x_norm_clstoken"]) diff --git a/library/src/otx/backend/native/models/common/layers/position_embed.py b/library/src/otx/backend/native/models/common/layers/position_embed.py index 7601f1eb29a..4af32ddf90a 100644 --- a/library/src/otx/backend/native/models/common/layers/position_embed.py +++ b/library/src/otx/backend/native/models/common/layers/position_embed.py @@ -6,9 +6,11 @@ from __future__ import annotations import math +from typing import Literal +import numpy as np import torch -from torch import nn +from torch import Tensor, nn class PositionEmbeddingSine(nn.Module): @@ -105,3 +107,147 @@ def gen_sineembed_for_position(pos_tensor: torch.Tensor) -> torch.Tensor: msg = f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}" raise ValueError(msg) return pos + + +class RopePositionEmbedding(nn.Module): + """Rotary Position Embedding (RoPE) for Vision Transformers. + + Computes sinusoidal position embeddings that are applied via rotation + to query and key vectors in attention layers. + + Args: + embed_dim: Total embedding dimension. + num_heads: Number of attention heads. + base: Base frequency for computing periods. + min_period: Minimum period (alternative to base). + max_period: Maximum period (alternative to base). + normalize_coords: How to normalize coordinates ('min', 'max', 'separate'). + shift_coords: Optional shift to apply to coordinates. + jitter_coords: Optional jitter range for data augmentation. + rescale_coords: Optional rescaling factor for coordinates. + dtype: Data type for embeddings. + device: Device for embeddings. + """ + + def __init__( + self, + embed_dim: int, + *, + num_heads: int, + base: float | None = 100.0, + min_period: float | None = None, + max_period: float | None = None, + normalize_coords: Literal["min", "max", "separate"] = "separate", + shift_coords: float | None = None, + jitter_coords: float | None = None, + rescale_coords: float | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ): + super().__init__() + if embed_dim % (4 * num_heads) != 0: + msg = f"embed_dim ({embed_dim}) must be divisible by 4 * num_heads ({4 * num_heads})" + raise ValueError(msg) + both_periods = min_period is not None and max_period is not None + if (base is None and not both_periods) or (base is not None and both_periods): + msg = "Either `base` or `min_period`+`max_period` must be provided." + raise ValueError(msg) + + d_head = embed_dim // num_heads + self.base = base + self.min_period = min_period + self.max_period = max_period + self.D_head = d_head + self.normalize_coords = normalize_coords + self.shift_coords = shift_coords + self.jitter_coords = jitter_coords + self.rescale_coords = rescale_coords + + # Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher + self.dtype = dtype # Don't rely on self.periods.dtype + self.register_buffer( + "periods", + torch.empty(d_head // 4, device=device, dtype=dtype), + persistent=True, + ) + self._init_weights() + + def forward(self, *, h: int, w: int) -> tuple[Tensor, Tensor]: + """Compute sin and cos position embeddings. + + Args: + H: Height of the feature map. + W: Width of the feature map. + + Returns: + Tuple of (sin, cos) tensors for rotary position embedding. + """ + device = self.periods.device + dtype = self.dtype + dd = {"device": device, "dtype": dtype} + + # Prepare coords in range [-1, +1] + if self.normalize_coords == "max": + max_hw = max(h, w) + coords_h = torch.arange(0.5, h, **dd) / max_hw # [H] + coords_w = torch.arange(0.5, w, **dd) / max_hw # [W] + elif self.normalize_coords == "min": + min_hw = min(h, w) + coords_h = torch.arange(0.5, h, **dd) / min_hw # [H] + coords_w = torch.arange(0.5, w, **dd) / min_hw # [W] + elif self.normalize_coords == "separate": + coords_h = torch.arange(0.5, h, **dd) / h # [h] + coords_w = torch.arange(0.5, w, **dd) / w # [W] + else: + msg = f"Unknown normalize_coords: {self.normalize_coords}" + raise ValueError(msg) + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2] + coords = coords.flatten(0, 1) # [HW, 2] + coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1] + + # Shift coords by adding a uniform value in [-shift, shift] + if self.training and self.shift_coords is not None: + shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords) + coords += shift_hw[None, :] + + # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] + if self.training and self.jitter_coords is not None: + jitter_max = np.log(self.jitter_coords) + jitter_min = -jitter_max + jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() + coords *= jitter_hw[None, :] + + # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] + if self.training and self.rescale_coords is not None: + rescale_max = np.log(self.rescale_coords) + rescale_min = -rescale_max + rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() + coords *= rescale_hw + + # Prepare angles and sin/cos + angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] # [HW, 2, D//4] + angles = angles.flatten(1, 2) # [HW, D//2] + angles = angles.tile(2) # [HW, D] + cos = torch.cos(angles) # [HW, D] + sin = torch.sin(angles) # [HW, D] + + return (sin, cos) # 2 * [HW, D] + + def _init_weights(self) -> None: + device = self.periods.device + dtype = self.dtype + if self.base is not None: + periods = self.base ** ( + 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2) + ) # [D//4] + else: + # min_period and max_period are guaranteed to be set when base is None + if self.min_period is None or self.max_period is None: + msg = "min_period and max_period must be set when base is None" + raise RuntimeError(msg) + base = self.max_period / self.min_period + exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) # [D//4] range [0, 1] + periods = base**exponents # range [1, max_period / min_period] + periods = periods / base # range [min_period / max_period, 1] + periods = periods * self.max_period # range [min_period, max_period] + self.periods.data = periods diff --git a/library/src/otx/backend/native/models/common/layers/transformer_layers.py b/library/src/otx/backend/native/models/common/layers/transformer_layers.py index 8945cf01458..34c3997bc88 100644 --- a/library/src/otx/backend/native/models/common/layers/transformer_layers.py +++ b/library/src/otx/backend/native/models/common/layers/transformer_layers.py @@ -7,15 +7,21 @@ import copy import math -from typing import Callable +from functools import partial +from typing import Any, Callable, NoReturn import torch import torch.nn.functional as f +import torchvision from torch import Tensor, nn from torch.nn import init -from otx.backend.native.models.common.utils.utils import get_clones -from otx.backend.native.models.modules.transformer import deformable_attention_core_func +from otx.backend.native.models.common.utils.utils import get_clones, inverse_sigmoid +from otx.backend.native.models.modules.norm import RMSNorm +from otx.backend.native.models.modules.transformer import ( + deformable_attention_core_func, +) +from otx.backend.native.models.utils.weight_init import bias_init_with_prob class TransformerEncoderLayer(nn.Module): @@ -89,6 +95,58 @@ def forward( return src +class ListForwardMixin: + """Mixin class that provides list-based forward operations for transformers.""" + + def forward(self, x: Tensor) -> NoReturn: + """Forward pass - must be implemented by subclass.""" + raise NotImplementedError + + def forward_list(self, x_list: list[Tensor]) -> list[Tensor]: + """Process a list of tensors by concatenating, forwarding, and splitting. + + Args: + x_list: List of input tensors. + + Returns: + List of processed tensors with original shapes. + """ + x_flat, shapes, num_tokens = cat_keep_shapes(x_list) + x_flat = self.forward(x_flat) + return uncat_with_shapes(x_flat, shapes, num_tokens) + + +class LayerScale(nn.Module): + """Learnable per-channel scaling layer for transformer blocks. + + Args: + dim: Number of channels/features. + init_values: Initial scale value. + inplace: If True, apply scaling in-place. + device: Device for parameters. + """ + + def __init__( + self, + dim: int, + init_values: float | Tensor = 1e-5, + inplace: bool = False, + device: torch.device | str | None = None, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(torch.empty(dim, device=device)) + self.init_values = init_values + + def reset_parameters(self) -> None: + """Reset gamma parameter to initial value.""" + nn.init.constant_(self.gamma, self.init_values) + + def forward(self, x: Tensor) -> Tensor: + """Apply learnable scaling to input tensor.""" + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + class TransformerEncoder(nn.Module): """TransformerEncoder.""" @@ -160,6 +218,55 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class MLP2L(nn.Module, ListForwardMixin): + """Multi-Layer Perceptron for Vision Transformer with 2 fixed layers. + + A simple two-layer MLP with configurable hidden dimension and activation. + + Args: + in_features: Number of input features. + hidden_features: Number of hidden features. Defaults to in_features. + out_features: Number of output features. Defaults to in_features. + act_layer: Activation layer class. + drop: Dropout rate. + bias: Whether to use bias in linear layers. + device: Device to place tensors on. + """ + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + device: torch.device | str | None = None, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through the MLP. + + Args: + x: Input tensor of shape (B, N, C). + + Returns: + Output tensor of shape (B, N, out_features). + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return self.drop(x) + + class MSDeformableAttention(nn.Module): """Multi-Scale Deformable Attention Module. @@ -170,7 +277,13 @@ class MSDeformableAttention(nn.Module): num_points (int): The number of points in MSDeformableAttention. """ - def __init__(self, embed_dim: int = 256, num_heads: int = 8, num_levels: int = 4, num_points: int = 4) -> None: + def __init__( + self, + embed_dim: int = 256, + num_heads: int = 8, + num_levels: int = 4, + num_points: int = 4, + ) -> None: """Multi-Scale Deformable Attention Module.""" super().__init__() self.embed_dim = embed_dim @@ -329,12 +442,14 @@ def __init__( num_heads: int = 8, num_levels: int = 4, num_points_list: list[int] = [3, 6, 3], # noqa: B006 + method: str = "default", ) -> None: super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.num_levels = num_levels self.num_points_list = num_points_list + self.method = method num_points_scale = [1 / n for n in num_points_list for _ in range(n)] self.register_buffer( @@ -350,6 +465,10 @@ def __init__( self._reset_parameters() + if method == "discrete": + for p in self.sampling_offsets.parameters(): + p.requires_grad = False + def _reset_parameters(self) -> None: """Reset parameters of the model.""" init.constant_(self.sampling_offsets.weight, 0) @@ -617,3 +736,758 @@ def forward( output = layer(output, pos, reference_points, spatial_shapes, padding_mask) return output + + +def cat_keep_shapes(x_list: list[Tensor]) -> tuple[Tensor, list[tuple[int, ...]], list[int]]: + """Concatenate tensors while preserving their original shapes. + + Args: + x_list: List of tensors to concatenate. + + Returns: + Tuple of (flattened tensor, original shapes, token counts). + """ + shapes = [x.shape for x in x_list] + num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list] + flattened = torch.cat([x.flatten(0, -2) for x in x_list]) + return flattened, shapes, num_tokens + + +def uncat_with_shapes(flattened: Tensor, shapes: list[tuple[int, ...]], num_tokens: list[int]) -> list[Tensor]: + """Split a flattened tensor back to original shapes. + + Args: + flattened: Concatenated tensor. + shapes: Original tensor shapes. + num_tokens: Token counts for splitting. + + Returns: + List of tensors with original shapes. + """ + outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0) + shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes] + return [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)] + + +# RoPE-related functions: +def rope_rotate_half(x: Tensor) -> Tensor: + """Rotate half of the tensor elements for RoPE. + + Args: + x: Input tensor of shape [..., D]. + + Returns: + Rotated tensor where x[..., :D/2] and x[..., D/2:] are swapped and negated. + """ + # x: [ x0 x1 x2 x3 x4 x5] + # out: [-x3 -x4 -x5 x0 x1 x2] + x1, x2 = x.chunk(2, dim=-1) + return torch.cat([-x2, x1], dim=-1) + + +def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: + """Apply rotary position embedding to tensor. + + Args: + x: Input tensor of shape [..., D]. + sin: Sine embeddings of shape [..., D]. + cos: Cosine embeddings of shape [..., D]. + + Returns: + Tensor with rotary position embedding applied. + """ + # x: [..., D], eg [x0, x1, x2, x3, x4, x5] + # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] + # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2] + return (x * cos) + (rope_rotate_half(x) * sin) + + +class LinearKMaskedBias(nn.Linear): + """Linear layer with masked bias for Q, K, V projection. + + Masks the K bias portion with NaN values for specific attention patterns. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 + super().__init__(*args, **kwargs) + o = self.out_features + if o % 3 != 0: + msg = f"out_features ({o}) must be divisible by 3" + raise ValueError(msg) + if self.bias is not None: + self.register_buffer("bias_mask", torch.full_like(self.bias, fill_value=math.nan)) + + def forward(self, input: Tensor) -> Tensor: # noqa: A002 + """Apply linear transformation with masked bias. + + Args: + input: Input tensor. + + Returns: + Transformed tensor. + """ + masked_bias = self.bias * self.bias_mask.to(self.bias.dtype) if self.bias is not None else None + return f.linear(input, self.weight, masked_bias) + + +class SelfAttention(nn.Module): + """Multi-head self-attention module. + + Args: + dim: Input/output feature dimension. + num_heads: Number of attention heads. + qkv_bias: If True, add bias to QKV projection. + proj_bias: If True, add bias to output projection. + attn_drop: Attention dropout rate. + proj_drop: Output projection dropout rate. + mask_k_bias: If True, mask the K bias. + device: Device for parameters. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + mask_k_bias: bool = False, + device: torch.device | str | None = None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear + self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device) + self.proj_drop = nn.Dropout(proj_drop) + + def apply_rope(self, q: Tensor, k: Tensor, rope: Tensor | tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + """Apply rotary position embeddings to query and key tensors. + + Args: + q: Query tensor of shape [B, heads, N, D//heads]. + k: Key tensor of shape [B, heads, N, D//heads]. + rope: Tuple of (sin, cos) tensors for position embedding. + + Returns: + Tuple of (q, k) with rotary embeddings applied. + """ + # All operations will use the dtype of rope, the output is cast back to the dtype of q and k + q_dtype = q.dtype + k_dtype = k.dtype + sin, cos = rope + rope_dtype = sin.dtype + q = q.to(dtype=rope_dtype) + k = k.to(dtype=rope_dtype) + n = q.shape[-2] + prefix = n - sin.shape[-2] + if prefix < 0: + msg = f"prefix ({prefix}) must be >= 0" + raise ValueError(msg) + q_prefix = q[:, :, :prefix, :] + q = rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] + q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head] + k_prefix = k[:, :, :prefix, :] + k = rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] + k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head] + q = q.to(dtype=q_dtype) + k = k.to(dtype=k_dtype) + return q, k + + def forward( + self, + x: Tensor, + attn_bias: Tensor | None = None, + rope: Tensor | tuple[Tensor, Tensor] | None = None, + ) -> Tensor: + """Forward pass for self-attention. + + Args: + x: Input tensor of shape [B, N, D]. + attn_bias: Optional attention bias. + rope: Optional rotary position embedding. + + Returns: + Output tensor of shape [B, N, D]. + """ + qkv = self.qkv(x) + attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope) + x = self.proj(attn_v) + return self.proj_drop(x) + + def forward_list( + self, + x_list: list[Tensor], + attn_bias: Tensor | None = None, + rope_list: list[tuple[Tensor, Tensor]] | None = None, + ) -> list[Tensor]: + """Forward pass for list of tensors. + + Args: + x_list: List of input tensors. + attn_bias: Optional attention bias. + rope_list: List of rotary position embeddings. + + Returns: + List of output tensors. + """ + if rope_list is None or len(x_list) != len(rope_list): + msg = "x_list and rope_list must have same length" + raise ValueError(msg) + x_flat, shapes, num_tokens = cat_keep_shapes(x_list) + qkv_flat = self.qkv(x_flat) + qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens) + att_out = [] + for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)): + att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope)) + x_flat, shapes, num_tokens = cat_keep_shapes(att_out) + x_flat = self.proj(x_flat) + return uncat_with_shapes(x_flat, shapes, num_tokens) + + def compute_attention( + self, + qkv: Tensor, + attn_bias: Tensor | None = None, + rope: tuple[Tensor, Tensor] | None = None, + ) -> Tensor: + """Compute attention from QKV tensor. + + Args: + qkv: Combined query-key-value tensor. + attn_bias: Optional attention bias (must be None). + rope: Optional rotary position embedding. + + Returns: + Attention output tensor. + """ + if attn_bias is not None: + msg = "attn_bias must be None" + raise ValueError(msg) + B, N, _ = qkv.shape # noqa: N806 + C = self.qkv.in_features # noqa: N806 + + qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = torch.unbind(qkv, 2) + q, k, v = [t.transpose(1, 2) for t in [q, k, v]] + if rope is not None: + q, k = self.apply_rope(q, k, rope) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = x.transpose(1, 2) + return x.reshape([B, N, C]) + + +class SelfAttentionBlock(nn.Module): + """Transformer block with self-attention and FFN. + + Args: + dim: Input/output feature dimension. + num_heads: Number of attention heads. + mlp_ratio: Ratio of MLP hidden dim to embedding dim. + qkv_bias: If True, add bias to QKV projection. + proj_bias: If True, add bias to output projection. + drop: Dropout rate. + attn_drop: Attention dropout rate. + init_values: Initial values for LayerScale. + drop_path: Drop path rate. + act_layer: Activation layer class. + norm_layer: Normalization layer class. + rope_subset_list: List of RoPE subsets. + ffn_layer: FFN layer class. + mask_k_bias: If True, mask the K bias. + device: Device for parameters. + """ + + def __init__( + self, + dim: int, + num_heads: int, + ffn_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values: float | None = None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = SelfAttention, + ffn_layer: Callable[..., nn.Module] = MLP2L, + mask_k_bias: bool = False, + device: torch.device | str | None = None, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + mask_k_bias=mask_k_bias, + device=device, + ) + self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * ffn_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + device=device, + ) + self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() + + self.sample_drop_ratio = drop_path + + @staticmethod + def _maybe_index_rope(rope: tuple[Tensor, Tensor] | None, indices: Tensor) -> tuple[Tensor, Tensor] | None: + if rope is None: + return None + + sin, cos = rope + if sin.ndim != cos.ndim: + msg = "sin and cos must have same ndim" + raise ValueError(msg) + if sin.ndim == 4: + # If the rope embedding has a batch dimension (is different for each batch element), index into it + return sin[indices], cos[indices] # [batch, heads, patches, embed_dim] + # No batch dimension, do not index + return sin, cos # [heads, patches, embed_dim] or [patches, embed_dim] + + def _forward(self, x: Tensor, rope: tuple[Tensor, Tensor] | None = None) -> Tensor: + """Forward pass for a single tensor. + + This is the reference implementation for a single tensor, matching what is done below for a list. + We call the list op on [x] instead of this function. + """ + b, _, _ = x.shape + sample_subset_size = max(int(b * (1 - self.sample_drop_ratio)), 1) + residual_scale_factor = b / sample_subset_size + + if self.training and self.sample_drop_ratio > 0.0: + indices_1 = (torch.randperm(b, device=x.device))[:sample_subset_size] + + x_subset_1 = x[indices_1] + rope_subset = self._maybe_index_rope(rope, indices_1) + residual_1 = self.attn(self.norm1(x_subset_1), rope=rope_subset) + + x_attn = torch.index_add( + x, + dim=0, + source=self.ls1(residual_1), + index=indices_1, + alpha=residual_scale_factor, + ) + + indices_2 = (torch.randperm(b, device=x.device))[:sample_subset_size] + + x_subset_2 = x_attn[indices_2] + residual_2 = self.mlp(self.norm2(x_subset_2)) + + x_ffn = torch.index_add( + x_attn, + dim=0, + source=self.ls2(residual_2), + index=indices_2, + alpha=residual_scale_factor, + ) + else: + x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope)) + x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn))) + + return x_ffn + + def _forward_list(self, x_list: list[Tensor], rope_list: list[tuple[Tensor, Tensor]] | None = None) -> list[Tensor]: + """Forward pass for list of tensors. + + This list operator concatenates the tokens from the list of inputs together to save + on the elementwise operations. Torch-compile memory-planning allows hiding the overhead + related to concat ops. + """ + b_list = [x.shape[0] for x in x_list] + sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list] + residual_scale_factors = [b / sample_subset_size for b, sample_subset_size in zip(b_list, sample_subset_sizes)] + + if self.training and self.sample_drop_ratio > 0.0: + indices_1_list = [ + (torch.randperm(b, device=x.device))[:sample_subset_size] + for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes) + ] + x_subset_1_list = [x[indices_1] for x, indices_1 in zip(x_list, indices_1_list)] + + if rope_list is not None: + rope_subset_list: list[tuple[Tensor, Tensor] | None] | None = [ + self._maybe_index_rope(rope, indices_1) for rope, indices_1 in zip(rope_list, indices_1_list) + ] + else: + rope_subset_list = None + + flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list) + norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens) + residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list) + + x_attn_list = [ + torch.index_add( + x, + dim=0, + source=self.ls1(residual_1), + index=indices_1, + alpha=residual_scale_factor, + ) + for x, residual_1, indices_1, residual_scale_factor in zip( + x_list, residual_1_list, indices_1_list, residual_scale_factors + ) + ] + + indices_2_list = [ + (torch.randperm(b, device=x.device))[:sample_subset_size] + for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes) + ] + x_subset_2_list = [x[indices_2] for x, indices_2 in zip(x_attn_list, indices_2_list)] + flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list) + norm2_flat = self.norm2(flattened) + norm2_list = uncat_with_shapes(norm2_flat, shapes, num_tokens) + + residual_2_list = self.mlp.forward_list(norm2_list) + + x_ffn = [ + torch.index_add( + x_attn, + dim=0, + source=self.ls2(residual_2), + index=indices_2, + alpha=residual_scale_factor, + ) + for x_attn, residual_2, indices_2, residual_scale_factor in zip( + x_attn_list, residual_2_list, indices_2_list, residual_scale_factors + ) + ] + else: + x_out = [] + for i, x in enumerate(x_list): + rope = rope_list[i] if rope_list is not None else None + x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope)) + x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn))) + x_out.append(x_ffn) + x_ffn = x_out + + return x_ffn + + def forward( + self, + x_or_x_list: Tensor | list[Tensor], + rope_or_rope_list: tuple[Tensor, Tensor] | list[tuple[Tensor, Tensor] | None] | None = None, + ) -> Tensor | list[Tensor]: + """Forward pass supporting both single tensor and list of tensors. + + Args: + x_or_x_list: Input tensor or list of tensors. + rope_or_rope_list: Rotary position embedding or list of embeddings. + + Returns: + Output tensor or list of tensors. + """ + if isinstance(x_or_x_list, Tensor): + # for reference: + # return self._forward(x_or_x_list, rope=rope_or_rope_list) + # in order to match implementations we call the list op: + rope_as_list = [rope_or_rope_list] if not isinstance(rope_or_rope_list, list) else rope_or_rope_list + return self._forward_list([x_or_x_list], rope_list=rope_as_list)[0] # type: ignore[arg-type] + if isinstance(x_or_x_list, list): + if rope_or_rope_list is None: + rope_or_rope_list = [None for _ in x_or_x_list] + # return [self._forward(x, rope=rope) for x, rope in zip(x_or_x_list, rope_or_rope_list)] + return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list) # type: ignore[arg-type] + msg = f"x_or_x_list must be Tensor or list, got {type(x_or_x_list)}" + raise TypeError(msg) + + +class Gate(nn.Module): + """Gated fusion module for combining two feature streams. + + Uses learnable gates to adaptively blend two input tensors. + + Args: + d_model: Feature dimension. + use_rmsnorm: Whether to use RMSNorm instead of LayerNorm. + """ + + def __init__(self, d_model: int, use_rmsnorm: bool = False) -> None: + super().__init__() + self.gate = nn.Linear(2 * d_model, 2 * d_model) + bias = bias_init_with_prob(0.5) + init.constant_(self.gate.bias, bias) + init.constant_(self.gate.weight, 0) + self.norm = RMSNorm(d_model) if use_rmsnorm else nn.LayerNorm(d_model) + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + """Gated fusion of two tensors. + + Args: + x1: First input tensor of shape (B, N, C). + x2: Second input tensor of shape (B, N, C). + + Returns: + Fused tensor of shape (B, N, C). + """ + gate_input = torch.cat([x1, x2], dim=-1) + gates = torch.sigmoid(self.gate(gate_input)) + gate1, gate2 = gates.chunk(2, dim=-1) + return self.norm(gate1 * x1 + gate2 * x2) + + +class Integral(nn.Module): + """Integral layer for distribution-based bounding box regression. + + Computes target location using: `sum{Pr(n) * W(n)}`, where Pr(n) is the + softmax probability vector and W(n) is the non-uniform weighting function. + + Args: + reg_max: Maximum number of discrete bins for regression. + """ + + def __init__(self, reg_max: int = 32) -> None: + super().__init__() + self.reg_max = reg_max + + def forward(self, x: Tensor, project: Tensor) -> Tensor: + """Compute integral over distribution. + + Args: + x: Distribution tensor of shape (B, N, 4*(reg_max+1)). + project: Projection weights for weighted sum. + + Returns: + Bounding box offsets of shape (B, N, 4). + """ + shape = x.shape + x = f.softmax(x.reshape(-1, self.reg_max + 1), dim=1) + x = f.linear(x, project.to(x.device)).reshape(-1, 4) + return x.reshape([*list(shape[:-1]), -1]) + + +class LQE(nn.Module): + """Location Quality Estimator. + + Estimates localization quality from corner distribution statistics + to refine classification scores. + + Args: + k: Number of top probabilities to use for statistics. + hidden_dim: Hidden dimension for MLP. + num_layers: Number of MLP layers. + reg_max: Maximum regression bins. + activation: Activation function class. + """ + + def __init__( + self, + k: int, + hidden_dim: int, + num_layers: int, + reg_max: int, + activation: Callable[..., nn.Module] = partial(nn.ReLU, inplace=True), + ) -> None: + super().__init__() + self.k = k + self.reg_max = reg_max + self.reg_conf = MLP( + input_dim=4 * (k + 1), + hidden_dim=hidden_dim, + output_dim=1, + num_layers=num_layers, + activation=activation, + ) + init.constant_(self.reg_conf.layers[-1].bias, 0) + init.constant_(self.reg_conf.layers[-1].weight, 0) + + def forward(self, scores: Tensor, pred_corners: Tensor) -> Tensor: + """Refine scores based on corner distribution quality. + + Args: + scores: Classification scores of shape (B, N, num_classes). + pred_corners: Corner predictions of shape (B, N, 4*(reg_max+1)). + + Returns: + Refined scores of shape (B, N, num_classes). + """ + b, num_pred, _ = pred_corners.size() + prob = f.softmax(pred_corners.reshape(b, num_pred, 4, self.reg_max + 1), dim=-1) + prob_topk, _ = prob.topk(self.k, dim=-1) + stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1) + quality_score = self.reg_conf(stat.reshape(b, num_pred, -1)) + return scores + quality_score + + +class SwiGLUFFN(nn.Module): + """SwiGLU Feed-Forward Network. + + Implements the SwiGLU activation function as described in GLU Variants paper. + Uses gated linear units with SiLU activation for improved performance. + + Args: + in_features: Number of input features. + hidden_features: Number of hidden features. + out_features: Number of output features. + bias: Whether to use bias in linear layers. + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + self._reset_parameters() + + def _reset_parameters(self) -> None: + """Initialize weights with Xavier uniform and zero bias.""" + init.xavier_uniform_(self.w12.weight) + init.constant_(self.w12.bias, 0) + init.xavier_uniform_(self.w3.weight) + init.constant_(self.w3.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass with SwiGLU activation. + + Args: + x: Input tensor of shape (B, N, C). + + Returns: + Output tensor of shape (B, N, out_features). + """ + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = f.silu(x1) * x2 + return self.w3(hidden) + + +def get_contrastive_denoising_training_group( + targets: list[dict[str, torch.Tensor]], + num_classes: int, + num_queries: int, + class_embed: torch.nn.Module, + num_denoising: int = 100, + label_noise_ratio: float = 0.5, + box_noise_scale: float = 1.0, + max_denoising_queries: int = 1000, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]] | tuple[None, None, None, None]: + """Generate contrastive denoising training group. + + Args: + targets (List[Dict[str, torch.Tensor]]): List of target dictionaries. + num_classes (int): Number of classes. + num_queries (int): Number of queries. + class_embed (torch.nn.Module): Class embedding module. + num_denoising (int, optional): Number of denoising queries. Defaults to 100. + label_noise_ratio (float, optional): Ratio of label noise. Defaults to 0.5. + box_noise_scale (float, optional): Scale of box noise. Defaults to 1.0. + max_denoising_queries (int, optional): Maximum number of denoising queries to prevent OOM. + Defaults to 1000. + + Returns: + Tuple[Tensor,Tensor,Tensor, dict[str, Tensor]] | tuple[None,None,None,None]: + Tuple containing input query class, input query bbox, attention mask, and denoising metadata. + """ + num_gts = [len(t["labels"]) for t in targets] + device = targets[0]["labels"].device + + max_gt_num = max(num_gts) + if max_gt_num == 0: + return None, None, None, None + + num_group = num_denoising // max_gt_num + num_group = 1 if num_group == 0 else num_group + + # Cap the number of denoising queries to prevent OOM with many ground truth objects + total_dn_queries = max_gt_num * 2 * num_group + if total_dn_queries > max_denoising_queries: + num_group = max(1, max_denoising_queries // (max_gt_num * 2)) + # pad gt to max_num of a batch + bs = len(num_gts) + + input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device) + + for i in range(bs): + num_gt = num_gts[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_group]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group]) + # positive and negative mask + negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_group, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts]) + # total denoising queries + num_denoising = int(max_gt_num * 2 * num_group) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = torchvision.ops.box_convert(input_query_bbox, in_fmt="cxcywh", out_fmt="xyxy") + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = torchvision.ops.box_convert(known_bbox, in_fmt="xyxy", out_fmt="cxcywh") + input_query_bbox = inverse_sigmoid(input_query_bbox) + + input_query_class = class_embed(input_query_class) + + tgt_size = num_denoising + num_queries + attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising:, :num_denoising] = True + + # reconstruct cannot see each other + for i in range(num_group): + if i == 0: + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True + if i == num_group - 1: + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * i * 2] = True + else: + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * 2 * i] = True + + dn_meta = { + "dn_positive_idx": dn_positive_idx, + "dn_num_group": num_group, + "dn_num_split": [num_denoising, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, dn_meta diff --git a/library/src/otx/backend/native/models/common/losses/gfocal_loss.py b/library/src/otx/backend/native/models/common/losses/gfocal_loss.py index e036e847127..9767878864b 100644 --- a/library/src/otx/backend/native/models/common/losses/gfocal_loss.py +++ b/library/src/otx/backend/native/models/common/losses/gfocal_loss.py @@ -12,7 +12,7 @@ from functools import partial -import torch.nn.functional as F # noqa: N812 +import torch.nn.functional as f from torch import Tensor, nn from otx.backend.native.models.common.losses.utils import weighted_loss @@ -45,10 +45,10 @@ def quality_focal_loss_tensor_target( raise ValueError(msg) if activated: pred_sigmoid = pred - loss_function = F.binary_cross_entropy + loss_function = f.binary_cross_entropy else: pred_sigmoid = pred.sigmoid() - loss_function = F.binary_cross_entropy_with_logits + loss_function = f.binary_cross_entropy_with_logits scale_factor = pred_sigmoid target = target.type_as(pred) @@ -89,7 +89,7 @@ def quality_focal_loss(pred: Tensor, target: Tensor, beta: float = 2.0) -> Tenso pred_sigmoid = pred.sigmoid() scale_factor = pred_sigmoid zerolabel = scale_factor.new_zeros(pred.shape) - loss = F.binary_cross_entropy_with_logits(pred, zerolabel, reduction="none") * scale_factor.pow( + loss = f.binary_cross_entropy_with_logits(pred, zerolabel, reduction="none") * scale_factor.pow( beta, ) @@ -99,7 +99,7 @@ def quality_focal_loss(pred: Tensor, target: Tensor, beta: float = 2.0) -> Tenso pos_label = label[pos].long() # positives are supervised by bbox quality (IoU) score scale_factor = score[pos] - pred_sigmoid[pos, pos_label] - loss[pos, pos_label] = F.binary_cross_entropy_with_logits( + loss[pos, pos_label] = f.binary_cross_entropy_with_logits( pred[pos, pos_label], score[pos], reduction="none", @@ -134,7 +134,7 @@ def quality_focal_loss_with_prob(pred: Tensor, target: Tensor, beta: float = 2.0 pred_sigmoid = pred scale_factor = pred_sigmoid zerolabel = scale_factor.new_zeros(pred.shape) - loss = F.binary_cross_entropy(pred, zerolabel, reduction="none") * scale_factor.pow(beta) + loss = f.binary_cross_entropy(pred, zerolabel, reduction="none") * scale_factor.pow(beta) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = pred.size(1) @@ -142,7 +142,7 @@ def quality_focal_loss_with_prob(pred: Tensor, target: Tensor, beta: float = 2.0 pos_label = label[pos].long() # positives are supervised by bbox quality (IoU) score scale_factor = score[pos] - pred_sigmoid[pos, pos_label] - loss[pos, pos_label] = F.binary_cross_entropy( + loss[pos, pos_label] = f.binary_cross_entropy( pred[pos, pos_label], score[pos], reduction="none", diff --git a/library/src/otx/backend/native/models/common/utils/assigners/hungarian_matcher.py b/library/src/otx/backend/native/models/common/utils/assigners/hungarian_matcher.py index 0d12b305edb..ab6d9f2a2c9 100644 --- a/library/src/otx/backend/native/models/common/utils/assigners/hungarian_matcher.py +++ b/library/src/otx/backend/native/models/common/utils/assigners/hungarian_matcher.py @@ -327,7 +327,6 @@ def forward( # Eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible`` cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10)) cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10)) - # Perform assignment using the hungarian algorithm in scipy assigned_indices = linear_sum_assignment(cost_matrix.cpu()) indices.append(assigned_indices) diff --git a/library/src/otx/backend/native/models/detection/__init__.py b/library/src/otx/backend/native/models/detection/__init__.py index 5caf5c62460..92dc445ac3a 100644 --- a/library/src/otx/backend/native/models/detection/__init__.py +++ b/library/src/otx/backend/native/models/detection/__init__.py @@ -6,9 +6,10 @@ from .atss import ATSS from .d_fine import DFine from .deim import DEIMDFine +from .deimv2 import DEIMV2 from .rtdetr import RTDETR from .rtmdet import RTMDet from .ssd import SSD from .yolox import YOLOX -__all__ = ["ATSS", "RTDETR", "SSD", "YOLOX", "DEIMDFine", "DFine", "RTMDet"] +__all__ = ["ATSS", "DEIMV2", "RTDETR", "SSD", "YOLOX", "DEIMDFine", "DFine", "RTMDet"] diff --git a/library/src/otx/backend/native/models/detection/backbones/dinov3sta.py b/library/src/otx/backend/native/models/detection/backbones/dinov3sta.py new file mode 100644 index 00000000000..b318eb580d9 --- /dev/null +++ b/library/src/otx/backend/native/models/detection/backbones/dinov3sta.py @@ -0,0 +1,322 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""DINOv3 with Spatial Token Attention (STA) backbone for DEIMv2 model. + +This module provides multi-scale feature extraction by combining DINOv3/ViT-Tiny +semantic features with spatial prior features from a lightweight CNN module. + +Modified from DEIMv2 (https://github.com/Intellindust-AI-Lab/DEIMv2) +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, ClassVar + +import torch +import torch.distributed as dist +import torch.nn.functional as f +from torch import Tensor, nn + +from otx.backend.native.models.common.backbones.dinov3 import DinoVisionTransformer +from otx.backend.native.models.detection.backbones.vit_tiny import VisionTransformer + +logger = logging.getLogger(__name__) + + +def get_norm_layer(num_features: int, use_sync_bn: bool = True) -> nn.Module: + """Get appropriate normalization layer based on distributed training context. + + Uses SyncBatchNorm for multi-GPU training, regular BatchNorm otherwise. + + Args: + num_features: Number of features for the normalization layer. + use_sync_bn: If True, use SyncBatchNorm when in multi-GPU setting. + + Returns: + BatchNorm2d or SyncBatchNorm based on training context. + """ + if use_sync_bn and dist.is_initialized() and dist.get_world_size() > 1: + return nn.SyncBatchNorm(num_features) + return nn.BatchNorm2d(num_features) + + +class SpatialPriorModulev2(nn.Module): + """Lightweight Spatial Prior Module for extracting multi-scale detail features. + + This module extracts fine-grained spatial details at multiple scales (1/8, 1/16, 1/32) + using a series of convolutional layers. These features are fused with semantic + features from the ViT backbone for improved detection performance. + + Args: + inplanes: Base number of channels for the convolutional layers. Defaults to 16. + use_sync_bn: Whether to use SyncBatchNorm for multi-GPU training. Defaults to True. + """ + + def __init__(self, inplanes: int = 16, use_sync_bn: bool = True) -> None: + super().__init__() + + # 1/4 scale stem + self.stem = nn.Sequential( + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + get_norm_layer(inplanes, use_sync_bn), + nn.GELU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + # 1/8 scale + self.conv2 = nn.Sequential( + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + get_norm_layer(2 * inplanes, use_sync_bn), + ) + # 1/16 scale + self.conv3 = nn.Sequential( + nn.GELU(), + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + get_norm_layer(4 * inplanes, use_sync_bn), + ) + # 1/32 scale + self.conv4 = nn.Sequential( + nn.GELU(), + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + get_norm_layer(4 * inplanes, use_sync_bn), + ) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """Extract multi-scale spatial features. + + Args: + x: Input image tensor of shape (B, 3, H, W). + + Returns: + Tuple of three feature tensors at scales 1/8, 1/16, and 1/32: + - c2: Shape (B, 2*inplanes, H/8, W/8) + - c3: Shape (B, 4*inplanes, H/16, W/16) + - c4: Shape (B, 4*inplanes, H/32, W/32) + """ + c1 = self.stem(x) + c2 = self.conv2(c1) # 1/8 + c3 = self.conv3(c2) # 1/16 + c4 = self.conv4(c3) # 1/32 + + return c2, c3, c4 + + +class DINOv3STAsModule(nn.Module): + """DINOv3/ViT backbone with Spatial Token Attention for multi-scale feature extraction. + + Combines semantic features from DINOv3 or ViT-Tiny backbone with spatial prior + features from a lightweight CNN module. Produces multi-scale features suitable + for object detection. + + Args: + name: Model name. Use 'dinov3_*' for DINOv3 variants or other names for ViT-Tiny. + weights_path: Path to pretrained weights. Defaults to None. + interaction_indexes: Layer indices to extract intermediate features from. + Defaults to empty list. + finetune: Whether to finetune the backbone. If False, backbone is frozen. + Defaults to True. + embed_dim: Embedding dimension for ViT-Tiny. Defaults to 192. + num_heads: Number of attention heads for ViT-Tiny. Defaults to 3. + patch_size: Patch size for the ViT backbone. Defaults to 16. + use_sta: Whether to use the Spatial Token Attention module. Defaults to True. + conv_inplane: Base channel number for STA module. Defaults to 16. + hidden_dim: Hidden dimension for output projection. Defaults to embed_dim. + gradient_checkpointing: If True, use gradient checkpointing in backbone + to reduce memory at the cost of increased computation. Defaults to False. + """ + + def __init__( + self, + name: str, + weights_path: str | None = None, + interaction_indexes: list[int] | None = None, + finetune: bool = True, + embed_dim: int = 192, + num_heads: int = 3, + patch_size: int = 16, + use_sta: bool = True, + conv_inplane: int = 16, + hidden_dim: int | None = None, + gradient_checkpointing: bool = False, + ) -> None: + super().__init__() + if interaction_indexes is None: + interaction_indexes = [] + + self.dinov3: DinoVisionTransformer | VisionTransformer + if "dinov3" in name: + self.dinov3 = DinoVisionTransformer(name=name, gradient_checkpointing=gradient_checkpointing) + if weights_path is not None and Path(weights_path).exists(): + logger.info("Loading checkpoint from %s...", weights_path) + self.dinov3.load_state_dict(torch.load(weights_path)) + else: + logger.info("Training DINOv3 from scratch...") + else: + self.dinov3 = VisionTransformer( + embed_dim=embed_dim, + num_heads=num_heads, + return_layers=interaction_indexes, + gradient_checkpointing=gradient_checkpointing, + ) + if weights_path is not None and Path(weights_path).exists(): + logger.info("Loading checkpoint from %s...", weights_path) + self.dinov3._model.load_state_dict(torch.load(weights_path)) # noqa: SLF001 + else: + logger.info("Training ViT-Tiny from scratch...") + + embed_dim = self.dinov3.embed_dim + self.interaction_indexes = interaction_indexes + self.patch_size = patch_size + + if not finetune: + self.dinov3.eval() + self.dinov3.requires_grad_(False) + + # Initialize the spatial prior module for detail features + self.use_sta = use_sta + self.sta: SpatialPriorModulev2 | None = None + if use_sta: + logger.info("Using Lite Spatial Prior Module with inplanes=%d", conv_inplane) + self.sta = SpatialPriorModulev2(inplanes=conv_inplane) + else: + conv_inplane = 0 + + # Linear projection layers for fusing semantic and spatial features + hidden_dim = hidden_dim if hidden_dim is not None else embed_dim + self.convs = nn.ModuleList( + [ + nn.Conv2d(embed_dim + conv_inplane * 2, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False), + nn.Conv2d(embed_dim + conv_inplane * 4, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False), + nn.Conv2d(embed_dim + conv_inplane * 4, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False), + ] + ) + # Normalization layers - use BatchNorm or SyncBatchNorm based on distributed context + use_sync = dist.is_initialized() and dist.get_world_size() > 1 + self.norms = nn.ModuleList( + [ + get_norm_layer(hidden_dim, use_sync), + get_norm_layer(hidden_dim, use_sync), + get_norm_layer(hidden_dim, use_sync), + ] + ) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """Extract multi-scale features from input image. + + Args: + x: Input image tensor of shape (B, C, H, W). + + Returns: + Tuple of three feature tensors at different scales: + - c2: Features at 1/8 scale, shape (B, hidden_dim, H/8, W/8) + - c3: Features at 1/16 scale, shape (B, hidden_dim, H/16, W/16) + - c4: Features at 1/32 scale, shape (B, hidden_dim, H/32, W/32) + """ + h_c, w_c = x.shape[2] // 16, x.shape[3] // 16 + bs = x.shape[0] + + # Extract semantic features from backbone + all_layers: list[tuple[Tensor, Tensor]] + if len(self.interaction_indexes) > 0 and not isinstance(self.dinov3, VisionTransformer): + result = self.dinov3.get_intermediate_layers(x, n=self.interaction_indexes, return_class_token=True) + all_layers = [(out, cls) for out, cls in result] # type: ignore[misc] + else: + all_layers = list(self.dinov3(x)) + + # Repeat single layer for all three scales if needed + if len(all_layers) == 1: + all_layers = [all_layers[0], all_layers[0], all_layers[0]] + + # Process semantic features at multiple scales + sem_feats: list[Tensor] = [] + num_scales = len(all_layers) - 2 + for i, layer_output in enumerate(all_layers): + feat, _ = layer_output + sem_feat = feat.transpose(1, 2).view(bs, -1, h_c, w_c).contiguous() # [B, D, H, W] + resize_h, resize_w = int(h_c * 2 ** (num_scales - i)), int(w_c * 2 ** (num_scales - i)) + sem_feat = f.interpolate(sem_feat, size=[resize_h, resize_w], mode="bilinear", align_corners=False) + sem_feats.append(sem_feat) + + # Fuse semantic and spatial features + fused_feats: list[Tensor] + if self.use_sta and self.sta is not None: + detail_feats = self.sta(x) + fused_feats = [ + torch.cat([sem_feat, detail_feat], dim=1) for sem_feat, detail_feat in zip(sem_feats, detail_feats) + ] + else: + fused_feats = sem_feats + + # Apply projection and normalization + c2 = self.norms[0](self.convs[0](fused_feats[0])) + c3 = self.norms[1](self.convs[1](fused_feats[1])) + c4 = self.norms[2](self.convs[2](fused_feats[2])) + + return c2, c3, c4 + + +class DINOv3STAs(nn.Module): + """Factory class for creating DINOv3/ViT with Spatial Token Attention backbones. + + This class provides predefined configurations for different DEIMv2 model variants. + Use the model_name to select a configuration: + - 'deimv2_x': DINOv3 ViT-S/16+ (largest) + - 'deimv2_l': DINOv3 ViT-S/16 (large) + - 'deimv2_m': ViT-Tiny+ (medium) + - 'deimv2_s': ViT-Tiny (small) + + Example: + >>> backbone = DINOv3STAs("deimv2_s") + >>> features = backbone(images) # Returns (c2, c3, c4) multi-scale features + """ + + backbone_cfg: ClassVar[dict[str, dict[str, Any]]] = { + "deimv2_x": { + "name": "dinov3_vits16plus", + "weights_path": None, + "interaction_indexes": [5, 8, 11], + "conv_inplane": 64, + "hidden_dim": 256, + }, + "deimv2_l": { + "name": "dinov3_vits16", + "weights_path": None, + "interaction_indexes": [5, 8, 11], + "conv_inplane": 32, + "hidden_dim": 224, + }, + "deimv2_m": { + "name": "vit_tinyplus", + "embed_dim": 256, + "weights_path": None, + "interaction_indexes": [3, 7, 11], + "num_heads": 4, + }, + "deimv2_s": { + "name": "vit_tiny", + "embed_dim": 192, + "weights_path": None, + "interaction_indexes": [3, 7, 11], + "num_heads": 3, + }, + } + + def __new__(cls, model_name: str, gradient_checkpointing: bool = False) -> DINOv3STAsModule: + """Create a DINOv3STAs backbone instance. + + Args: + model_name: Name of the model configuration to use. + Must be one of: 'deimv2_x', 'deimv2_l', 'deimv2_m', 'deimv2_s'. + gradient_checkpointing: If True, use gradient checkpointing in backbone. + + Returns: + Configured DINOv3STAsModule backbone instance. + + Raises: + KeyError: If model_name is not in backbone_cfg. + """ + cfg = cls.backbone_cfg[model_name].copy() + cfg["gradient_checkpointing"] = gradient_checkpointing + return DINOv3STAsModule(**cfg) diff --git a/library/src/otx/backend/native/models/detection/backbones/vit_tiny.py b/library/src/otx/backend/native/models/detection/backbones/vit_tiny.py new file mode 100644 index 00000000000..6a3f631094e --- /dev/null +++ b/library/src/otx/backend/native/models/detection/backbones/vit_tiny.py @@ -0,0 +1,423 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Vision Transformer (ViT) Tiny implementation for object detection. + +Modified from DEIMv2 (https://github.com/Intellindust-AI-Lab/DEIMv2) +Modified from DINOv3 (https://github.com/facebookresearch/dinov3) +Modified from https://huggingface.co/spaces/Hila/RobustViT/blob/main/ViT/ViT_new.py +""" + +from __future__ import annotations + +from functools import partial +from typing import Callable + +import torch +from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint as gradient_checkpoint + +from otx.backend.native.models.common.layers.position_embed import RopePositionEmbedding +from otx.backend.native.models.common.layers.transformer_layers import MLP2L as MLP +from otx.backend.native.models.utils.weight_init import trunc_normal_ + + +def rotate_half(x: Tensor) -> Tensor: + """Rotate half the hidden dims of the input for RoPE. + + Splits the last dimension in half and swaps the two halves with negation. + + Args: + x: Input tensor of shape (..., D) where D is even. + + Returns: + Rotated tensor of the same shape. + """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rope(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: + """Apply Rotary Position Embedding (RoPE) to the input tensor. + + Args: + x: Input tensor to apply RoPE to. + sin: Precomputed sine values for position encoding. + cos: Precomputed cosine values for position encoding. + + Returns: + Tensor with RoPE applied. + """ + return (x * cos) + (rotate_half(x) * sin) + + +class SimplifiedPatchEmbed(nn.Module): + """Patch Embedding layer for Vision Transformer. + + Converts an image into a sequence of patch embeddings using a convolutional layer. + + Args: + img_size: Input image size. Defaults to 224. + patch_size: Size of each patch. Defaults to 16. + in_chans: Number of input channels. Defaults to 3. + embed_dim: Embedding dimension. Defaults to 768. + """ + + def __init__( + self, + img_size: int | tuple[int, int] = 224, + patch_size: int | tuple[int, int] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + super().__init__() + img_size = (img_size, img_size) if isinstance(img_size, int) else img_size + patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x: Tensor) -> Tensor: + """Convert image to patch embeddings. + + Args: + x: Input image tensor of shape (B, C, H, W). + + Returns: + Patch embeddings of shape (B, num_patches, embed_dim). + """ + return self.proj(x).flatten(2).transpose(1, 2) + + +def drop_path(x: Tensor, drop_prob: float = 0.0, training: bool = False) -> Tensor: + """Drop paths (Stochastic Depth) per sample. + + When applied in main path of residual blocks, this implements stochastic depth. + + Args: + x: Input tensor. + drop_prob: Probability of dropping the path. Defaults to 0.0. + training: Whether the model is in training mode. Defaults to False. + + Returns: + Output tensor with drop path applied during training. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + return x.div(keep_prob) * random_tensor.floor() + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample. + + A module wrapper for the drop_path function. + + Args: + drop_prob: Probability of dropping the path. Defaults to None (0.0). + """ + + def __init__(self, drop_prob: float | None = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: Tensor) -> Tensor: + """Apply drop path to input. + + Args: + x: Input tensor. + + Returns: + Output tensor with drop path applied. + """ + return drop_path(x, self.drop_prob or 0.0, self.training) + + +class Attention(nn.Module): + """Multi-head self-attention module with optional RoPE support. + + Args: + dim: Input dimension. + num_heads: Number of attention heads. Defaults to 8. + qkv_bias: Whether to add bias to QKV projection. Defaults to False. + attn_drop: Attention dropout rate. Defaults to 0.0. + proj_drop: Output projection dropout rate. Defaults to 0.0. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + x: Tensor, + rope_sincos: tuple[Tensor, Tensor] | None = None, + ) -> Tensor: + """Forward pass for multi-head attention. + + Args: + x: Input tensor of shape (B, N, C). + rope_sincos: Optional tuple of (sin, cos) tensors for RoPE. + + Returns: + Output tensor of shape (B, N, C). + """ + B, N, C = x.shape # noqa: N806 + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + if rope_sincos is not None: + sin, cos = rope_sincos + q_cls, q_patch = q[:, :, :1, :], q[:, :, 1:, :] + k_cls, k_patch = k[:, :, :1, :], k[:, :, 1:, :] + + q_patch = apply_rope(q_patch, sin, cos) + k_patch = apply_rope(k_patch, sin, cos) + + q = torch.cat((q_cls, q_patch), dim=2) + k = torch.cat((k_cls, k_patch), dim=2) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop) + x = x.transpose(1, 2).reshape([B, N, C]) + x = self.proj(x) + return self.proj_drop(x) + + +class Block(nn.Module): + """Transformer block with attention and MLP. + + Standard transformer encoder block with pre-normalization. + + Args: + dim: Input dimension. + num_heads: Number of attention heads. + mlp_ratio: Ratio of MLP hidden dim to embedding dim. Defaults to 4.0. + qkv_bias: Whether to add bias to QKV projection. Defaults to False. + drop: Dropout rate for MLP. Defaults to 0.0. + attn_drop: Attention dropout rate. Defaults to 0.0. + drop_path: Drop path rate. Defaults to 0.0. + act_layer: Activation layer class. Defaults to nn.GELU. + norm_layer: Normalization layer class. Defaults to nn.LayerNorm. + """ + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + attn_drop: float = 0.0, + drop_path: float = 0.0, + drop: float = 0.0, + act_layer: type[nn.Module] = nn.GELU, + norm_layer: type[nn.Module] | Callable[..., nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = MLP( + in_features=dim, hidden_features=int(dim * mlp_ratio), out_features=dim, act_layer=act_layer, drop=drop + ) + + def forward(self, x: Tensor, rope_sincos: tuple[Tensor, Tensor] | None = None) -> Tensor: + """Forward pass through transformer block. + + Args: + x: Input tensor of shape (B, N, C). + rope_sincos: Optional tuple of (sin, cos) tensors for RoPE. + + Returns: + Output tensor of shape (B, N, C). + """ + attn_output = self.attn(self.norm1(x), rope_sincos=rope_sincos) + x = x + self.drop_path(attn_output) + return x + self.drop_path(self.mlp(self.norm2(x))) + + +class VisionTransformer(nn.Module): + """Vision Transformer (ViT) backbone for object detection. + + A Vision Transformer with Rotary Position Embedding (RoPE) support, + designed for multi-scale feature extraction. + + Args: + img_size: Input image size. Defaults to 224. + patch_size: Size of each patch. Defaults to 16. + in_chans: Number of input channels. Defaults to 3. + embed_dim: Embedding dimension. Defaults to 192. + depth: Number of transformer blocks. Defaults to 12. + num_heads: Number of attention heads. Defaults to 3. + mlp_ratio: Ratio of MLP hidden dim to embedding dim. Defaults to 4.0. + qkv_bias: Whether to add bias to QKV projection. Defaults to True. + drop_rate: Dropout rate. Defaults to 0.0. + attn_drop_rate: Attention dropout rate. Defaults to 0.0. + drop_path_rate: Drop path rate. Defaults to 0.0. + return_layers: List of layer indices to return features from. + Defaults to [3, 7, 11]. + embed_layer: Patch embedding layer class. Defaults to SimplifiedPatchEmbed. + norm_layer: Normalization layer class. Defaults to LayerNorm with eps=1e-6. + act_layer: Activation layer class. Defaults to nn.GELU. + gradient_checkpointing: If True, use gradient checkpointing to reduce memory. + Defaults to False. + """ + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 192, + depth: int = 12, + num_heads: int = 3, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + return_layers: list[int] | None = None, + embed_layer: type[nn.Module] = SimplifiedPatchEmbed, + norm_layer: type[nn.Module] | Callable[..., nn.Module] | None = None, + act_layer: type[nn.Module] | None = None, + gradient_checkpointing: bool = False, + ) -> None: + super().__init__() + if return_layers is None: + return_layers = [3, 7, 11] + self.num_features = self.embed_dim = embed_dim + self.num_tokens = 1 + self.return_layers = return_layers + self.gradient_checkpointing = gradient_checkpointing + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self._model = nn.Module() + self._model.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim + ) + self.patch_size = patch_size + self._model.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self._model.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + attn_drop=attn_drop_rate, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ) + for i in range(depth) + ] + ) + + self._model.rope_embed = RopePositionEmbedding( + embed_dim=embed_dim, + num_heads=num_heads, + base=100.0, + normalize_coords="separate", + shift_coords=None, + jitter_coords=None, + rescale_coords=None, + dtype=None, + device=None, + ) + self.init_weights() + + def init_weights(self) -> None: + """Initialize model weights.""" + trunc_normal_(self._model.cls_token, std=0.02) + self._model.rope_embed._init_weights() # noqa: SLF001 + self.apply(self._init_vit_weights) + + def _init_vit_weights(self, m: nn.Module) -> None: + """Initialize weights for ViT layers. + + Args: + m: Module to initialize. + """ + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(m.bias) + nn.init.ones_(m.weight) + + @torch.jit.ignore + def no_weight_decay(self) -> set[str]: + """Return set of parameter names that should not use weight decay. + + Returns: + Set of parameter names to exclude from weight decay. + """ + return {"cls_token"} + + def get_model(self) -> nn.Module: + """Get the internal model module. + + Returns: + The internal _model module containing all layers. + """ + return self._model + + def feature_dim(self) -> int: + """Get the feature dimension. + + Returns: + The embedding dimension. + """ + return self.embed_dim + + def forward(self, x: Tensor) -> list[tuple[Tensor, Tensor]]: + """Forward pass through Vision Transformer. + + Args: + x: Input image tensor of shape (B, C, H, W). + + Returns: + List of tuples (patch_features, cls_token) for each return layer. + patch_features has shape (B, num_patches, embed_dim). + cls_token has shape (B, embed_dim). + """ + outs = [] + B, C, H, W = x.shape # noqa: N806 + + x_embed = self._model.patch_embed(x) + cls_token = self._model.cls_token.expand(x_embed.shape[0], -1, -1) + x = torch.cat((cls_token, x_embed), dim=1) + + patch_grid_h = H // self.patch_size + patch_grid_w = W // self.patch_size + rope_sincos = self._model.rope_embed(h=patch_grid_h, w=patch_grid_w) + + for i, blk in enumerate(self._model.blocks): + if self.training and self.gradient_checkpointing: + x = gradient_checkpoint(blk, x, rope_sincos, use_reentrant=False) + else: + x = blk(x, rope_sincos=rope_sincos) + if i in self.return_layers: + outs.append((x[:, 1:], x[:, 0])) + return outs diff --git a/library/src/otx/backend/native/models/detection/deim.py b/library/src/otx/backend/native/models/detection/deim.py index 0bd5a47925e..8f8be529cab 100644 --- a/library/src/otx/backend/native/models/detection/deim.py +++ b/library/src/otx/backend/native/models/detection/deim.py @@ -106,6 +106,7 @@ def _create_model(self, num_classes: int | None = None) -> DETR: decoder = DFINETransformer( model_name=self.model_name, num_classes=num_classes, + eval_spatial_size=self.data_input_params.input_size, ) criterion = DEIMCriterion( weight_dict={ diff --git a/library/src/otx/backend/native/models/detection/deimv2.py b/library/src/otx/backend/native/models/detection/deimv2.py new file mode 100644 index 00000000000..7c44eff924e --- /dev/null +++ b/library/src/otx/backend/native/models/detection/deimv2.py @@ -0,0 +1,165 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""DEIM-DFine model implementations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, Literal + +from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable +from otx.backend.native.models.detection.backbones.dinov3sta import DINOv3STAs +from otx.backend.native.models.detection.detectors import DETR +from otx.backend.native.models.detection.heads.deim_decoder import DEIMTransformer +from otx.backend.native.models.detection.losses.deim_loss import DEIMCriterion +from otx.backend.native.models.detection.necks.dfine_hybrid_encoder import HybridEncoder +from otx.backend.native.models.utils.utils import load_checkpoint +from otx.config.data import TileConfig +from otx.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable + +from .deim import DEIMDFine + +if TYPE_CHECKING: + from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable + + from otx.backend.native.schedulers import LRSchedulerListCallable + from otx.metrics import MetricCallable + from otx.types.label import LabelInfoTypes + + +class DEIMV2(DEIMDFine): + """OTX Detection model class for DEIMV2. + + DEIMV2 is an improved version of DEIMV1, which introduces DINOV3 backbone and improved decoder. + + It is based on the DEIMV2 paper: https://arxiv.org/abs/2412.04234 + The original implementation is available at: https://github.com/Intellindust-AI-Lab/DEIMv2/tree/main + + The model should be used with + :class:`~otx.backend.native.callbacks.aug_scheduler.DataAugSwitch` and + :class:`~otx.backend.native.callbacks.aug_scheduler.AugmentationSchedulerCallback` + for dynamic augmentation scheduling. + + Attributes: + _pretrained_weights (ClassVar[dict[str, str]]): Dictionary containing URLs for pretrained weights. + input_size_multiplier (int): Multiplier for the input size. + + Args: + label_info (LabelInfoTypes): Information about the labels. + data_input_params (DataInputParams | None): Parameters for the image data preprocessing. + If None, uses _default_preprocessing_params. + model_name (literal, optional): Name of the model to use. Defaults to "deim_dfine_hgnetv2_x". + optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler. + Defaults to DefaultSchedulerCallable. + metric (MetricCallable, optional): Callable for the metric. Defaults to MeanAveragePrecisionFMeasureCallable. + multi_scale (bool, optional): Whether to use multi-scale training. Defaults to False. + torch_compile (bool, optional): Whether to use torch compile. Defaults to False. + tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False). + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing to reduce memory + at the cost of increased computation. Recommended for large models or limited GPU memory. + Defaults to False. + """ + + _pretrained_weights: ClassVar[dict[str, str]] = { + "deimv2_x": "https://github.com/kprokofi/DEIMv2/releases/download/1.0.0/deimv2_dinov3_x_coco.pth", + "deimv2_l": "https://github.com/kprokofi/DEIMv2/releases/download/1.0.0/deimv2_dinov3_l_coco.pth", + "deimv2_m": "https://github.com/kprokofi/DEIMv2/releases/download/1.0.0/deimv2_dinov3_m_coco.pth", + "deimv2_s": "https://github.com/kprokofi/DEIMv2/releases/download/1.0.0/deimv2_dinov3_s_coco.pth", + } + + input_size_multiplier = 32 + + def __init__( + self, + label_info: LabelInfoTypes, + data_input_params: DataInputParams | None = None, + model_name: Literal[ + "deimv2_x", + "deimv2_l", + "deimv2_m", + "deimv2_s", + ] = "deimv2_x", + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MeanAveragePrecisionFMeasureCallable, + multi_scale: bool = False, + torch_compile: bool = False, + tile_config: TileConfig = TileConfig(enable_tiler=False), + gradient_checkpointing: bool = False, + ) -> None: + self.gradient_checkpointing = gradient_checkpointing + super().__init__( + model_name=model_name, # type: ignore[arg-type] + label_info=label_info, + data_input_params=data_input_params, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + tile_config=tile_config, + multi_scale=multi_scale, + ) + + def _create_model(self, num_classes: int | None = None) -> DETR: + """Create DEIM-DFine model.""" + num_classes = num_classes if num_classes is not None else self.num_classes + backbone = DINOv3STAs(model_name=self.model_name, gradient_checkpointing=self.gradient_checkpointing) + encoder = HybridEncoder(model_name=self.model_name) + decoder = DEIMTransformer( + model_name=self.model_name, + num_classes=num_classes, + eval_spatial_size=self.data_input_params.input_size, + ) + + criterion = DEIMCriterion( + weight_dict={ + "loss_vfl": 1, + "loss_bbox": 5, + "loss_giou": 2, + "loss_fgl": 0.15, + "loss_ddf": 1.5, + "loss_mal": 1.0, + }, + alpha=0.75, + gamma=1.5, + reg_max=32, + num_classes=num_classes, + ) + + backbone_lr_mapping = { + "deimv2_x": 0.00001, + "deimv2_l": 0.0000125, + "deimv2_m": 0.000025, + "deimv2_s": 0.000025, + } + + try: + backbone_lr = backbone_lr_mapping[self.model_name] + except KeyError as err: + msg = f"Unsupported model name: {self.model_name}" + raise ValueError(msg) from err + + optimizer_configuration = [ + {"params": "^(?=.*.dinov3)(?!.*(?:norm|bn|bias)).*$", "lr": backbone_lr}, + {"params": "^(?=.*.dinov3)(?=.*(?:norm|bn|bias)).*$", "lr": backbone_lr, "weight_decay": 0.0}, + {"params": "^(?=.*(?:sta|encoder|decoder))(?=.*(?:norm|bn|bias)).*$", "weight_decay": 0.0}, + ] + + model = DETR( + multi_scale=self.multi_scale, + backbone=backbone, + encoder=encoder, + decoder=decoder, + criterion=criterion, + num_classes=num_classes, + optimizer_configuration=optimizer_configuration, + input_size=self.data_input_params.input_size[0], + ) + model.init_weights() + load_checkpoint(model, self._pretrained_weights[self.model_name], map_location="cpu") + return model + + @property + def _default_preprocessing_params(self) -> DataInputParams | dict[str, DataInputParams]: + return DataInputParams(input_size=(640, 640), mean=(123.675, 116.280, 103.530), std=(58.395, 57.120, 57.375)) diff --git a/library/src/otx/backend/native/models/detection/heads/deim_decoder.py b/library/src/otx/backend/native/models/detection/heads/deim_decoder.py new file mode 100644 index 00000000000..e5c09631752 --- /dev/null +++ b/library/src/otx/backend/native/models/detection/heads/deim_decoder.py @@ -0,0 +1,1069 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""DEIM Transformer Decoder. + +Modified from DEIMv2 (https://github.com/Intellindust-AI-Lab/DEIMv2) +""" + +from __future__ import annotations + +import copy +from collections import OrderedDict +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, ClassVar + +import torch +import torch.nn.functional as f +from torch import Tensor, nn +from torch.nn import init + +from otx.backend.native.models.common.layers.transformer_layers import ( + LQE, + MLP, + Gate, + Integral, + MSDeformableAttentionV2, + SwiGLUFFN, + get_contrastive_denoising_training_group, +) +from otx.backend.native.models.common.utils.utils import inverse_sigmoid +from otx.backend.native.models.detection.utils.utils import dfine_distance2bbox, dfine_weighting_function +from otx.backend.native.models.modules.norm import RMSNorm +from otx.backend.native.models.utils.weight_init import bias_init_with_prob + +if TYPE_CHECKING: + from torch.nn import ModuleList + +__all__ = ["DEIMTransformer"] + + +class TransformerDecoderLayer(nn.Module): + """Single transformer decoder layer with self-attention, cross-attention, and FFN. + + Args: + d_model: Model dimension. + n_head: Number of attention heads. + dim_feedforward: FFN hidden dimension. + dropout: Dropout rate. + n_levels: Number of feature levels for deformable attention. + n_points: Number of sampling points per level. + layer_scale: Optional scale factor for wide layers. + use_gateway: Whether to use gated fusion for cross-attention. + """ + + def __init__( + self, + d_model: int = 256, + n_head: int = 8, + dim_feedforward: int = 1024, + dropout: float = 0.0, + n_levels: int = 4, + n_points: int | list[int] = 4, + layer_scale: float | None = None, + use_gateway: bool = False, + ) -> None: + super().__init__() + + if layer_scale is not None: + dim_feedforward = round(layer_scale * dim_feedforward) + d_model = round(layer_scale * d_model) + + # self attention - use memory-efficient scaled_dot_product_attention + self.n_head = n_head + self.head_dim = d_model // n_head + self.qkv_proj = nn.Linear(d_model, 3 * d_model) + self.out_proj = nn.Linear(d_model, d_model) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = RMSNorm(d_model) + + # cross attention + n_points_list = [n_points] * n_levels if isinstance(n_points, int) else n_points + self.cross_attn = MSDeformableAttentionV2(d_model, n_head, n_levels, n_points_list) + self.dropout2 = nn.Dropout(dropout) + + self.use_gateway = use_gateway + if use_gateway: + self.gateway = Gate(d_model, use_rmsnorm=True) + else: + self.norm2 = RMSNorm(d_model) + + # ffn + self.swish_ffn = SwiGLUFFN(d_model, dim_feedforward // 2, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = RMSNorm(d_model) + + def with_pos_embed(self, tensor: Tensor, pos: Tensor | None) -> Tensor: + """Add positional embedding to tensor if provided.""" + return tensor if pos is None else tensor + pos + + def _self_attention( + self, + q: Tensor, + k: Tensor, + v: Tensor, + attn_mask: Tensor | None = None, + ) -> Tensor: + """Memory-efficient self-attention using scaled_dot_product_attention. + + Uses Flash Attention when available (PyTorch 2.0+, CUDA, no mask or causal mask). + + Args: + q: Query tensor of shape (B, N, C). + k: Key tensor of shape (B, N, C). + v: Value tensor of shape (B, N, C). + attn_mask: Optional attention mask of shape (N, N) or (B, N, N). + + Returns: + Attention output of shape (B, N, C). + """ + B, N, C = q.shape # noqa: N806 + + # Project Q, K, V together for efficiency + qkv = self.qkv_proj(q) + qkv = qkv.reshape(B, N, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # Each: (B, n_head, N, head_dim) + + # Convert boolean mask to float mask for scaled_dot_product_attention + # True means "mask out" (don't attend), so we use -inf for those positions + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask = attn_mask.float().masked_fill(attn_mask, float("-inf")) + # Expand mask for multi-head attention: (N, N) -> (1, 1, N, N) + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) + + # Use scaled_dot_product_attention - automatically uses Flash Attention when possible + out = f.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + + # Reshape back: (B, n_head, N, head_dim) -> (B, N, C) + out = out.transpose(1, 2).reshape(B, N, C) + return self.out_proj(out) + + def forward( + self, + target: Tensor, + reference_points: Tensor, + value: tuple[Tensor, ...], + spatial_shapes: list[list[int]], + attn_mask: Tensor | None = None, + query_pos_embed: Tensor | None = None, + ) -> Tensor: + """Forward pass through decoder layer. + + Args: + target: Query features of shape (B, N, C). + reference_points: Reference points of shape (B, N, 1, 4). + value: Multi-scale value features. + spatial_shapes: Spatial shapes of each feature level. + attn_mask: Optional attention mask. + query_pos_embed: Optional positional embedding for queries. + + Returns: + Updated query features of shape (B, N, C). + """ + # self attention using memory-efficient scaled_dot_product_attention + q = k = self.with_pos_embed(target, query_pos_embed) + + target2 = self._self_attention(q, k, target, attn_mask=attn_mask) + target = target + self.dropout1(target2) + target = self.norm1(target) + + # cross attention + target2 = self.cross_attn(self.with_pos_embed(target, query_pos_embed), reference_points, value, spatial_shapes) + + if self.use_gateway: + target = self.gateway(target, self.dropout2(target2)) + else: + target = target + self.dropout2(target2) + target = self.norm2(target) + + # ffn + target2 = self.swish_ffn(target) + target = target + self.dropout4(target2) + return self.norm3(target.clamp(min=-65504, max=65504)) + + +class TransformerDecoder(nn.Module): + """Transformer Decoder with Fine-grained Distribution Refinement (FDR). + + Refines object detection predictions through iterative updates across multiple layers, + utilizing attention mechanisms, location quality estimators, and distribution refinement + techniques to improve bounding box accuracy. + + Args: + hidden_dim: Hidden dimension. + decoder_layer: Standard decoder layer. + decoder_layer_wide: Wide decoder layer for later stages. + num_layers: Total number of decoder layers. + num_head: Number of attention heads. + reg_max: Maximum regression bins. + reg_scale: Regression scale factor. + up: Up-sampling parameter. + eval_idx: Index of layer used for evaluation. + layer_scale: Scale factor for wide layers. + act: Activation function class. + """ + + def __init__( + self, + hidden_dim: int, + decoder_layer: TransformerDecoderLayer, + decoder_layer_wide: TransformerDecoderLayer, + num_layers: int, + num_head: int, + reg_max: int, + reg_scale: nn.Parameter, + up: nn.Parameter, + eval_idx: int = -1, + layer_scale: int = 2, + act: Callable[..., nn.Module] = partial(nn.ReLU, inplace=True), + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.layer_scale = layer_scale + self.num_head = num_head + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max + self.layers = nn.ModuleList( + [copy.deepcopy(decoder_layer) for _ in range(self.eval_idx + 1)] + + [copy.deepcopy(decoder_layer_wide) for _ in range(num_layers - self.eval_idx - 1)] + ) + self.lqe_layers = nn.ModuleList( + [copy.deepcopy(LQE(4, 64, 2, reg_max, activation=act)) for _ in range(num_layers)] + ) + + def value_op( + self, + memory: Tensor, + value_proj: nn.Module | None, + value_scale: int | None, + memory_mask: Tensor | None, + memory_spatial_shapes: list[list[int]], + ) -> tuple[Tensor, ...]: + """Preprocess values for MSDeformableAttention. + + Args: + memory: Encoder memory of shape (B, L, C). + value_proj: Optional projection layer. + value_scale: Optional scale for interpolation. + memory_mask: Optional memory mask. + memory_spatial_shapes: Spatial shapes of each level. + + Returns: + Tuple of value tensors split by level. + """ + value = value_proj(memory) if value_proj is not None else memory + value = f.interpolate(memory, size=value_scale) if value_scale is not None else value + if memory_mask is not None: + value = value * memory_mask.to(value.dtype).unsqueeze(-1) + value = value.reshape(value.shape[0], value.shape[1], self.num_head, -1) + split_shape = [h * w for h, w in memory_spatial_shapes] + return value.permute(0, 2, 3, 1).split(split_shape, dim=-1) + + def convert_to_deploy(self) -> None: + """Convert model for deployment by removing unused layers.""" + self.project = dfine_weighting_function(self.reg_max, self.up, self.reg_scale) + self.layers = self.layers[: self.eval_idx + 1] + self.lqe_layers = nn.ModuleList([nn.Identity()] * self.eval_idx + [self.lqe_layers[self.eval_idx]]) + + def forward( + self, + target: Tensor, + ref_points_unact: Tensor, + memory: Tensor, + spatial_shapes: list[list[int]], + bbox_head: ModuleList, + score_head: ModuleList, + query_pos_head: MLP, + pre_bbox_head: MLP, + integral: Integral, + up: nn.Parameter, + reg_scale: nn.Parameter, + attn_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + dn_meta: dict[str, Any] | None = None, + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Forward pass through decoder. + + Args: + target: Query features of shape (B, N, C). + ref_points_unact: Unactivated reference points of shape (B, N, 4). + memory: Encoder memory of shape (B, L, C). + spatial_shapes: Spatial shapes of each feature level. + bbox_head: Bounding box regression heads. + score_head: Classification heads. + query_pos_head: Query position embedding head. + pre_bbox_head: Pre-bbox head for initial predictions. + integral: Integral layer for distribution regression. + up: Up-sampling parameter. + reg_scale: Regression scale parameter. + attn_mask: Optional attention mask. + memory_mask: Optional memory mask. + dn_meta: Optional denoising metadata. + + Returns: + Tuple of (bboxes, logits, corners, refs, pre_bboxes, pre_scores). + """ + output = target + output_detach = pred_corners_undetach = 0 + value = self.value_op(memory, None, None, memory_mask, spatial_shapes) + + dec_out_bboxes = [] + dec_out_logits = [] + dec_out_pred_corners = [] + dec_out_refs = [] + if not hasattr(self, "project"): + project = dfine_weighting_function(self.reg_max, up, reg_scale) + else: + project = self.project + + ref_points_detach = f.sigmoid(ref_points_unact) + query_pos_embed = query_pos_head(ref_points_detach).clamp(min=-10, max=10) + + for i, layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + + if i >= self.eval_idx + 1 and self.layer_scale > 1: + query_pos_embed = f.interpolate(query_pos_embed, scale_factor=self.layer_scale) + value = self.value_op(memory, None, query_pos_embed.shape[-1], memory_mask, spatial_shapes) + output = f.interpolate(output, size=query_pos_embed.shape[-1]) + output_detach = output.detach() + + output = layer(output, ref_points_input, value, spatial_shapes, attn_mask, query_pos_embed) + + if i == 0: + # Initial bounding box predictions with inverse sigmoid refinement + pre_bboxes = f.sigmoid(pre_bbox_head(output) + inverse_sigmoid(ref_points_detach)) + pre_scores = score_head[0](output) + ref_points_initial = pre_bboxes.detach() + + # Refine bounding box corners using FDR, integrating previous layer's corrections + pred_corners = bbox_head[i](output + output_detach) + pred_corners_undetach + inter_ref_bbox = dfine_distance2bbox(ref_points_initial, integral(pred_corners, project), reg_scale) + + if self.training or i == self.eval_idx: + scores = score_head[i](output) + # Lqe does not affect the performance here. + scores = self.lqe_layers[i](scores, pred_corners) + dec_out_logits.append(scores) + dec_out_bboxes.append(inter_ref_bbox) + dec_out_pred_corners.append(pred_corners) + dec_out_refs.append(ref_points_initial) + + if not self.training: + break + + pred_corners_undetach = pred_corners + ref_points_detach = inter_ref_bbox.detach() + output_detach = output.detach() + + return ( + torch.stack(dec_out_bboxes), + torch.stack(dec_out_logits), + torch.stack(dec_out_pred_corners), + torch.stack(dec_out_refs), + pre_bboxes, + pre_scores, + ) + + +class DEIMTransformerModule(nn.Module): + """DEIM Transformer module for object detection. + + This module implements the DEIM (Detection Transformer with Efficient + Integration Module) architecture with Fine-grained Distribution Refinement + (FDR) for accurate object detection. + + Attributes: + __share__: List of attributes shared across instances. + hidden_dim: Hidden dimension size. + nhead: Number of attention heads. + feat_strides: Feature strides for each level. + num_levels: Number of feature levels. + num_classes: Number of object classes. + num_queries: Number of detection queries. + eps: Small epsilon for numerical stability. + num_layers: Number of decoder layers. + eval_spatial_size: Spatial size for evaluation. + aux_loss: Whether to use auxiliary losses. + reg_max: Maximum regression value for FDR. + """ + + __share__: ClassVar[list[str]] = ["num_classes", "eval_spatial_size"] + + def __init__( # noqa: PLR0913 + self, + num_classes: int = 80, + hidden_dim: int = 256, + num_queries: int = 300, + feat_channels: list[int] | None = None, + feat_strides: list[int] | None = None, + num_levels: int = 3, + num_points: list[int] | None = None, + nhead: int = 8, + num_layers: int = 6, + dim_feedforward: int = 2048, + dropout: float = 0.0, + activation: Callable[..., nn.Module] = nn.SiLU, + num_denoising: int = 100, + label_noise_ratio: float = 0.5, + box_noise_scale: float = 1.0, + learn_query_content: bool = False, + eval_spatial_size: tuple[int, int] | None = None, + eval_idx: int = -1, + eps: float = 1e-2, + aux_loss: bool = True, + cross_attn_method: str = "default", + query_select_method: str = "default", + reg_max: int = 32, + reg_scale: float = 4.0, + layer_scale: int = 1, + use_gateway: bool = True, + share_bbox_head: bool = False, + share_score_head: bool = False, + ) -> None: + """Initialize DEIMTransformerModule. + + Args: + num_classes: Number of object classes. + hidden_dim: Hidden dimension size. + num_queries: Number of detection queries. + feat_channels: Feature channels for each input level. + feat_strides: Feature strides for each level. + num_levels: Number of feature levels. + num_points: Number of sampling points per level. + nhead: Number of attention heads. + num_layers: Number of decoder layers. + dim_feedforward: Feedforward network dimension. + dropout: Dropout rate. + activation: Activation function class. + num_denoising: Number of denoising queries for training. + label_noise_ratio: Label noise ratio for denoising. + box_noise_scale: Box noise scale for denoising. + learn_query_content: Whether to learn query content. + eval_spatial_size: Spatial size for evaluation (H, W). + eval_idx: Evaluation layer index (-1 for last). + eps: Epsilon for numerical stability. + aux_loss: Whether to use auxiliary losses. + cross_attn_method: Cross attention method ('default' or 'discrete'). + query_select_method: Query selection method. + reg_max: Maximum regression value for FDR. + reg_scale: Regression scale factor. + layer_scale: Scale factor for wide layers. + use_gateway: Whether to use gateway fusion. + share_bbox_head: Whether to share bbox head across layers. + share_score_head: Whether to share score head across layers. + """ + super().__init__() + if feat_channels is None: + feat_channels = [256, 256, 256] + if feat_strides is None: + feat_strides = [8, 16, 32] + if num_points is None: + num_points = [3, 6, 3] + if len(feat_channels) > num_levels: + msg = f"feat_channels ({len(feat_channels)}) must be <= num_levels ({num_levels})" + raise ValueError(msg) + if len(feat_strides) != len(feat_channels): + msg = f"feat_strides ({len(feat_strides)}) must match feat_channels ({len(feat_channels)})" + raise ValueError(msg) + + for _ in range(num_levels - len(feat_strides)): + feat_strides.append(feat_strides[-1] * 2) + + self.hidden_dim = hidden_dim + scaled_dim = round(layer_scale * hidden_dim) + self.nhead = nhead + self.feat_strides = feat_strides + self.num_levels = num_levels + self.num_classes = num_classes + self.num_queries = num_queries + self.eps = eps + self.num_layers = num_layers + self.eval_spatial_size = eval_spatial_size + self.aux_loss = aux_loss + self.reg_max = reg_max + + self.cross_attn_method = cross_attn_method + self.query_select_method = query_select_method + + # backbone feature projection + self._build_input_proj_layer(feat_channels) + + # Transformer module + self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False) + self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False) + decoder_layer = TransformerDecoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + num_levels, + num_points, + use_gateway=use_gateway, + ) + decoder_layer_wide = TransformerDecoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + num_levels, + num_points, + layer_scale=layer_scale, + use_gateway=use_gateway, + ) + self.decoder = TransformerDecoder( + hidden_dim, + decoder_layer, + decoder_layer_wide, + num_layers, + nhead, + reg_max, + self.reg_scale, + self.up, + eval_idx, + layer_scale, + act=partial(activation, inplace=True), + ) + # denoising + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + if num_denoising > 0: + self.denoising_class_embed = nn.Embedding(num_classes + 1, hidden_dim, padding_idx=num_classes) + init.normal_(self.denoising_class_embed.weight[:-1]) + + # decoder embedding + self.learn_query_content = learn_query_content + if learn_query_content: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + + if query_select_method == "agnostic": + self.enc_score_head = nn.Linear(hidden_dim, 1) + else: + self.enc_score_head = nn.Linear(hidden_dim, num_classes) + self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, activation=partial(activation, inplace=True)) + + self.query_pos_head = MLP(4, hidden_dim, hidden_dim, 3, activation=partial(activation, inplace=True)) + + # decoder head + self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, activation=partial(activation, inplace=True)) + self.integral = Integral(self.reg_max) + + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + dec_score_head = nn.Linear(hidden_dim, num_classes) + self.dec_score_head = nn.ModuleList( + [dec_score_head if share_score_head else copy.deepcopy(dec_score_head) for _ in range(self.eval_idx + 1)] + + [copy.deepcopy(dec_score_head) for _ in range(num_layers - self.eval_idx - 1)] + ) + + # Share the same bbox head for all layers + dec_bbox_head = MLP( + hidden_dim, hidden_dim, 4 * (self.reg_max + 1), 3, activation=partial(activation, inplace=True) + ) + self.dec_bbox_head = nn.ModuleList( + [dec_bbox_head if share_bbox_head else copy.deepcopy(dec_bbox_head) for _ in range(self.eval_idx + 1)] + + [ + MLP(scaled_dim, scaled_dim, 4 * (self.reg_max + 1), 3, activation=partial(activation, inplace=True)) + for _ in range(num_layers - self.eval_idx - 1) + ] + ) + + # init encoder output anchors and valid_mask + if self.eval_spatial_size: + anchors, valid_mask = self._generate_anchors() + self.register_buffer("anchors", anchors) + self.register_buffer("valid_mask", valid_mask) + # init encoder output anchors and valid_mask + if self.eval_spatial_size: + self.anchors, self.valid_mask = self._generate_anchors() + + self._reset_parameters(feat_channels) + + def convert_to_deploy(self) -> None: + """Convert model to deployment mode by pruning unused components.""" + self.dec_score_head = nn.ModuleList([nn.Identity()] * (self.eval_idx) + [self.dec_score_head[self.eval_idx]]) + self.dec_bbox_head = nn.ModuleList( + [self.dec_bbox_head[i] if i <= self.eval_idx else nn.Identity() for i in range(len(self.dec_bbox_head))] + ) + + def _reset_parameters(self, feat_channels: list[int]) -> None: + """Reset model parameters with appropriate initialization. + + Args: + feat_channels: List of feature channel dimensions. + """ + bias = bias_init_with_prob(0.01) + init.constant_(self.enc_score_head.bias, bias) + init.constant_(self.enc_bbox_head.layers[-1].weight, 0) + init.constant_(self.enc_bbox_head.layers[-1].bias, 0) + + init.constant_(self.pre_bbox_head.layers[-1].weight, 0) + init.constant_(self.pre_bbox_head.layers[-1].bias, 0) + + for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): + init.constant_(cls_.bias, bias) + if hasattr(reg_, "layers"): + init.constant_(reg_.layers[-1].weight, 0) + init.constant_(reg_.layers[-1].bias, 0) + + if self.learn_query_content: + init.xavier_uniform_(self.tgt_embed.weight) + init.xavier_uniform_(self.query_pos_head.layers[0].weight) + init.xavier_uniform_(self.query_pos_head.layers[1].weight) + init.xavier_uniform_(self.query_pos_head.layers[-1].weight) + for m, in_channels in zip(self.input_proj, feat_channels): + if in_channels != self.hidden_dim: + init.xavier_uniform_(m[0].weight) + + def _build_input_proj_layer(self, feat_channels: list[int]) -> None: + """Build input projection layers for feature transformation. + + Args: + feat_channels: List of input feature channel dimensions. + """ + self.input_proj = nn.ModuleList() + for in_channels in feat_channels: + if in_channels == self.hidden_dim: + self.input_proj.append(nn.Identity()) + else: + self.input_proj.append( + nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)), + ("norm", nn.BatchNorm2d(self.hidden_dim)), + ] + ) + ) + ) + + in_channels = feat_channels[-1] + + for _ in range(self.num_levels - len(feat_channels)): + if in_channels == self.hidden_dim: + self.input_proj.append(nn.Identity()) + else: + self.input_proj.append( + nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)), + ("norm", nn.BatchNorm2d(self.hidden_dim)), + ] + ) + ) + ) + in_channels = self.hidden_dim + + def _get_encoder_input(self, feats: list[Tensor]) -> tuple[Tensor, list[list[int]]]: + """Get encoder input from multi-scale features. + + Projects input features to hidden dimension and flattens them + for transformer processing. + + Args: + feats: List of feature tensors from backbone. + + Returns: + Tuple of (flattened features, spatial shapes per level). + """ + # get projection features + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + if self.num_levels > len(proj_feats): + len_srcs = len(proj_feats) + for i in range(len_srcs, self.num_levels): + if i == len_srcs: + proj_feats.append(self.input_proj[i](feats[-1])) + else: + proj_feats.append(self.input_proj[i](proj_feats[-1])) + + # get encoder inputs + feat_flatten = [] + spatial_shapes = [] + for feat in proj_feats: + _, _, h, w = feat.shape + # [b, c, h, w] -> [b, h*w, c] + feat_flatten.append(feat.flatten(2).permute(0, 2, 1)) + # [num_levels, 2] + spatial_shapes.append([h, w]) + + # [b, l, c] + feat_flatten = torch.concat(feat_flatten, 1) + return feat_flatten, spatial_shapes + + def _generate_anchors( + self, + spatial_shapes: list[list[int]] | None = None, + grid_size: float = 0.05, + dtype: torch.dtype = torch.float32, + device: str | torch.device = "cpu", + ) -> tuple[Tensor, Tensor]: + """Generate anchor points for all feature levels. + + Args: + spatial_shapes: Spatial shapes for each level. If None, computed from eval_spatial_size. + grid_size: Base grid size for anchors. + dtype: Data type for anchor tensors. + device: Device to place anchor tensors on. + + Returns: + Tuple of (anchor coordinates, validity mask). + """ + if spatial_shapes is None: + if self.eval_spatial_size is None: + msg = "eval_spatial_size must be set when spatial_shapes is None" + raise ValueError(msg) + spatial_shapes = [] + eval_h, eval_w = self.eval_spatial_size + for s in self.feat_strides: + spatial_shapes.append([int(eval_h / s), int(eval_w / s)]) + + anchor_list: list[Tensor] = [] + for lvl, (h, w) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + grid_xy = torch.stack([grid_x, grid_y], dim=-1) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor([w, h], dtype=dtype) + wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl) + lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4) + anchor_list.append(lvl_anchors) + + anchors = torch.concat(anchor_list, dim=1).to(device) + valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.inf) + + return anchors, valid_mask + + def _get_decoder_input( + self, + memory: Tensor, + spatial_shapes: list[list[int]], + denoising_logits: Tensor | None = None, + denoising_bbox_unact: Tensor | None = None, + ) -> tuple[Tensor, Tensor, list[Tensor], list[Tensor], Tensor]: + """Prepare input for the decoder. + + Generates anchors, selects top-k queries, and prepares content + embeddings for decoder processing. + + Args: + memory: Encoder memory of shape (B, L, C). + spatial_shapes: Spatial shapes for each feature level. + denoising_logits: Optional denoising logits for training. + denoising_bbox_unact: Optional denoising bbox for training. + + Returns: + Tuple of (content, bbox_unact, topk_bboxes_list, topk_logits_list, enc_logits). + """ + # prepare input for decoder + if self.training or self.eval_spatial_size is None: + anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device) + else: + anchors = self.anchors + valid_mask = self.valid_mask + if memory.shape[0] > 1: + anchors = anchors.repeat(memory.shape[0], 1, 1) + + # memory = torch.where(valid_mask, memory, 0) + memory = valid_mask.to(memory.dtype) * memory + + enc_outputs_logits: Tensor = self.enc_score_head(memory) + + # select topk queries + enc_topk_memory, enc_topk_logits, enc_topk_anchors = self._select_topk( + memory, enc_outputs_logits, anchors, self.num_queries + ) + + enc_topk_bbox_unact: Tensor = self.enc_bbox_head(enc_topk_memory) + enc_topk_anchors + + enc_topk_bboxes_list, enc_topk_logits_list = [], [] + if self.training: + enc_topk_bboxes = f.sigmoid(enc_topk_bbox_unact) + enc_topk_bboxes_list.append(enc_topk_bboxes) + enc_topk_logits_list.append(enc_topk_logits) + + if self.learn_query_content: + content = self.tgt_embed.weight.unsqueeze(0).tile([memory.shape[0], 1, 1]) + else: + content = enc_topk_memory.detach() + + enc_topk_bbox_unact = enc_topk_bbox_unact.detach() + + if denoising_bbox_unact is not None: + enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1) + content = torch.concat([denoising_logits, content], dim=1) + + return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list, enc_outputs_logits + + def _select_topk( + self, + memory: Tensor, + outputs_logits: Tensor, + outputs_anchors_unact: Tensor, + topk: int, + ) -> tuple[Tensor, Tensor | None, Tensor]: + """Select top-k queries based on classification scores. + + Args: + memory: Encoder memory of shape (B, L, C). + outputs_logits: Classification logits of shape (B, L, num_classes). + outputs_anchors_unact: Unactivated anchor coordinates. + topk: Number of top queries to select. + + Returns: + Tuple of (topk_memory, topk_logits, topk_anchors). + """ + topk_ind: Tensor + if self.query_select_method == "default": + _, topk_ind = torch.topk(outputs_logits.max(-1).values, topk, dim=-1) + + elif self.query_select_method == "one2many": + _, topk_ind = torch.topk(outputs_logits.flatten(1), topk, dim=-1) + topk_ind = topk_ind // self.num_classes + + elif self.query_select_method == "agnostic": + _, topk_ind = torch.topk(outputs_logits.squeeze(-1), topk, dim=-1) + + topk_anchors = outputs_anchors_unact.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_anchors_unact.shape[-1]) + ) + + topk_logits = ( + outputs_logits.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1])) + if self.training + else None + ) + + topk_memory = memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, memory.shape[-1])) + + return topk_memory, topk_logits, topk_anchors + + def forward( + self, + feats: list[Tensor], + targets: list[dict[str, Any]] | None = None, + explain_mode: bool = False, + ) -> dict[str, Any]: + """Forward pass of the DEIM Transformer module. + + Args: + feats: List of multi-scale feature tensors from backbone. + targets: Optional list of target dictionaries for training. + explain_mode: Whether to include raw logits for explainability. + + Returns: + Dictionary containing predictions and optional auxiliary outputs: + - pred_logits: Classification logits. + - pred_boxes: Predicted bounding boxes. + - pred_corners: Corner predictions (training only). + - ref_points: Reference points (training only). + - aux_outputs: Auxiliary outputs from intermediate layers. + - dn_outputs: Denoising outputs (training only). + """ + # input projection and embedding + memory, spatial_shapes = self._get_encoder_input(feats) + + # prepare denoising training + if self.training and self.num_denoising > 0 and targets is not None: + denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = get_contrastive_denoising_training_group( + targets, + self.num_classes, + self.num_queries, + self.denoising_class_embed, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=1.0, + ) + else: + denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None + + init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list, enc_outputs_logits = ( + self._get_decoder_input(memory, spatial_shapes, denoising_logits, denoising_bbox_unact) + ) + + # decoder + out_bboxes, out_logits, out_corners, out_refs, pre_bboxes, pre_logits = self.decoder( + init_ref_contents, + init_ref_points_unact, + memory, + spatial_shapes, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + self.pre_bbox_head, + self.integral, + self.up, + self.reg_scale, + attn_mask=attn_mask, + dn_meta=dn_meta, + ) + + out_bboxes = out_bboxes.clamp(min=1e-8) + + if self.training and dn_meta is not None: + # the output from the first decoder layer, only one + dn_pre_logits, pre_logits = torch.split(pre_logits, dn_meta["dn_num_split"], dim=1) + dn_pre_bboxes, pre_bboxes = torch.split(pre_bboxes, dn_meta["dn_num_split"], dim=1) + + dn_out_logits, out_logits = torch.split(out_logits, dn_meta["dn_num_split"], dim=2) + dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta["dn_num_split"], dim=2) + + dn_out_corners, out_corners = torch.split(out_corners, dn_meta["dn_num_split"], dim=2) + dn_out_refs, out_refs = torch.split(out_refs, dn_meta["dn_num_split"], dim=2) + + if self.training: + out = { + "pred_logits": out_logits[-1], + "pred_boxes": out_bboxes[-1], + "pred_corners": out_corners[-1], + "ref_points": out_refs[-1], + "up": self.up, + "reg_scale": self.reg_scale, + } + else: + out = {"pred_logits": out_logits[-1], "pred_boxes": out_bboxes[-1]} + + if self.training and self.aux_loss: + out["aux_outputs"] = self._set_aux_loss2( + out_logits[:-1], out_bboxes[:-1], out_corners[:-1], out_refs[:-1], out_corners[-1], out_logits[-1] + ) + out["enc_aux_outputs"] = self._set_aux_loss(enc_topk_logits_list, enc_topk_bboxes_list) + out["pre_outputs"] = {"pred_logits": pre_logits, "pred_boxes": pre_bboxes} + out["enc_meta"] = {"class_agnostic": self.query_select_method == "agnostic"} + + if dn_meta is not None: + out["dn_outputs"] = self._set_aux_loss2( + dn_out_logits, dn_out_bboxes, dn_out_corners, dn_out_refs, dn_out_corners[-1], dn_out_logits[-1] + ) + out["dn_pre_outputs"] = {"pred_logits": dn_pre_logits, "pred_boxes": dn_pre_bboxes} + out["dn_meta"] = dn_meta + + if explain_mode: + out["raw_logits"] = enc_outputs_logits + + return out + + @torch.jit.unused + def _set_aux_loss( + self, + outputs_class: list[Tensor], + outputs_coord: list[Tensor], + ) -> list[dict[str, Tensor]]: + """Set auxiliary loss outputs for encoder. + + This is a workaround to make torchscript happy, as torchscript + doesn't support dictionary with non-homogeneous values. + + Args: + outputs_class: List of classification outputs. + outputs_coord: List of coordinate outputs. + + Returns: + List of dictionaries with pred_logits and pred_boxes. + """ + return [{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + @torch.jit.unused + def _set_aux_loss2( + self, + outputs_class: list[Tensor], + outputs_coord: list[Tensor], + outputs_corners: list[Tensor], + outputs_ref: list[Tensor], + teacher_corners: Tensor | None = None, + teacher_logits: Tensor | None = None, + ) -> list[dict[str, Tensor | None]]: + """Set auxiliary loss outputs for decoder with FDR. + + This is a workaround to make torchscript happy, as torchscript + doesn't support dictionary with non-homogeneous values. + + Args: + outputs_class: List of classification outputs. + outputs_coord: List of coordinate outputs. + outputs_corners: List of corner outputs. + outputs_ref: List of reference point outputs. + teacher_corners: Optional teacher corner predictions. + teacher_logits: Optional teacher logits. + + Returns: + List of dictionaries with predictions and teacher outputs. + """ + return [ + { + "pred_logits": a, + "pred_boxes": b, + "pred_corners": c, + "ref_points": d, + "teacher_corners": teacher_corners, + "teacher_logits": teacher_logits, + } + for a, b, c, d in zip(outputs_class, outputs_coord, outputs_corners, outputs_ref) + ] + + +class DEIMTransformer: + """Factory class for creating DEIMTransformerModule instances. + + Provides predefined configurations for different model sizes (x, l, m, s) + with appropriate hidden dimensions, number of layers, and feedforward dimensions. + + Attributes: + decoder_cfg: Dictionary mapping model names to their configurations. + """ + + decoder_cfg: ClassVar[dict[str, Any]] = { + "deimv2_x": { + "num_layers": 6, + "eval_idx": -1, + "feat_channels": [256, 256, 256], + "hidden_dim": 256, + "dim_feedforward": 2048, + }, + "deimv2_l": { + "feat_channels": [224, 224, 224], + "hidden_dim": 224, + "num_layers": 4, + "eval_idx": -1, + "dim_feedforward": 1792, + }, + "deimv2_m": { + "feat_channels": [256, 256, 256], + "hidden_dim": 256, + "dim_feedforward": 512, + "num_layers": 4, + "eval_idx": -1, + }, + "deimv2_s": { + "feat_channels": [192, 192, 192], + "hidden_dim": 192, + "dim_feedforward": 512, + "num_layers": 4, + "eval_idx": -1, + }, + } + + def __new__( + cls, + model_name: str, + num_classes: int, + eval_spatial_size: tuple[int, int] = (640, 640), + ) -> DEIMTransformerModule: + """Create a new DEIMTransformerModule instance. + + Args: + model_name: Name of the model configuration (e.g., 'deimv2_x'). + num_classes: Number of object classes. + eval_spatial_size: Spatial size for evaluation (H, W). + + Returns: + Configured DEIMTransformerModule instance. + + Raises: + KeyError: If model_name is not found in decoder_cfg. + """ + cfg = cls.decoder_cfg[model_name] + return DEIMTransformerModule(num_classes=num_classes, eval_spatial_size=eval_spatial_size, **cfg) diff --git a/library/src/otx/backend/native/models/detection/heads/dfine_decoder.py b/library/src/otx/backend/native/models/detection/heads/dfine_decoder.py index 8e6e63cd6ab..9fa90f53c90 100644 --- a/library/src/otx/backend/native/models/detection/heads/dfine_decoder.py +++ b/library/src/otx/backend/native/models/detection/heads/dfine_decoder.py @@ -15,9 +15,15 @@ from torch import Tensor, nn from torch.nn import init -from otx.backend.native.models.common.layers.transformer_layers import MLP, MSDeformableAttentionV2 +from otx.backend.native.models.common.layers.transformer_layers import ( + LQE, + MLP, + Gate, + Integral, + MSDeformableAttentionV2, + get_contrastive_denoising_training_group, +) from otx.backend.native.models.common.utils.utils import inverse_sigmoid -from otx.backend.native.models.detection.heads.rtdetr_decoder import get_contrastive_denoising_training_group from otx.backend.native.models.detection.utils.utils import dfine_distance2bbox, dfine_weighting_function from otx.backend.native.models.utils.weight_init import bias_init_with_prob @@ -137,109 +143,6 @@ def forward( return self.norm3(target.clamp(min=-65504, max=65504)) -class Gate(nn.Module): - """Target Gating Layers. - - Args: - d_model (int): The number of expected features in the input. - """ - - def __init__(self, d_model: int) -> None: - super().__init__() - self.gate = nn.Linear(2 * d_model, 2 * d_model) - bias = bias_init_with_prob(0.5) - init.constant_(self.gate.bias, bias) - init.constant_(self.gate.weight, 0) - self.norm = nn.LayerNorm(d_model) - - def forward(self, x1: Tensor, x2: Tensor) -> Tensor: - """Forward function of the gate. - - Args: - x1 (Tensor): first target input tensor. - x2 (Tensor): second target input tensor. - - Returns: - Tensor: gated target tensor. - """ - gate_input = torch.cat([x1, x2], dim=-1) - gates = torch.sigmoid(self.gate(gate_input)) - gate1, gate2 = gates.chunk(2, dim=-1) - return self.norm(gate1 * x1 + gate2 * x2) - - -class Integral(nn.Module): - """A static layer that calculates integral results from a distribution. - - This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`, - where Pr(n) is the softmax probability vector representing the discrete - distribution, and W(n) is the non-uniform Weighting Function. - - Args: - reg_max (int): Max number of the discrete bins. Default is 32. - It can be adjusted based on the dataset or task requirements. - """ - - def __init__(self, reg_max: int = 32): - super().__init__() - self.reg_max = reg_max - - def forward(self, x: Tensor, box_distance_weight: Tensor) -> Tensor: - """Forward function of the Integral layer.""" - shape = x.shape - x = f.softmax(x.reshape(-1, self.reg_max + 1), dim=1) - x = f.linear(x, box_distance_weight).reshape(-1, 4) - return x.reshape([*list(shape[:-1]), -1]) - - -class LQE(nn.Module): - """Localization Quality Estimation. - - Args: - k (int): number of edge points. - hidden_dim (int): The number of expected features in the input. - num_layers (int): The number of layers in the MLP. - reg_max (int): Max number of the discrete bins. - """ - - def __init__( - self, - k: int, - hidden_dim: int, - num_layers: int, - reg_max: int, - ): - super().__init__() - self.k = k - self.reg_max = reg_max - self.reg_conf = MLP( - input_dim=4 * (k + 1), - hidden_dim=hidden_dim, - output_dim=1, - num_layers=num_layers, - activation=partial(nn.ReLU, inplace=True), - ) - init.constant_(self.reg_conf.layers[-1].bias, 0) - init.constant_(self.reg_conf.layers[-1].weight, 0) - - def forward(self, scores: Tensor, pred_corners: Tensor) -> Tensor: - """Forward function of the LQE layer. - - Args: - scores (Tensor): Prediction scores. - pred_corners (Tensor): Predicted bounding box corners. - - Returns: - Tensor: Updated scores. - """ - b, num_pred, _ = pred_corners.size() - prob = f.softmax(pred_corners.reshape(b, num_pred, 4, self.reg_max + 1), dim=-1) - prob_topk, _ = prob.topk(self.k, dim=-1) - stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1) - quality_score = self.reg_conf(stat.reshape(b, num_pred, -1)) - return scores + quality_score - - class TransformerDecoder(nn.Module): """Transformer Decoder implementing Fine-grained Distribution Refinement (FDR). diff --git a/library/src/otx/backend/native/models/detection/heads/rtdetr_decoder.py b/library/src/otx/backend/native/models/detection/heads/rtdetr_decoder.py index d6e308a7aaa..6b7f3574164 100644 --- a/library/src/otx/backend/native/models/detection/heads/rtdetr_decoder.py +++ b/library/src/otx/backend/native/models/detection/heads/rtdetr_decoder.py @@ -11,123 +11,20 @@ from typing import Any, Callable, ClassVar import torch -import torchvision from torch import nn from torch.nn import init -from otx.backend.native.models.common.layers.transformer_layers import MLP, MSDeformableAttention +from otx.backend.native.models.common.layers.transformer_layers import ( + MLP, + MSDeformableAttention, + get_contrastive_denoising_training_group, +) from otx.backend.native.models.common.utils.utils import inverse_sigmoid from otx.backend.native.models.modules.base_module import BaseModule __all__ = ["RTDETRTransformer"] -def get_contrastive_denoising_training_group( - targets: list[dict[str, torch.Tensor]], - num_classes: int, - num_queries: int, - class_embed: torch.nn.Module, - num_denoising: int = 100, - label_noise_ratio: float = 0.5, - box_noise_scale: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]] | tuple[None, None, None, None]: - """Generate contrastive denoising training group. - - Args: - targets (List[Dict[str, torch.Tensor]]): List of target dictionaries. - num_classes (int): Number of classes. - num_queries (int): Number of queries. - class_embed (torch.nn.Module): Class embedding module. - num_denoising (int, optional): Number of denoising queries. Defaults to 100. - label_noise_ratio (float, optional): Ratio of label noise. Defaults to 0.5. - box_noise_scale (float, optional): Scale of box noise. Defaults to 1.0. - - Returns: - Tuple[Tensor,Tensor,Tensor, dict[str, Tensor]] | tuple[None,None,None,None]: - Tuple containing input query class, input query bbox, attention mask, and denoising metadata. - """ - num_gts = [len(t["labels"]) for t in targets] - device = targets[0]["labels"].device - - max_gt_num = max(num_gts) - if max_gt_num == 0: - return None, None, None, None - - num_group = num_denoising // max_gt_num - num_group = 1 if num_group == 0 else num_group - # pad gt to max_num of a batch - bs = len(num_gts) - - input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device) - input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device) - pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device) - - for i in range(bs): - num_gt = num_gts[i] - if num_gt > 0: - input_query_class[i, :num_gt] = targets[i]["labels"] - input_query_bbox[i, :num_gt] = targets[i]["boxes"] - pad_gt_mask[i, :num_gt] = 1 - # each group has positive and negative queries. - input_query_class = input_query_class.tile([1, 2 * num_group]) - input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1]) - pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group]) - # positive and negative mask - negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device) - negative_gt_mask[:, max_gt_num:] = 1 - negative_gt_mask = negative_gt_mask.tile([1, num_group, 1]) - positive_gt_mask = 1 - negative_gt_mask - # contrastive denoising training positive index - positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask - dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] - dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts]) - # total denoising queries - num_denoising = int(max_gt_num * 2 * num_group) - - if label_noise_ratio > 0: - mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) - # randomly put a new one here - new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) - input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) - - if box_noise_scale > 0: - known_bbox = torchvision.ops.box_convert(input_query_bbox, in_fmt="cxcywh", out_fmt="xyxy") - diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale - rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 - rand_part = torch.rand_like(input_query_bbox) - rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) - rand_part *= rand_sign - known_bbox += rand_part * diff - known_bbox.clip_(min=0.0, max=1.0) - input_query_bbox = torchvision.ops.box_convert(known_bbox, in_fmt="xyxy", out_fmt="cxcywh") - input_query_bbox = inverse_sigmoid(input_query_bbox) - - input_query_class = class_embed(input_query_class) - - tgt_size = num_denoising + num_queries - attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device) - # match query cannot see the reconstruction - attn_mask[num_denoising:, :num_denoising] = True - - # reconstruct cannot see each other - for i in range(num_group): - if i == 0: - attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True - if i == num_group - 1: - attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * i * 2] = True - else: - attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True - attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * 2 * i] = True - - dn_meta = { - "dn_positive_idx": dn_positive_idx, - "dn_num_group": num_group, - "dn_num_split": [num_denoising, num_queries], - } - - return input_query_class, input_query_bbox, attn_mask, dn_meta - - class TransformerDecoderLayer(nn.Module): """TransformerDecoderLayer. diff --git a/library/src/otx/backend/native/models/detection/necks/dfine_hybrid_encoder.py b/library/src/otx/backend/native/models/detection/necks/dfine_hybrid_encoder.py index 3a552c3267f..600ce3d131a 100644 --- a/library/src/otx/backend/native/models/detection/necks/dfine_hybrid_encoder.py +++ b/library/src/otx/backend/native/models/detection/necks/dfine_hybrid_encoder.py @@ -1,36 +1,320 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""D-FINE Hybrid Encoder. Modified from D-FINE (https://github.com/Peterande/D-FINE).""" +"""D-FINE Hybrid Encoder. + +Modified from D-FINE (https://github.com/Peterande/D-FINE). +""" from __future__ import annotations import copy from collections import OrderedDict from functools import partial -from typing import Any, Callable, ClassVar +from typing import Any, Callable, ClassVar, Literal import torch import torch.nn.functional as f from torch import Tensor, nn -from otx.backend.native.models.common.layers.transformer_layers import TransformerEncoder, TransformerEncoderLayer +from otx.backend.native.models.common.layers.transformer_layers import ( + TransformerEncoder, + TransformerEncoderLayer, +) from otx.backend.native.models.detection.layers.csp_layer import CSPRepLayer from otx.backend.native.models.detection.utils.utils import auto_pad from otx.backend.native.models.modules.activation import build_activation_layer from otx.backend.native.models.modules.conv_module import Conv2dModule from otx.backend.native.models.modules.norm import build_norm_layer +# ============================================================================= +# Helper Layers +# ============================================================================= + + +class ConvNormLayer(nn.Module): + """Convolution + BatchNorm + Activation layer. + + Args: + ch_in: Input channels. + ch_out: Output channels. + kernel_size: Convolution kernel size. + stride: Convolution stride. + groups: Number of groups for grouped convolution. + padding: Padding size. If None, uses (kernel_size-1)//2. + bias: Whether to use bias in convolution. + act: Activation function name or None. + """ + + def __init__( + self, + ch_in: int, + ch_out: int, + kernel_size: int, + stride: int, + groups: int = 1, + padding: int | None = None, + bias: bool = False, + act: Callable[..., nn.Module] | None = None, + ) -> None: + super().__init__() + padding = (kernel_size - 1) // 2 if padding is None else padding + self.conv = nn.Conv2d( + ch_in, + ch_out, + kernel_size, + stride, + groups=groups, + padding=padding, + bias=bias, + ) + self.norm = nn.BatchNorm2d(ch_out) + self.act = nn.Identity() if act is None else act() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.act(self.norm(self.conv(x))) + + +class ConvNormLayerFusable(nn.Module): + """Fusable Convolution + BatchNorm + Activation layer. + + Supports fusing Conv and BatchNorm for deployment optimization. + + Args: + ch_in: Input channels. + ch_out: Output channels. + kernel_size: Convolution kernel size. + stride: Convolution stride. + groups: Number of groups for grouped convolution. + padding: Padding size. If None, uses (kernel_size-1)//2. + bias: Whether to use bias in convolution. + act: Activation function class or None. + """ + + def __init__( + self, + ch_in: int, + ch_out: int, + kernel_size: int, + stride: int, + groups: int = 1, + padding: int | None = None, + bias: bool = False, + act: Callable[..., nn.Module] | None = None, + ) -> None: + super().__init__() + padding = (kernel_size - 1) // 2 if padding is None else padding + self.conv = nn.Conv2d(ch_in, ch_out, kernel_size, stride, groups=groups, padding=padding, bias=bias) + self.norm = nn.BatchNorm2d(ch_out) + self.act = nn.Identity() if act is None else act() + # Store params for deployment conversion + self._ch_in = ch_in + self._ch_out = ch_out + self._kernel_size = kernel_size + self._stride = stride + self._groups = groups + self._padding = padding + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + if hasattr(self, "conv_bn_fused"): + return self.act(self.conv_bn_fused(x)) + return self.act(self.norm(self.conv(x))) + + def convert_to_deploy(self) -> None: + """Fuse conv and batchnorm for deployment.""" + if not hasattr(self, "conv_bn_fused"): + self.conv_bn_fused = nn.Conv2d( + self._ch_in, + self._ch_out, + self._kernel_size, + self._stride, + groups=self._groups, + padding=self._padding, + bias=True, + ) + kernel, bias = self._get_fused_kernel_bias() + self.conv_bn_fused.weight.data = kernel + self.conv_bn_fused.bias.data = bias + delattr(self, "conv") + delattr(self, "norm") + + def _get_fused_kernel_bias(self) -> tuple[Tensor, Tensor]: + """Get fused kernel and bias from conv and batchnorm.""" + kernel = self.conv.weight + running_mean = self.norm.running_mean + running_var = self.norm.running_var + gamma = self.norm.weight + beta = self.norm.bias + eps = self.norm.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class VGGBlock(nn.Module): + """VGG-style block with parallel 3x3 and 1x1 convolutions. + + Can be converted to a single 3x3 conv for deployment. + + Args: + ch_in: Input channels. + ch_out: Output channels. + act: Activation function class. + """ + + def __init__( + self, + ch_in: int, + ch_out: int, + act: Callable[..., nn.Module] = nn.ReLU, + ) -> None: + super().__init__() + self.ch_in = ch_in + self.ch_out = ch_out + self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None) + self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None) + self.act = nn.Identity() if act is None else act() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + if hasattr(self, "conv"): + return self.act(self.conv(x)) + return self.act(self.conv1(x) + self.conv2(x)) + + def convert_to_deploy(self) -> None: + """Fuse parallel branches into single conv for deployment.""" + if not hasattr(self, "conv"): + self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1) + kernel, bias = self._get_equivalent_kernel_bias() + self.conv.weight.data = kernel + self.conv.bias.data = bias + delattr(self, "conv1") + delattr(self, "conv2") + + def _get_equivalent_kernel_bias(self) -> tuple[Tensor, Tensor]: + """Get equivalent 3x3 kernel and bias.""" + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) + return kernel3x3 + f.pad(kernel1x1, [1, 1, 1, 1]), bias3x3 + bias1x1 + + def _fuse_bn_tensor(self, branch: ConvNormLayer) -> tuple[Tensor, Tensor]: + """Fuse batchnorm into conv weights.""" + kernel = branch.conv.weight + running_mean = branch.norm.running_mean + running_var = branch.norm.running_var + gamma = branch.norm.weight + beta = branch.norm.bias + eps = branch.norm.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +# ============================================================================= +# CSP Layers +# ============================================================================= + + +class CSPLayerV2(nn.Module): + """Cross Stage Partial Layer V2. + + Args: + in_channels: Input channels. + out_channels: Output channels. + num_blocks: Number of bottleneck blocks. + expansion: Channel expansion ratio. + bias: Whether to use bias. + act: Activation function class. + bottletype: Bottleneck block type. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: int = 3, + expansion: float = 1.0, + bias: bool = False, + act: Callable[..., nn.Module] = nn.SiLU, + bottletype: type[nn.Module] = VGGBlock, + ) -> None: + super().__init__() + hidden_channels = int(out_channels * expansion) + self.conv1 = ConvNormLayerFusable(in_channels, hidden_channels, 1, 1, bias=bias, act=act) + self.conv2 = ConvNormLayerFusable(in_channels, hidden_channels, 1, 1, bias=bias, act=act) + self.bottlenecks = nn.Sequential( + *[bottletype(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)] + ) + self.conv3: nn.Module = ( + ConvNormLayerFusable(hidden_channels, out_channels, 1, 1, bias=bias, act=act) + if hidden_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x_1 = self.bottlenecks(self.conv1(x)) + x_2 = self.conv2(x) + return self.conv3(x_1 + x_2) + + +class CSPLayer2(nn.Module): + """Cross Stage Partial Layer with chunk-based split. + + Args: + in_channels: Input channels. + out_channels: Output channels. + num_blocks: Number of bottleneck blocks. + expansion: Channel expansion ratio. + bias: Whether to use bias. + act: Activation function class. + bottletype: Bottleneck block type. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: int = 3, + expansion: float = 1.0, + bias: bool = False, + act: Callable[..., nn.Module] = nn.SiLU, + bottletype: type[nn.Module] = VGGBlock, + ) -> None: + super().__init__() + hidden_channels = int(out_channels * expansion) + self.conv1 = ConvNormLayerFusable(in_channels, hidden_channels * 2, 1, 1, bias=bias, act=act) + self.bottlenecks = nn.Sequential( + *[bottletype(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)] + ) + self.conv3: nn.Module = ( + ConvNormLayerFusable(hidden_channels, out_channels, 1, 1, bias=bias, act=act) + if hidden_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + y = list(self.conv1(x).chunk(2, 1)) + return self.conv3(y[0] + self.bottlenecks(y[1])) + + +# ============================================================================= +# Downsampling Modules +# ============================================================================= + class SCDown(nn.Module): - """SCDown downsampling module. + """Spatial-Channel Downsampling module. Args: - c1 (int): Number of channels in the input feature map. - c2 (int): Number of channels produced by the convolution. - k (int): Kernel size of the convolving kernel. - s (int): Stride of the convolution. - normalization (Callable[..., nn.Module] | None): Normalization layer module. + c1: Input channels. + c2: Output channels. + k: Kernel size. + s: Stride. + normalization: Normalization layer builder. """ def __init__( @@ -66,21 +350,43 @@ def forward(self, x: Tensor) -> Tensor: return self.cv2(self.cv1(x)) -class RepNCSPELAN4(nn.Module): - """GELANModule from YOLOv9. +class SCDownFusable(nn.Module): + """Fusable Spatial-Channel Downsampling module. - Note: - Might not be replaceable as layer implementation is very different from GELANModule in YOLOv9. + Args: + c1: Input channels. + c2: Output channels. + k: Kernel size. + s: Stride. + """ + + def __init__(self, c1: int, c2: int, k: int, s: int) -> None: + super().__init__() + self.cv1 = ConvNormLayerFusable(c1, c2, 1, 1) + self.cv2 = ConvNormLayerFusable(c2, c2, k, s, groups=c2) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.cv2(self.cv1(x)) + + +# ============================================================================= +# RepNCSPELAN Blocks +# ============================================================================= + + +class RepNCSPELAN4(nn.Module): + """RepNCSPELAN4 block (GELAN-style). Args: - c1 (int): c1 channel size. Refer to GELAN paper. - c2 (int): c2 channel size. Refer to GELAN paper. - c3 (int): c3 channel size. Refer to GELAN paper. - c4 (int): c4 channel size. Refer to GELAN paper. - n (int, optional): number of blocks. Defaults to 3. - bias (bool, optional): use bias. Defaults to False. - activation (Callable[..., nn.Module] | None, optional): activation function. Defaults to None. - normalization (Callable[..., nn.Module] | None, optional): norm layer. Defaults to None. + c1: Input channels. + c2: Output channels. + c3: Internal channels. + c4: Bottleneck channels. + num_blocks: Number of blocks. + bias: Whether to use bias. + activation: Activation layer builder. + normalization: Normalization layer builder. """ def __init__( @@ -106,17 +412,8 @@ def __init__( activation=build_activation_layer(activation), normalization=build_norm_layer(normalization, num_features=c3), ) - self.cv2 = nn.Sequential( - CSPRepLayer( - c3 // 2, - c4, - num_blocks, - 1, - bias=bias, - activation=activation, - normalization=normalization, - ), + CSPRepLayer(c3 // 2, c4, num_blocks, 1, bias=bias, activation=activation, normalization=normalization), Conv2dModule( c4, c4, @@ -128,17 +425,8 @@ def __init__( normalization=build_norm_layer(normalization, num_features=c4), ), ) - self.cv3 = nn.Sequential( - CSPRepLayer( - c4, - c4, - num_blocks, - 1, - bias=bias, - activation=activation, - normalization=normalization, - ), + CSPRepLayer(c4, c4, num_blocks, 1, bias=bias, activation=activation, normalization=normalization), Conv2dModule( c4, c4, @@ -150,7 +438,6 @@ def __init__( normalization=build_norm_layer(normalization, num_features=c4), ), ) - self.cv4 = Conv2dModule( c3 + (2 * c4), c2, @@ -168,61 +455,105 @@ def forward(self, x: Tensor) -> Tensor: return self.cv4(torch.cat(y, 1)) +class RepNCSPELAN5(nn.Module): + """RepNCSPELAN5 block (DEIM-style, fusable implementation). + + Args: + c1: Input channels. + c2: Output channels. + c3: Internal channels. + c4: Bottleneck channels. + num_blocks: Number of blocks. + bias: Whether to use bias. + act: Activation function class. + """ + + def __init__( + self, + c1: int, + c2: int, + c3: int, + c4: int, + num_blocks: int = 3, + bias: bool = False, + act: Callable[..., nn.Module] = nn.SiLU, + ) -> None: + super().__init__() + self.c = c3 // 2 + self.cv1 = ConvNormLayerFusable(c1, c3, 1, 1, bias=bias, act=act) + self.cv2 = nn.Sequential(CSPLayer2(c3 // 2, c4, num_blocks, 1, bias=bias, act=act, bottletype=VGGBlock)) + self.cv3 = nn.Sequential(CSPLayer2(c4, c4, num_blocks, 1, bias=bias, act=act, bottletype=VGGBlock)) + self.cv4 = ConvNormLayerFusable(c3 + (2 * c4), c2, 1, 1, bias=bias, act=act) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + +# ============================================================================= +# Main Hybrid Encoder Module +# ============================================================================= + + class HybridEncoderModule(nn.Module): - """HybridEncoder for DFine. + """Unified Hybrid Encoder for D-FINE and DEIM detection models. - TODO(Eugene): Merge with current rtdetr.HybridEncoderModule in next PR. + Combines transformer encoder with FPN and PAN for multi-scale feature fusion. Args: - in_channels (list[int], optional): List of input channels for each feature map. - Defaults to [512, 1024, 2048]. - feat_strides (list[int], optional): List of stride values for - each feature map. Defaults to [8, 16, 32]. - hidden_dim (int, optional): Hidden dimension size. Defaults to 256. - nhead (int, optional): Number of attention heads in the transformer encoder. - Defaults to 8. - dim_feedforward (int, optional): Dimension of the feedforward network - in the transformer encoder. Defaults to 1024. - dropout (float, optional): Dropout rate. Defaults to 0.0. - enc_activation (Callable[..., nn.Module]): Activation layer module. - Defaults to ``nn.GELU``. - normalization (Callable[..., nn.Module]): Normalization layer module. - Defaults to ``partial(build_norm_layer, nn.BatchNorm2d, layer_name="norm")``. - use_encoder_idx (list[int], optional): List of indices of the encoder to use. - Defaults to [2]. - num_encoder_layers (int, optional): Number of layers in the transformer encoder. - Defaults to 1. - pe_temperature (float, optional): Temperature parameter for positional encoding. - Defaults to 10000. - expansion (float, optional): Expansion factor for the CSPRepLayer. - Defaults to 1.0. - depth_mult (float, optional): Depth multiplier for the CSPRepLayer. - Defaults to 1.0. - activation (Callable[..., nn.Module]): Activation layer module. - Defaults to ``nn.SiLU``. - eval_spatial_size (tuple[int, int] | None, optional): Spatial size for - evaluation. Defaults to None. + in_channels: List of input channel sizes for each feature level. + feat_strides: List of stride values for each feature level. + hidden_dim: Hidden dimension for the encoder. + nhead: Number of attention heads. + dim_feedforward: Feedforward dimension in transformer. + dropout: Dropout rate. + enc_activation: Activation for transformer encoder. + use_encoder_idx: Indices of feature levels to apply transformer encoder. + num_encoder_layers: Number of transformer encoder layers. + pe_temperature: Temperature for positional encoding. + expansion: Channel expansion factor. + depth_mult: Depth multiplier for CSP blocks. + activation: Activation function class for FPN/PAN blocks. + normalization: Normalization layer builder. + eval_spatial_size: Spatial size for evaluation (caches positional embeddings). + fuse_op: Feature fusion operation ('cat' or 'sum'). + use_fusable_layers: Whether to use fusable layers (for DEIM models). """ def __init__( self, - in_channels: list[int] = [512, 1024, 2048], # noqa: B006 - feat_strides: list[int] = [8, 16, 32], # noqa: B006 + in_channels: list[int] | None = None, + feat_strides: list[int] | None = None, hidden_dim: int = 256, nhead: int = 8, dim_feedforward: int = 1024, dropout: float = 0.0, enc_activation: Callable[..., nn.Module] = nn.GELU, - normalization: Callable[..., nn.Module] = partial(build_norm_layer, nn.BatchNorm2d, layer_name="norm"), - use_encoder_idx: list[int] = [2], # noqa: B006 + use_encoder_idx: list[int] | None = None, num_encoder_layers: int = 1, - pe_temperature: int = 10000, + pe_temperature: float = 10000.0, expansion: float = 1.0, depth_mult: float = 1.0, activation: Callable[..., nn.Module] = nn.SiLU, + normalization: Callable[..., nn.Module] | None = None, eval_spatial_size: tuple[int, int] | None = None, - ): + fuse_op: Literal["cat", "sum"] = "cat", + use_fusable_layers: bool = False, + ) -> None: super().__init__() + + # Set defaults + if in_channels is None: + in_channels = [512, 1024, 2048] + if feat_strides is None: + feat_strides = [8, 16, 32] + if use_encoder_idx is None: + use_encoder_idx = [2] + if normalization is None: + normalization = partial(build_norm_layer, nn.BatchNorm2d, layer_name="norm") + self.in_channels = in_channels self.feat_strides = feat_strides self.hidden_dim = hidden_dim @@ -230,24 +561,14 @@ def __init__( self.num_encoder_layers = num_encoder_layers self.pe_temperature = pe_temperature self.eval_spatial_size = eval_spatial_size - self.out_channels = [hidden_dim for _ in range(len(in_channels))] + self.fuse_op = fuse_op + self.out_channels = [hidden_dim] * len(in_channels) self.out_strides = feat_strides - # channel projection - self.input_proj = nn.ModuleList() - for in_channel in in_channels: - self.input_proj.append( - nn.Sequential( - OrderedDict( - [ - ("conv", nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)), - ("norm", nn.BatchNorm2d(hidden_dim)), - ], - ), - ), - ) + # Build input projection + self.input_proj = self._build_input_proj(in_channels, hidden_dim, use_fusable_layers) - # encoder transformer + # Build transformer encoder encoder_layer = TransformerEncoderLayer( hidden_dim, nhead=nhead, @@ -255,68 +576,114 @@ def __init__( dropout=dropout, activation=enc_activation, ) - self.encoder = nn.ModuleList( - [TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))], + [TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))] + ) + + # Build FPN and PAN + self._build_fpn_pan( + in_channels, hidden_dim, expansion, depth_mult, activation, normalization, fuse_op, use_fusable_layers ) - # top-down fpn + self._reset_parameters() + + def _build_input_proj( + self, + in_channels: list[int], + hidden_dim: int, + use_fusable_layers: bool, + ) -> nn.ModuleList: + """Build input projection layers.""" + input_proj = nn.ModuleList() + for in_channel in in_channels: + if use_fusable_layers and in_channel == hidden_dim: + input_proj.append(nn.Identity()) + else: + input_proj.append( + nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)), + ("norm", nn.BatchNorm2d(hidden_dim)), + ] + ) + ) + ) + return input_proj + + def _build_fpn_pan( + self, + in_channels: list[int], + hidden_dim: int, + expansion: float, + depth_mult: float, + activation: Callable[..., nn.Module], + normalization: Callable[..., nn.Module], + fuse_op: str, + use_fusable_layers: bool, + ) -> None: + """Build FPN and PAN layers.""" + num_levels = len(in_channels) + input_dim = hidden_dim if fuse_op == "sum" else hidden_dim * 2 + num_blocks = round(3 * depth_mult) + c4 = round(expansion * hidden_dim // 2) + self.lateral_convs = nn.ModuleList() self.fpn_blocks = nn.ModuleList() - for _ in range(len(in_channels) - 1, 0, -1): - self.lateral_convs.append( - Conv2dModule( - hidden_dim, - hidden_dim, - 1, - 1, - normalization=build_norm_layer(normalization, num_features=hidden_dim), - activation=None, - ), - ) - self.fpn_blocks.append( - RepNCSPELAN4( - hidden_dim * 2, - hidden_dim, - hidden_dim * 2, - round(expansion * hidden_dim // 2), - round(3 * depth_mult), - activation=activation, - normalization=normalization, - ), - ) - - # bottom-up pan self.downsample_convs = nn.ModuleList() self.pan_blocks = nn.ModuleList() - for _ in range(len(in_channels) - 1): - self.downsample_convs.append( - nn.Sequential( - SCDown( + + for _ in range(num_levels - 1): + if use_fusable_layers: + # DEIM-style fusable layers + self.lateral_convs.append(ConvNormLayerFusable(hidden_dim, hidden_dim, 1, 1)) + self.fpn_blocks.append( + RepNCSPELAN5(input_dim, hidden_dim, hidden_dim * 2, c4, num_blocks, act=activation) + ) + self.downsample_convs.append(nn.Sequential(SCDownFusable(hidden_dim, hidden_dim, 3, 2))) + self.pan_blocks.append( + RepNCSPELAN5(input_dim, hidden_dim, hidden_dim * 2, c4, num_blocks, act=activation) + ) + else: + # D-FINE style with OTX layers + self.lateral_convs.append( + Conv2dModule( hidden_dim, hidden_dim, - 3, - 2, + 1, + 1, + normalization=build_norm_layer(normalization, num_features=hidden_dim), + activation=None, + ) + ) + self.fpn_blocks.append( + RepNCSPELAN4( + hidden_dim * 2, + hidden_dim, + hidden_dim * 2, + c4, + num_blocks, + activation=activation, normalization=normalization, - ), - ), - ) - self.pan_blocks.append( - RepNCSPELAN4( - hidden_dim * 2, - hidden_dim, - hidden_dim * 2, - round(expansion * hidden_dim // 2), - round(3 * depth_mult), - activation=activation, - normalization=normalization, - ), - ) - - self._reset_parameters() + ) + ) + self.downsample_convs.append( + nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, normalization=normalization)) + ) + self.pan_blocks.append( + RepNCSPELAN4( + hidden_dim * 2, + hidden_dim, + hidden_dim * 2, + c4, + num_blocks, + activation=activation, + normalization=normalization, + ) + ) def _reset_parameters(self) -> None: - """Reset parameters.""" + """Initialize cached positional embeddings for evaluation.""" if self.eval_spatial_size: for idx in self.use_encoder_idx: stride = self.feat_strides[idx] @@ -335,13 +702,25 @@ def build_2d_sincos_position_embedding( embed_dim: int = 256, temperature: float = 10000.0, ) -> Tensor: - """Build 2D sin-cos position embedding.""" + """Build 2D sinusoidal-cosine position embedding. + + Args: + w: Width of the feature map. + h: Height of the feature map. + embed_dim: Embedding dimension (must be divisible by 4). + temperature: Temperature for positional encoding. + + Returns: + Position embedding tensor of shape (1, h*w, embed_dim). + """ grid_w = torch.arange(int(w), dtype=torch.float32) grid_h = torch.arange(int(h), dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + if embed_dim % 4 != 0: msg = "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" raise ValueError(msg) + pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1.0 / (temperature**omega) @@ -351,22 +730,31 @@ def build_2d_sincos_position_embedding( return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] - def forward(self, feats: Tensor) -> list[Tensor]: - """Forward pass.""" + def forward(self, feats: list[Tensor]) -> list[Tensor]: + """Forward pass. + + Args: + feats: List of feature tensors from backbone. + + Returns: + List of fused multi-scale feature tensors. + """ if len(feats) != len(self.in_channels): - msg = f"Input feature size {len(feats)} does not match the number of input channels {len(self.in_channels)}" + msg = f"Input feature size {len(feats)} does not match expected {len(self.in_channels)}" raise ValueError(msg) + + # Project input features proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] - # encoder + # Apply transformer encoder if self.num_encoder_layers > 0: for i, enc_ind in enumerate(self.use_encoder_idx): h, w = proj_feats[enc_ind].shape[2:] - # flatten [B, C, H, W] to [B, HxW, C] src_flatten = proj_feats[enc_ind].flatten(2).permute(0, 2, 1) + if self.training or self.eval_spatial_size is None: pos_embed = self.build_2d_sincos_position_embedding(w, h, self.hidden_dim, self.pe_temperature).to( - src_flatten.device, + src_flatten.device ) else: pos_embed = getattr(self, f"pos_embed{enc_ind}").to(src_flatten.device) @@ -374,32 +762,57 @@ def forward(self, feats: Tensor) -> list[Tensor]: memory = self.encoder[i](src_flatten, pos_embed=pos_embed) proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() - # broadcasting and fusion + # Top-down FPN inner_outs = [proj_feats[-1]] for idx in range(len(self.in_channels) - 1, 0, -1): - feat_heigh = inner_outs[0] + feat_high = inner_outs[0] feat_low = proj_feats[idx - 1] - feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh) - inner_outs[0] = feat_heigh - upsample_feat = f.interpolate(feat_heigh, scale_factor=2.0, mode="nearest") - inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1)) + + feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high) + inner_outs[0] = feat_high + + upsample_feat = f.interpolate(feat_high, scale_factor=2.0, mode="nearest") + + if self.fuse_op == "sum": + fused_feat = upsample_feat + feat_low + else: + fused_feat = torch.concat([upsample_feat, feat_low], dim=1) + + inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](fused_feat) inner_outs.insert(0, inner_out) + # Bottom-up PAN outs = [inner_outs[0]] for idx in range(len(self.in_channels) - 1): feat_low = outs[-1] - feat_height = inner_outs[idx + 1] + feat_high = inner_outs[idx + 1] + downsample_feat = self.downsample_convs[idx](feat_low) - out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_height], dim=1)) + + if self.fuse_op == "sum": + fused_feat = downsample_feat + feat_high + else: + fused_feat = torch.concat([downsample_feat, feat_high], dim=1) + + out = self.pan_blocks[idx](fused_feat) outs.append(out) return outs +# ============================================================================= +# Factory Class +# ============================================================================= + + class HybridEncoder: - """HybridEncoder factory for D-Fine detection.""" + """Factory class for creating HybridEncoder instances. - encoder_cfg: ClassVar[dict[str, Any]] = { + Supports D-FINE (dfine_*), DEIM-DFINE (deim_dfine_*), and DEIMv2 (deimv2_*) models. + """ + + encoder_cfg: ClassVar[dict[str, dict[str, Any]]] = { + # D-FINE models (use concat fusion, OTX layers) "dfine_hgnetv2_n": { "in_channels": [512, 1024], "feat_strides": [16, 32], @@ -428,6 +841,7 @@ class HybridEncoder: "hidden_dim": 384, "dim_feedforward": 2048, }, + # DEIM-DFINE models (use concat fusion, OTX layers) "deim_dfine_hgnetv2_n": { "in_channels": [512, 1024], "feat_strides": [16, 32], @@ -456,11 +870,56 @@ class HybridEncoder: "hidden_dim": 384, "dim_feedforward": 2048, }, + # DEIMv2 models (use sum fusion, fusable layers) + "deimv2_x": { + "in_channels": [256, 256, 256], + "hidden_dim": 256, + "dim_feedforward": 1024, + "expansion": 1.25, + "depth_mult": 1.37, + "fuse_op": "sum", + "use_fusable_layers": True, + }, + "deimv2_l": { + "in_channels": [224, 224, 224], + "hidden_dim": 224, + "dim_feedforward": 896, + "fuse_op": "sum", + "use_fusable_layers": True, + }, + "deimv2_m": { + "in_channels": [256, 256, 256], + "depth_mult": 1.0, + "expansion": 0.67, + "hidden_dim": 256, + "dim_feedforward": 512, + "fuse_op": "sum", + "use_fusable_layers": True, + }, + "deimv2_s": { + "in_channels": [192, 192, 192], + "depth_mult": 0.67, + "expansion": 0.34, + "hidden_dim": 192, + "dim_feedforward": 512, + "fuse_op": "sum", + "use_fusable_layers": True, + }, } def __new__(cls, model_name: str) -> HybridEncoderModule: - """Constructor for HybridEncoder.""" + """Create a HybridEncoder instance. + + Args: + model_name: Model configuration name. + + Returns: + Configured HybridEncoderModule instance. + + Raises: + KeyError: If model_name is not in encoder_cfg. + """ if model_name not in cls.encoder_cfg: - msg = f"model type '{model_name}' is not supported" + msg = f"Model type '{model_name}' is not supported. Available: {list(cls.encoder_cfg.keys())}" raise KeyError(msg) return HybridEncoderModule(**cls.encoder_cfg[model_name]) diff --git a/library/src/otx/backend/native/models/detection/utils/utils.py b/library/src/otx/backend/native/models/detection/utils/utils.py index 42e189ee1c9..12f18550807 100644 --- a/library/src/otx/backend/native/models/detection/utils/utils.py +++ b/library/src/otx/backend/native/models/detection/utils/utils.py @@ -275,7 +275,6 @@ def dfine_weighting_function(reg_max: int, up: Tensor, reg_scale: Tensor) -> Ten reg_scale (Tensor): Controls the curvature of the Weighting Function. Larger values result in flatter weights near the central axis W(reg_max/2)=0 and steeper weights at both ends. - deploy (bool): If True, uses deployment mode settings. Returns: Tensor: Sequence of Weighting Function. diff --git a/library/src/otx/backend/native/models/modules/norm.py b/library/src/otx/backend/native/models/modules/norm.py index 0155c310ba7..c5514308619 100644 --- a/library/src/otx/backend/native/models/modules/norm.py +++ b/library/src/otx/backend/native/models/modules/norm.py @@ -12,12 +12,51 @@ from typing import Any, Callable import torch -from torch import nn +from torch import Tensor, nn from torch.nn import SyncBatchNorm from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.instancenorm import _InstanceNorm +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization. + + Args: + dim (int): The number of features in the input. + eps (float, optional): A value added for numerical stability. Defaults to 1e-6. + """ + + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.dim = dim + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: Tensor) -> Tensor: + """Compute RMS normalization.""" + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of RMSNorm. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Normalized and scaled tensor. + """ + output = self._norm(x.float()).type_as(x) + return output * self.scale + + def extra_repr(self) -> str: + """Extra representation string.""" + return f"dim={self.dim}, eps={self.eps}" + + def reset_parameters(self) -> None: + """Reset scale parameter to ones.""" + nn.init.constant_(self.scale, 1) + + class FrozenBatchNorm2d(nn.Module): """Copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py. diff --git a/library/src/otx/backend/native/models/modules/transformer.py b/library/src/otx/backend/native/models/modules/transformer.py index 5436dc3cde7..560922a2cab 100644 --- a/library/src/otx/backend/native/models/modules/transformer.py +++ b/library/src/otx/backend/native/models/modules/transformer.py @@ -243,6 +243,87 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x, out_size +class UnflattenPatchEmbed(nn.Module): + """2D image to patch embedding: (B,C,H,W) -> (B,N,D). + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: int | tuple[int, int] = 224, + patch_size: int | tuple[int, int] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Callable | None = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_hw = img_size if isinstance(img_size, tuple) else (img_size, img_size) + patch_hw = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size) + patch_grid_size = ( + image_hw[0] // patch_hw[0], + image_hw[1] // patch_hw[1], + ) + + self.img_size = image_hw + self.patch_size = patch_hw + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_hw, stride=patch_hw) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass to embed image patches. + + Args: + x: Input image tensor of shape (B, C, H, W). + + Returns: + Patch embeddings of shape (B, N, D) or (B, H, W, D) if not flattened. + """ + _, _, h, w = x.shape + + x = self.proj(x) # B C H W + h, w = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, h, w, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + """Calculate FLOPs for patch embedding. + + Returns: + Number of floating point operations. + """ + ho, wo = self.patches_resolution + flops = ho * wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += ho * wo * self.embed_dim + return flops + + def reset_parameters(self) -> None: + """Reset projection layer parameters using uniform initialization.""" + k = 1 / (self.in_chans * (self.patch_size[0] ** 2)) + nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k)) + if self.proj.bias is not None: + nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k)) + + class FFN(BaseModule): """Implements feed-forward networks (FFNs) with identity connection. diff --git a/library/src/otx/backend/native/models/utils/utils.py b/library/src/otx/backend/native/models/utils/utils.py index 5f6645d0655..6043d59fdb0 100644 --- a/library/src/otx/backend/native/models/utils/utils.py +++ b/library/src/otx/backend/native/models/utils/utils.py @@ -70,14 +70,30 @@ def load_checkpoint( map_location: str = "cpu", strict: bool = False, prefix: str = "", + key_mapping: dict[str, str] | None = None, ) -> None: - """Load state dict from path of checkpoint and dump to model.""" + """Load state dict from path of checkpoint and dump to model. + + Args: + model: The PyTorch model to load the checkpoint into. + checkpoint: Path or URL to the checkpoint file. + map_location: Device to map tensors to. Defaults to "cpu". + strict: Whether to strictly enforce key matching. Defaults to False. + prefix: Prefix to strip from state dict keys. Defaults to "". + key_mapping: Dictionary mapping old key names to new key names + for remapping pretrained weights with different parameter names. + Example: {"patch_embed.proj": "patch_embed.projection"} will remap + any key containing "patch_embed.proj" (e.g., "patch_embed.proj.weight") to + "patch_embed.projection.weight". + Defaults to None. + """ if Path(checkpoint).exists(): load_checkpoint_to_model( model, torch.load(checkpoint, map_location), strict=strict, prefix=prefix, + key_mapping=key_mapping, ) else: load_checkpoint_to_model( @@ -85,6 +101,7 @@ def load_checkpoint( load_from_http(checkpoint, map_location), strict=strict, prefix=prefix, + key_mapping=key_mapping, ) @@ -203,11 +220,47 @@ def load(module: nn.Module, local_state_dict: dict, prefix: str = "") -> None: warn("\n".join(err_msg), stacklevel=1) +def remap_state_dict_keys( + state_dict: dict[str, Any], + key_mapping: dict[str, str], +) -> dict[str, Any]: + """Remap state dict keys based on provided mapping. + + Args: + state_dict: Original state dictionary. + key_mapping: Dictionary mapping old key names to new key names. + Supports exact matches and prefix matching with wildcards. + Example: {"patch_embed.proj": "patch_embed.projection"} + + Returns: + State dict with remapped keys. + """ + new_state_dict = OrderedDict() + + for key, value in state_dict.items(): + new_key = key + + # Check for exact match first + if key in key_mapping: + new_key = key_mapping[key] + else: + # Check for prefix matches (e.g., "old_prefix" -> "new_prefix") + for old_pattern, new_pattern in key_mapping.items(): + if old_pattern in key: + new_key = key.replace(old_pattern, new_pattern, 1) + break + + new_state_dict[new_key] = value + + return new_state_dict + + def load_checkpoint_to_model( model: nn.Module, checkpoint: dict, strict: bool = False, prefix: str = "", + key_mapping: dict[str, str] | None = None, ) -> None: """Loads a checkpoint dictionary into a PyTorch model. @@ -218,6 +271,12 @@ def load_checkpoint_to_model( checkpoint (dict): The checkpoint dictionary containing the model's state_dict. strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model's state_dict. Defaults to False. + prefix (str, optional): Prefix to strip from state dict keys. Defaults to "". + key_mapping (dict[str, str] | None, optional): Dictionary mapping old key names to new key names + for remapping pretrained weights with different parameter names. + Example: {"patch_embed.proj": "patch_embed.projection"} will remap + "patch_embed.proj.weight" to "patch_embed.projection.weight". + Defaults to None. Returns: None @@ -234,6 +293,10 @@ def load_checkpoint_to_model( for p, r in [(r"^module\.", ""), (rf"^{prefix}\.", "")]: state_dict = OrderedDict({re.sub(p, r, k): v for k, v in state_dict.items()}) + # Remap keys if mapping is provided + if key_mapping is not None: + state_dict = remap_state_dict_keys(state_dict, key_mapping) + # Keep metadata in state_dict state_dict._metadata = metadata # noqa: SLF001 diff --git a/library/src/otx/models/__init__.py b/library/src/otx/models/__init__.py index 3d5d70ebdb0..3937da970ad 100644 --- a/library/src/otx/models/__init__.py +++ b/library/src/otx/models/__init__.py @@ -5,6 +5,7 @@ from otx.backend.native.models import ( ATSS, + DEIMV2, RTDETR, SSD, YOLOX, @@ -38,6 +39,7 @@ __all__ = [ # detection "ATSS", + "DEIMV2", "RTDETR", "SSD", "YOLOX", diff --git a/library/src/otx/recipe/detection/deim_dfine_m.yaml b/library/src/otx/recipe/detection/deim_dfine_m.yaml index 0a8337cceba..18fb009f875 100644 --- a/library/src/otx/recipe/detection/deim_dfine_m.yaml +++ b/library/src/otx/recipe/detection/deim_dfine_m.yaml @@ -54,6 +54,7 @@ callbacks: - class_path: otx.backend.native.callbacks.aug_scheduler.AugmentationSchedulerCallback init_args: data_aug_switch: + class_path: otx.backend.native.callbacks.aug_scheduler.DataAugSwitch init_args: policy_epochs: [4, 40, 72] policies: diff --git a/library/src/otx/recipe/detection/deimv2_l.yaml b/library/src/otx/recipe/detection/deimv2_l.yaml new file mode 100644 index 00000000000..ba4e185f236 --- /dev/null +++ b/library/src/otx/recipe/detection/deimv2_l.yaml @@ -0,0 +1,268 @@ +task: DETECTION +model: + class_path: otx.backend.native.models.detection.DEIMV2 + init_args: + model_name: deimv2_l + label_info: 80 + multi_scale: false + gradient_checkpointing: false # Set to true to reduce memory usage at cost of ~20% slower training + + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.0005 + betas: [0.9, 0.999] + weight_decay: 0.000125 + + scheduler: + class_path: otx.backend.native.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 30 + main_scheduler_callable: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.5 + patience: 10 + monitor: val/f1-score + +engine: + device: auto + +callback_monitor: val/f1-score + +callbacks: + - class_path: otx.backend.native.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling + init_args: + max_interval: 1 + min_lrschedule_patience: 3 + - class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + mode: max + patience: 10 + min_delta: 0.001 + warmup_iters: 50 + warmup_epochs: 10 + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: "" # use engine.work_dir + monitor: val/f1-score + mode: max + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + filename: "checkpoints/epoch_{epoch:03d}" + - class_path: otx.backend.native.callbacks.aug_scheduler.AugmentationSchedulerCallback + init_args: + data_aug_switch: + class_path: otx.backend.native.callbacks.aug_scheduler.DataAugSwitch + init_args: + policy_epochs: [4, 23, 40] + policies: + no_aug: + to_tv_image: false + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + keep_ratio: false + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + light_aug: + to_tv_image: true + transforms: + - class_path: torchvision.transforms.v2.ToImage + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + init_args: + p: 0.5 + - class_path: torchvision.transforms.v2.RandomZoomOut + init_args: + fill: 0 + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomIoUCrop + init_args: + probability: 0.8 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + transform_bbox: true + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + strong_aug_1: + to_tv_image: false + transforms: + - class_path: otx.data.transform_libs.torchvision.CachedMosaic + init_args: + random_pop: true + max_cached_images: 20 + img_scale: [640, 640] + - class_path: otx.data.transform_libs.torchvision.CachedMixUp + init_args: + img_scale: [640, 640] # (H, W) + ratio_range: + - 1.0 + - 1.0 + probability: 0.5 + random_pop: true + max_cached_images: 10 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] # (H, W) + keep_ratio: false + transform_bbox: true + - class_path: otx.data.transform_libs.torchvision.YOLOXHSVRandomAug + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + strong_aug_2: + to_tv_image: true + transforms: + - class_path: torchvision.transforms.v2.ToImage + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + init_args: + p: 0.5 + - class_path: torchvision.transforms.v2.RandomZoomOut # Can't be used when using CachedMosaic + init_args: + fill: 0 + p: 0.5 + side_range: [1.0, 2.0] + - class_path: otx.data.transform_libs.torchvision.RandomIoUCrop # Can't be used when using CachedMosaic + init_args: + probability: 0.8 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: otx.data.transform_libs.torchvision.CachedMixUp + init_args: + img_scale: [640, 640] # (H, W) + ratio_range: + - 1.0 + - 1.0 + probability: 0.5 + random_pop: true + max_cached_images: 10 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + transform_bbox: true + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] +data: ../_base_/data/torchvision_base.yaml +overrides: + callbacks: + - class_path: otx.backend.native.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling + init_args: + max_interval: 1 + min_lrschedule_patience: 3 + - class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + warmup_epochs: 7 + + data: + input_size: + - 640 + - 640 + task: DETECTION + data_format: coco_instances + train_subset: + batch_size: 8 + num_workers: 4 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + sampler: + class_path: otx.data.samplers.balanced_sampler.BalancedSampler + + val_subset: + batch_size: 8 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + test_subset: + batch_size: 8 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] diff --git a/library/src/otx/recipe/detection/deimv2_m.yaml b/library/src/otx/recipe/detection/deimv2_m.yaml new file mode 100644 index 00000000000..35e36535774 --- /dev/null +++ b/library/src/otx/recipe/detection/deimv2_m.yaml @@ -0,0 +1,272 @@ +task: DETECTION +model: + class_path: otx.backend.native.models.detection.DEIMV2 + init_args: + model_name: deimv2_m + label_info: 80 + multi_scale: false + gradient_checkpointing: true # Set to true to reduce memory usage at cost of ~20% slower training + + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.0004 + betas: [0.9, 0.999] + weight_decay: 0.0001 + + scheduler: + class_path: otx.backend.native.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 30 + main_scheduler_callable: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.5 + patience: 15 + monitor: val/f1-score + +engine: + device: auto + +callback_monitor: val/f1-score + +callbacks: + - class_path: otx.backend.native.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling + init_args: + max_interval: 1 + min_lrschedule_patience: 3 + - class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + mode: max + patience: 10 + min_delta: 0.001 + warmup_iters: 50 + warmup_epochs: 10 + - class_path: otx.backend.native.callbacks.cuda_cache_cleaner.CUDACacheCleaner + init_args: + clean_on_validation_end: true + clean_on_epoch_end: true + log_memory: false + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: "" # use engine.work_dir + monitor: val/f1-score + mode: max + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + filename: "checkpoints/epoch_{epoch:03d}" + - class_path: otx.backend.native.callbacks.aug_scheduler.AugmentationSchedulerCallback + init_args: + data_aug_switch: + class_path: otx.backend.native.callbacks.aug_scheduler.DataAugSwitch + init_args: + policy_epochs: [4, 40, 72] + policies: + no_aug: + to_tv_image: false + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + keep_ratio: false + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + light_aug: + to_tv_image: true + transforms: + - class_path: torchvision.transforms.v2.ToImage + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + init_args: + p: 0.5 + - class_path: torchvision.transforms.v2.RandomZoomOut + init_args: + fill: 0 + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomIoUCrop + init_args: + probability: 0.8 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + transform_bbox: true + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + strong_aug_1: + to_tv_image: false + transforms: + - class_path: otx.data.transform_libs.torchvision.CachedMosaic + init_args: + random_pop: true + max_cached_images: 20 + img_scale: [640, 640] + - class_path: otx.data.transform_libs.torchvision.CachedMixUp + init_args: + img_scale: [640, 640] # (H, W) + ratio_range: + - 1.0 + - 1.0 + probability: 0.5 + random_pop: true + max_cached_images: 10 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] # (H, W) + keep_ratio: false + transform_bbox: true + - class_path: otx.data.transform_libs.torchvision.YOLOXHSVRandomAug + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + strong_aug_2: + to_tv_image: true + transforms: + - class_path: torchvision.transforms.v2.ToImage + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + init_args: + p: 0.5 + - class_path: torchvision.transforms.v2.RandomZoomOut # Can't be used when using CachedMosaic + init_args: + fill: 0 + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomIoUCrop # Can't be used when using CachedMosaic + init_args: + probability: 0.8 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.CachedMixUp + init_args: + img_scale: [640, 640] # (H, W) + ratio_range: + - 1.0 + - 1.0 + probability: 0.5 + random_pop: true + max_cached_images: 10 + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + transform_bbox: true + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] +data: ../_base_/data/torchvision_base.yaml +overrides: + callbacks: + - class_path: otx.backend.native.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling + init_args: + max_interval: 1 + min_lrschedule_patience: 3 + - class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + warmup_epochs: 7 + + data: + input_size: + - 640 + - 640 + task: DETECTION + data_format: coco_instances + train_subset: + batch_size: 8 + num_workers: 4 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + sampler: + class_path: otx.data.samplers.balanced_sampler.BalancedSampler + + val_subset: + batch_size: 8 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + test_subset: + batch_size: 8 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] diff --git a/library/src/otx/recipe/detection/deimv2_s.yaml b/library/src/otx/recipe/detection/deimv2_s.yaml new file mode 100644 index 00000000000..3f00221a5b9 --- /dev/null +++ b/library/src/otx/recipe/detection/deimv2_s.yaml @@ -0,0 +1,267 @@ +task: DETECTION +model: + class_path: otx.backend.native.models.detection.DEIMV2 + init_args: + model_name: deimv2_s + label_info: 80 + multi_scale: false + gradient_checkpointing: false # Set to true to reduce memory usage at cost of ~20% slower training + + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.0005 + betas: [0.9, 0.999] + weight_decay: 0.0001 + + scheduler: + class_path: otx.backend.native.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 30 + main_scheduler_callable: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.5 + patience: 6 + monitor: val/f1-score + +engine: + device: auto + +callback_monitor: val/f1-score + +callbacks: + - class_path: otx.backend.native.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling + init_args: + max_interval: 1 + min_lrschedule_patience: 3 + - class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + mode: max + patience: 10 + min_delta: 0.001 + warmup_iters: 50 + warmup_epochs: 10 + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: "" # use engine.work_dir + monitor: val/f1-score + mode: max + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + filename: "checkpoints/epoch_{epoch:03d}" + - class_path: otx.backend.native.callbacks.aug_scheduler.AugmentationSchedulerCallback + init_args: + data_aug_switch: + class_path: otx.backend.native.callbacks.aug_scheduler.DataAugSwitch + init_args: + policy_epochs: [4, 40, 70] + policies: + no_aug: + to_tv_image: false + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + keep_ratio: false + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + light_aug: + to_tv_image: true + transforms: + - class_path: torchvision.transforms.v2.ToImage + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + init_args: + p: 0.5 + - class_path: torchvision.transforms.v2.RandomZoomOut + init_args: + fill: 0 + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomIoUCrop + init_args: + probability: 0.8 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + transform_bbox: true + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + strong_aug_1: + to_tv_image: false + transforms: + - class_path: otx.data.transform_libs.torchvision.CachedMosaic + init_args: + random_pop: true + max_cached_images: 20 + img_scale: [640, 640] + - class_path: otx.data.transform_libs.torchvision.CachedMixUp + init_args: + img_scale: [640, 640] # (H, W) + ratio_range: + - 1.0 + - 1.0 + probability: 0.5 + random_pop: true + max_cached_images: 10 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] # (H, W) + keep_ratio: false + transform_bbox: true + - class_path: otx.data.transform_libs.torchvision.YOLOXHSVRandomAug + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + strong_aug_2: + to_tv_image: true + transforms: + - class_path: torchvision.transforms.v2.ToImage + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + init_args: + p: 0.5 + - class_path: torchvision.transforms.v2.RandomZoomOut # Can't be used when using CachedMosaic + init_args: + fill: 0 + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomIoUCrop # Can't be used when using CachedMosaic + init_args: + probability: 0.8 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.CachedMixUp + init_args: + img_scale: [640, 640] # (H, W) + ratio_range: + - 1.0 + - 1.0 + probability: 0.5 + random_pop: true + max_cached_images: 10 + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + transform_bbox: true + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] +data: ../_base_/data/torchvision_base.yaml +overrides: + callbacks: + - class_path: otx.backend.native.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling + init_args: + max_interval: 1 + min_lrschedule_patience: 3 + - class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + warmup_epochs: 7 + + data: + input_size: + - 640 + - 640 + task: DETECTION + data_format: coco_instances + train_subset: + batch_size: 8 + num_workers: 4 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + sampler: + class_path: otx.data.samplers.balanced_sampler.BalancedSampler + + val_subset: + batch_size: 8 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + test_subset: + batch_size: 8 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] diff --git a/library/src/otx/recipe/detection/deimv2_x.yaml b/library/src/otx/recipe/detection/deimv2_x.yaml new file mode 100644 index 00000000000..2d722a41303 --- /dev/null +++ b/library/src/otx/recipe/detection/deimv2_x.yaml @@ -0,0 +1,272 @@ +task: DETECTION +model: + class_path: otx.backend.native.models.detection.DEIMV2 + init_args: + model_name: deimv2_x + label_info: 80 + multi_scale: false + gradient_checkpointing: false # Set to true to reduce memory usage at cost of ~20% slower training + + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.0005 + betas: [0.9, 0.999] + weight_decay: 0.000125 + + scheduler: + class_path: otx.backend.native.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 30 + main_scheduler_callable: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.5 + patience: 6 + monitor: val/f1-score + +engine: + device: auto + +callback_monitor: val/f1-score + +callbacks: + - class_path: otx.backend.native.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling + init_args: + max_interval: 1 + min_lrschedule_patience: 3 + - class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + mode: max + patience: 10 + min_delta: 0.001 + warmup_iters: 50 + warmup_epochs: 10 + - class_path: otx.backend.native.callbacks.cuda_cache_cleaner.CUDACacheCleaner + init_args: + clean_on_validation_end: true + clean_on_epoch_end: true + log_memory: false + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: "" # use engine.work_dir + monitor: val/f1-score + mode: max + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + filename: "checkpoints/epoch_{epoch:03d}" + - class_path: otx.backend.native.callbacks.aug_scheduler.AugmentationSchedulerCallback + init_args: + data_aug_switch: + class_path: otx.backend.native.callbacks.aug_scheduler.DataAugSwitch + init_args: + policy_epochs: [0, 23, 40] + policies: + no_aug: + to_tv_image: false + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + keep_ratio: false + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + light_aug: + to_tv_image: true + transforms: + - class_path: torchvision.transforms.v2.ToImage + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + init_args: + p: 0.5 + - class_path: torchvision.transforms.v2.RandomZoomOut + init_args: + fill: 0 + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomIoUCrop + init_args: + probability: 0.8 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + transform_bbox: true + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + strong_aug_1: + to_tv_image: false + transforms: + - class_path: otx.data.transform_libs.torchvision.CachedMosaic + init_args: + random_pop: true + max_cached_images: 20 + img_scale: [640, 640] + - class_path: otx.data.transform_libs.torchvision.CachedMixUp + init_args: + img_scale: [640, 640] # (H, W) + ratio_range: + - 1.0 + - 1.0 + probability: 0.5 + random_pop: true + max_cached_images: 10 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] # (H, W) + keep_ratio: false + transform_bbox: true + - class_path: otx.data.transform_libs.torchvision.YOLOXHSVRandomAug + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + strong_aug_2: + to_tv_image: true + transforms: + - class_path: torchvision.transforms.v2.ToImage + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + init_args: + p: 0.5 + - class_path: torchvision.transforms.v2.RandomZoomOut # Can't be used when using CachedMosaic + init_args: + fill: 0 + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomIoUCrop # Can't be used when using CachedMosaic + init_args: + probability: 0.8 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.CachedMixUp + init_args: + img_scale: [640, 640] # (H, W) + ratio_range: + - 1.0 + - 1.0 + probability: 0.5 + random_pop: true + max_cached_images: 10 + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.SanitizeBoundingBoxes + init_args: + min_area: 1 + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: [640, 640] + transform_bbox: true + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] +data: ../_base_/data/torchvision_base.yaml +overrides: + callbacks: + - class_path: otx.backend.native.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling + init_args: + max_interval: 1 + min_lrschedule_patience: 3 + - class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 100 + warmup_epochs: 7 + + data: + input_size: + - 640 + - 640 + task: DETECTION + data_format: coco_instances + train_subset: + batch_size: 8 + num_workers: 4 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: otx.data.transform_libs.torchvision.RandomFlip + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + sampler: + class_path: otx.data.samplers.balanced_sampler.BalancedSampler + + val_subset: + batch_size: 8 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] + test_subset: + batch_size: 8 + transforms: + - class_path: otx.data.transform_libs.torchvision.Resize + init_args: + scale: $(input_size) + keep_ratio: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.280, 103.530] + std: [58.395, 57.120, 57.375] diff --git a/library/tests/perf_v2/benchmark.py b/library/tests/perf_v2/benchmark.py index 790ee094b53..504dda4aedc 100644 --- a/library/tests/perf_v2/benchmark.py +++ b/library/tests/perf_v2/benchmark.py @@ -97,7 +97,6 @@ def __init__( num_epoch: int = 0, eval_upto: str = "train", tags: dict[str, str] | None = None, - dry_run: bool = False, deterministic: bool = False, accelerator: str = "gpu", reference_results: pd.DataFrame | None = None, @@ -109,7 +108,6 @@ def __init__( self.num_epoch = num_epoch self.eval_upto = eval_upto self.tags = tags or {} - self.dry_run = dry_run self.deterministic = deterministic self.accelerator = accelerator self.reference_results = reference_results @@ -127,6 +125,7 @@ def train( dataset_info: DatasetInfo, sub_work_dir: Path, seed: int, + num_devices: int = 1, ) -> float: """Train model with given dataset and return the total time. @@ -145,7 +144,7 @@ def train( dataset_info=dataset_info, work_dir=sub_work_dir / SubCommand.TRAIN.value, ) - + engine.num_devices = num_devices kwargs = {} if dataset_info.extra_overrides: kwargs.update(dataset_info.extra_overrides.get("train", {})) @@ -338,6 +337,7 @@ def run( dataset_info: DatasetInfo, seed: int, criteria: list[Criterion], + num_devices: int = 1, ) -> pd.DataFrame | None: """Run configured benchmark with given dataset and model and return the result. @@ -381,6 +381,7 @@ def run( dataset_info=dataset_info, sub_work_dir=sub_work_dir, seed=seed, + num_devices=num_devices, ) self._log_metrics( @@ -638,7 +639,6 @@ def check(self, result: pd.DataFrame, criteria: list[Criterion]): num_epoch=args.num_epoch, eval_upto=args.eval_upto, tags=tags, - dry_run=args.dry_run, deterministic=( False if args.deterministic is None else {"true": True, "false": False, "warn": "warn"}[args.deterministic] ), @@ -653,6 +653,7 @@ def check(self, result: pd.DataFrame, criteria: list[Criterion]): dataset_info=dataset_info, seed=args.seed, criteria=criteria, + num_devices=args.num_devices, ) benchmark.check( result=result, diff --git a/library/tests/perf_v2/run.py b/library/tests/perf_v2/run.py index 85ce3c47c04..79db7a77efe 100644 --- a/library/tests/perf_v2/run.py +++ b/library/tests/perf_v2/run.py @@ -99,7 +99,6 @@ def load_failed_jobs(file_path: Path) -> list[list[str]]: if (output_root / model.name / dataset.name / str(seed)).exists(): logger.info(f"Skipping existing job for {model.name} on {dataset.name} with seed {seed}") continue - cmd = [ "python", "-m", @@ -116,8 +115,12 @@ def load_failed_jobs(file_path: Path) -> list[list[str]]: str(output_root), "--seed", str(seed), + "--eval-upto", + str(args.eval_upto), "--num-epoch", str(args.num_epoch), + "--deterministic", + str(args.deterministic), "--device", args.device, "--user-name", diff --git a/library/tests/perf_v2/summary.py b/library/tests/perf_v2/summary.py index 61b1e11926e..8422f740b97 100644 --- a/library/tests/perf_v2/summary.py +++ b/library/tests/perf_v2/summary.py @@ -170,7 +170,7 @@ def summarize_table(history: pd.DataFrame, task: OTXTaskType) -> list[pd.DataFra score_metric = TASK_METRIC_MAP[task] # Metrics to summarize in aggregated table - metrics = [ + expected_metrics = [ "training:e2e_time", "training:epoch", "training:train/iter_time", @@ -185,7 +185,15 @@ def summarize_table(history: pd.DataFrame, task: OTXTaskType) -> list[pd.DataFra ] raw_task_data = history.query(f"task == '{task.value}'") - dataset_dfs = aggregate(raw_task_data, metrics) + valid_metrics = [] + for metric in expected_metrics: + if metric not in raw_task_data.columns: + msg = f"Metric {metric} not found in raw data" + logger.warning(msg) + else: + valid_metrics.append(metric) + + dataset_dfs = aggregate(raw_task_data, valid_metrics) # Round all numeric columns to 4 decimal places for df in dataset_dfs: diff --git a/library/tests/perf_v2/tasks/detection.py b/library/tests/perf_v2/tasks/detection.py index 6c2b5a90a1c..939fe58f595 100644 --- a/library/tests/perf_v2/tasks/detection.py +++ b/library/tests/perf_v2/tasks/detection.py @@ -21,6 +21,13 @@ ModelInfo(task=TASK_TYPE.value, name="atss_mobilenetv2", category="default"), ModelInfo(task=TASK_TYPE.value, name="yolox_s", category="speed"), ModelInfo(task=TASK_TYPE.value, name="dfine_x", category="accuracy"), + ModelInfo(task=TASK_TYPE.value, name="deim_dfine_x", category="other"), + ModelInfo(task=TASK_TYPE.value, name="deimv2_x", category="other"), + ModelInfo(task=TASK_TYPE.value, name="deim_dfine_l", category="other"), + ModelInfo(task=TASK_TYPE.value, name="deimv2_l", category="other"), + ModelInfo(task=TASK_TYPE.value, name="deim_dfine_m", category="other"), + ModelInfo(task=TASK_TYPE.value, name="deimv2_m", category="other"), + ModelInfo(task=TASK_TYPE.value, name="deimv2_s", category="other"), ModelInfo(task=TASK_TYPE.value, name="atss_resnext101", category="other"), ModelInfo(task=TASK_TYPE.value, name="rtdetr_101", category="other"), ModelInfo(task=TASK_TYPE.value, name="rtdetr_18", category="other"), @@ -55,11 +62,6 @@ path=Path("detection/wgisd_merged_coco_small"), group="small", ), - DatasetInfo( - name="skindetect", - path=Path("detection/skindetect-roboflow"), - group="small", - ), DatasetInfo( name="diopsis", path=Path("detection/diopsis_coco"), @@ -70,11 +72,6 @@ path=Path("detection/bdd_medium"), group="medium", ), - DatasetInfo( - name="Vitens-Aeromonas", - path=Path("detection/Vitens-Aeromonas-coco"), - group="medium", - ), DatasetInfo( name="visdrone", path=Path("detection/visdrone_coco_custom_split"), diff --git a/library/tests/perf_v2/utils.py b/library/tests/perf_v2/utils.py index 2307f922b23..1093dc0bcad 100644 --- a/library/tests/perf_v2/utils.py +++ b/library/tests/perf_v2/utils.py @@ -352,4 +352,10 @@ def get_parser() -> ArgumentParser: default="gpu", help="Which device to use.", ) + parser.add_argument( + "--num-devices", + type=int, + default=1, + help="How much devices to use during training.", + ) return parser diff --git a/library/tests/unit/backend/native/models/classification/utils/test_embed.py b/library/tests/unit/backend/native/models/classification/utils/test_embed.py deleted file mode 100644 index ef89d0f7466..00000000000 --- a/library/tests/unit/backend/native/models/classification/utils/test_embed.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from otx.backend.native.models.classification.utils.embed import resize_pos_embed - - -def test_resize_pos_embed(): - pos_embed = torch.randn(1, 32 * 32 + 1, 256) - src_shape = (32, 32) - dst_shape = (64, 64) - mode = "bicubic" - num_extra_tokens = 1 - - resized_pos_embed = resize_pos_embed(pos_embed, src_shape, dst_shape, mode, num_extra_tokens) - - assert resized_pos_embed.shape == (1, 4097, 256) - assert resized_pos_embed.dtype == pos_embed.dtype - assert resized_pos_embed[:, :num_extra_tokens].equal(pos_embed[:, :num_extra_tokens]) diff --git a/library/tests/unit/backend/native/models/common/backbones/test_dinov3.py b/library/tests/unit/backend/native/models/common/backbones/test_dinov3.py new file mode 100644 index 00000000000..4b7bad287fe --- /dev/null +++ b/library/tests/unit/backend/native/models/common/backbones/test_dinov3.py @@ -0,0 +1,459 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for DINOv3 Vision Transformer backbone.""" + +import pytest +import torch + +from otx.backend.native.models.common.backbones.dinov3 import ( + DinoVisionTransformer, + Weights, + configs, + dtype_dict, + ffn_layer_dict, + init_weights_vit, + named_apply, + norm_layer_dict, +) + + +class TestDinoVisionTransformer: + """Test class for DinoVisionTransformer.""" + + @pytest.fixture + def vit_small(self): + """Create a ViT-S/16 model.""" + return DinoVisionTransformer(name="dinov3_vits16") + + @pytest.fixture + def vit_small_plus(self): + """Create a ViT-S/16+ model with SwiGLU.""" + return DinoVisionTransformer(name="dinov3_vits16plus") + + def test_init_vits16(self, vit_small): + """Test ViT-S/16 initialization.""" + assert isinstance(vit_small, DinoVisionTransformer) + assert vit_small.embed_dim == 384 + assert vit_small.num_features == 384 + assert vit_small.n_blocks == 12 + assert vit_small.num_heads == 6 + assert vit_small.patch_size == 16 + assert vit_small.n_storage_tokens == 4 + + def test_init_vits16plus(self, vit_small_plus): + """Test ViT-S/16+ initialization with SwiGLU.""" + assert isinstance(vit_small_plus, DinoVisionTransformer) + assert vit_small_plus.embed_dim == 384 + assert vit_small_plus.n_blocks == 12 + # Check SwiGLU FFN is used + from otx.backend.native.models.classification.utils.swiglu_ffn import SwiGLUFFNV2 + + assert isinstance(vit_small_plus.blocks[0].mlp, SwiGLUFFNV2) + + def test_model_components(self, vit_small): + """Test that all model components are properly initialized.""" + # Check patch embedding + assert hasattr(vit_small, "patch_embed") + assert vit_small.patch_embed.patch_size == (16, 16) + + # Check cls token + assert hasattr(vit_small, "cls_token") + assert vit_small.cls_token.shape == (1, 1, 384) + + # Check storage tokens + assert hasattr(vit_small, "storage_tokens") + assert vit_small.storage_tokens.shape == (1, 4, 384) + + # Check RoPE embedding + assert hasattr(vit_small, "rope_embed") + + # Check transformer blocks + assert hasattr(vit_small, "blocks") + assert len(vit_small.blocks) == 12 + + # Check normalization + assert hasattr(vit_small, "norm") + + # Check mask token + assert hasattr(vit_small, "mask_token") + assert vit_small.mask_token.shape == (1, 384) + + def test_forward_single_image(self, vit_small): + """Test forward pass with a single image.""" + vit_small.eval() + x = torch.randn(1, 3, 224, 224) + + with torch.no_grad(): + output = vit_small(x) + + # In eval mode, returns cls token features + assert output.shape == (1, 384) + + def test_forward_batch(self, vit_small): + """Test forward pass with a batch of images.""" + vit_small.eval() + batch_size = 4 + x = torch.randn(batch_size, 3, 224, 224) + + with torch.no_grad(): + output = vit_small(x) + + assert output.shape == (batch_size, 384) + + def test_forward_features(self, vit_small): + """Test forward_features method.""" + vit_small.eval() + x = torch.randn(2, 3, 224, 224) + + with torch.no_grad(): + output = vit_small.forward_features(x) + + assert isinstance(output, dict) + assert "x_norm_clstoken" in output + assert "x_storage_tokens" in output + assert "x_norm_patchtokens" in output + assert "x_prenorm" in output + assert "masks" in output + + # Check shapes + assert output["x_norm_clstoken"].shape == (2, 384) + assert output["x_storage_tokens"].shape == (2, 4, 384) + # 224/16 = 14, so 14*14 = 196 patch tokens + assert output["x_norm_patchtokens"].shape == (2, 196, 384) + + def test_forward_features_list(self, vit_small): + """Test forward_features with list of images.""" + vit_small.eval() + x1 = torch.randn(2, 3, 224, 224) + x2 = torch.randn(2, 3, 224, 224) + + with torch.no_grad(): + output = vit_small.forward_features([x1, x2]) + + assert isinstance(output, list) + assert len(output) == 2 + for out in output: + assert isinstance(out, dict) + assert "x_norm_clstoken" in out + + def test_forward_training_mode(self, vit_small): + """Test forward pass in training mode.""" + vit_small.train() + x = torch.randn(2, 3, 224, 224) + + output = vit_small(x, is_training=True) + + assert isinstance(output, dict) + assert "x_norm_clstoken" in output + + def test_prepare_tokens_with_masks(self, vit_small): + """Test token preparation with masks.""" + x = torch.randn(2, 3, 224, 224) + + tokens, (h, w) = vit_small.prepare_tokens_with_masks(x) + + # Should have cls_token + storage_tokens + patch_tokens + # 1 + 4 + 196 = 201 + assert tokens.shape == (2, 201, 384) + assert h == 14 + assert w == 14 + + def test_prepare_tokens_with_mask_tokens(self, vit_small): + """Test token preparation with mask tokens applied.""" + x = torch.randn(2, 3, 224, 224) + # Create a mask for some patch positions + masks = torch.zeros(2, 196, dtype=torch.bool) + masks[:, :50] = True # Mask first 50 patches + + tokens, (h, w) = vit_small.prepare_tokens_with_masks(x, masks) + + assert tokens.shape == (2, 201, 384) + assert h == 14 + assert w == 14 + + def test_get_intermediate_layers(self, vit_small): + """Test getting intermediate layer outputs.""" + vit_small.eval() + x = torch.randn(1, 3, 224, 224) + + with torch.no_grad(): + # Get last 2 layers + outputs = vit_small.get_intermediate_layers(x, n=2) + + assert len(outputs) == 2 + for out in outputs: + assert out.shape == (1, 196, 384) + + def test_get_intermediate_layers_specific_indices(self, vit_small): + """Test getting specific intermediate layers by index.""" + vit_small.eval() + x = torch.randn(1, 3, 224, 224) + + with torch.no_grad(): + # Get layers 5 and 10 + outputs = vit_small.get_intermediate_layers(x, n=[5, 10]) + + assert len(outputs) == 2 + + def test_get_intermediate_layers_reshape(self, vit_small): + """Test getting intermediate layers with spatial reshape.""" + vit_small.eval() + x = torch.randn(1, 3, 224, 224) + + with torch.no_grad(): + outputs = vit_small.get_intermediate_layers(x, n=1, reshape=True) + + assert len(outputs) == 1 + # Should be reshaped to (B, C, H, W) + assert outputs[0].shape == (1, 384, 14, 14) + + def test_get_intermediate_layers_with_cls_token(self, vit_small): + """Test getting intermediate layers with class token.""" + vit_small.eval() + x = torch.randn(1, 3, 224, 224) + + with torch.no_grad(): + outputs = vit_small.get_intermediate_layers(x, n=1, return_class_token=True) + + assert len(outputs) == 1 + assert len(outputs[0]) == 2 # (features, cls_token) + features, cls_token = outputs[0] + assert features.shape == (1, 196, 384) + assert cls_token.shape == (1, 384) + + def test_get_intermediate_layers_with_extra_tokens(self, vit_small): + """Test getting intermediate layers with extra/storage tokens.""" + vit_small.eval() + x = torch.randn(1, 3, 224, 224) + + with torch.no_grad(): + outputs = vit_small.get_intermediate_layers(x, n=1, return_extra_tokens=True) + + assert len(outputs) == 1 + assert len(outputs[0]) == 2 # (features, extra_tokens) + features, extra_tokens = outputs[0] + assert features.shape == (1, 196, 384) + assert extra_tokens.shape == (1, 4, 384) + + def test_get_intermediate_layers_with_both_tokens(self, vit_small): + """Test getting intermediate layers with both cls and extra tokens.""" + vit_small.eval() + x = torch.randn(1, 3, 224, 224) + + with torch.no_grad(): + outputs = vit_small.get_intermediate_layers(x, n=1, return_class_token=True, return_extra_tokens=True) + + assert len(outputs) == 1 + assert len(outputs[0]) == 3 # (features, cls_token, extra_tokens) + features, cls_token, extra_tokens = outputs[0] + assert features.shape == (1, 196, 384) + assert cls_token.shape == (1, 384) + assert extra_tokens.shape == (1, 4, 384) + + def test_get_intermediate_layers_no_norm(self, vit_small): + """Test getting intermediate layers without normalization.""" + vit_small.eval() + x = torch.randn(1, 3, 224, 224) + + with torch.no_grad(): + outputs_normed = vit_small.get_intermediate_layers(x, n=1, norm=True) + outputs_raw = vit_small.get_intermediate_layers(x, n=1, norm=False) + + # Outputs should be different + assert not torch.allclose(outputs_normed[0], outputs_raw[0]) + + def test_different_input_sizes(self, vit_small): + """Test forward pass with different input sizes.""" + vit_small.eval() + + # Test with 448x448 (larger than training size) + x_large = torch.randn(1, 3, 448, 448) + with torch.no_grad(): + output_large = vit_small.forward_features(x_large) + + # 448/16 = 28, so 28*28 = 784 patches + assert output_large["x_norm_patchtokens"].shape == (1, 784, 384) + + def test_init_weights(self, vit_small): + """Test weight initialization.""" + # Simply verify init_weights runs without error + vit_small.init_weights() + + # Check that cls_token has non-zero values (initialized) + assert not torch.all(vit_small.cls_token == 0) + + # Check that mask_token is zero-initialized + assert torch.all(vit_small.mask_token == 0) + + +class TestHelperFunctions: + """Test helper functions and configurations.""" + + def test_configs_exist(self): + """Test that configurations exist.""" + assert "dinov3_vits16" in configs + assert "dinov3_vits16plus" in configs + + def test_configs_structure(self): + """Test configuration structure.""" + config = configs["dinov3_vits16"] + + required_keys = [ + "img_size", + "patch_size", + "embed_dim", + "depth", + "num_heads", + "ffn_ratio", + "norm_layer", + "ffn_layer", + ] + for key in required_keys: + assert key in config + + def test_weights_enum(self): + """Test Weights enum.""" + assert Weights.LVD1689M.value == "LVD1689M" + assert Weights.SAT493M.value == "SAT493M" + + def test_ffn_layer_dict(self): + """Test FFN layer mapping.""" + assert "mlp" in ffn_layer_dict + assert "swiglu" in ffn_layer_dict + assert "swiglu32" in ffn_layer_dict + assert "swiglu64" in ffn_layer_dict + assert "swiglu128" in ffn_layer_dict + + def test_norm_layer_dict(self): + """Test norm layer mapping.""" + assert "layernorm" in norm_layer_dict + assert "layernormbf16" in norm_layer_dict + + def test_dtype_dict(self): + """Test dtype mapping.""" + assert dtype_dict["fp32"] == torch.float32 + assert dtype_dict["fp16"] == torch.float16 + assert dtype_dict["bf16"] == torch.bfloat16 + + def test_named_apply(self): + """Test named_apply function.""" + model = torch.nn.Sequential( + torch.nn.Linear(10, 20), + torch.nn.ReLU(), + torch.nn.Linear(20, 5), + ) + + applied_names = [] + + def track_fn(module, name) -> None: + applied_names.append(name) + + named_apply(track_fn, model, include_root=True) + + # Should have applied to root and all children + assert "" in applied_names # Root + assert "0" in applied_names + assert "1" in applied_names + assert "2" in applied_names + + def test_named_apply_depth_first(self): + """Test named_apply with depth-first ordering.""" + model = torch.nn.Sequential( + torch.nn.Linear(10, 20), + torch.nn.Linear(20, 5), + ) + + order = [] + + def track_order(module, name) -> None: + order.append(name) + + # Depth-first: children before parent + named_apply(track_order, model, depth_first=True, include_root=True) + + # Children should come before root + root_idx = order.index("") + child_idx = order.index("0") + assert child_idx < root_idx + + def test_named_apply_breadth_first(self): + """Test named_apply with breadth-first ordering.""" + model = torch.nn.Sequential( + torch.nn.Linear(10, 20), + torch.nn.Linear(20, 5), + ) + + order = [] + + def track_order(module, name) -> None: + order.append(name) + + # Breadth-first: parent before children + named_apply(track_order, model, depth_first=False, include_root=True) + + # Root should come before children + root_idx = order.index("") + child_idx = order.index("0") + assert root_idx < child_idx + + def test_init_weights_vit_linear(self): + """Test init_weights_vit for Linear layers.""" + linear = torch.nn.Linear(10, 20) + original_weight = linear.weight.clone() + + init_weights_vit(linear) + + # Weight should be reinitialized + assert not torch.allclose(linear.weight, original_weight) + # Bias should be zeros + assert torch.all(linear.bias == 0) + + def test_init_weights_vit_layernorm(self): + """Test init_weights_vit for LayerNorm.""" + ln = torch.nn.LayerNorm(256) + + # Should not raise error + init_weights_vit(ln) + + # LayerNorm should have default initialization + assert torch.allclose(ln.weight, torch.ones(256)) + assert torch.allclose(ln.bias, torch.zeros(256)) + + +class TestGradients: + """Test gradient flow through the model.""" + + @pytest.fixture + def vit_small(self): + """Create a ViT-S/16 model for gradient testing.""" + return DinoVisionTransformer(name="dinov3_vits16") + + def test_gradient_flow(self, vit_small): + """Test that gradients flow through the model.""" + vit_small.train() + x = torch.randn(2, 3, 224, 224, requires_grad=True) + + output = vit_small(x, is_training=True) + loss = output["x_norm_clstoken"].sum() + loss.backward() + + # Check gradients exist + assert x.grad is not None + assert not torch.all(x.grad == 0) + + def test_gradient_flow_to_parameters(self, vit_small): + """Test that gradients flow to all trainable parameters.""" + vit_small.train() + x = torch.randn(2, 3, 224, 224) + + output = vit_small(x, is_training=True) + loss = output["x_norm_clstoken"].sum() + loss.backward() + + # Check key parameters have gradients + assert vit_small.cls_token.grad is not None + assert vit_small.storage_tokens.grad is not None + + # Check some block parameters + assert vit_small.blocks[0].attn.qkv.weight.grad is not None diff --git a/library/tests/unit/backend/native/models/detection/backbones/test_vit_tiny.py b/library/tests/unit/backend/native/models/detection/backbones/test_vit_tiny.py new file mode 100644 index 00000000000..a35a0b71723 --- /dev/null +++ b/library/tests/unit/backend/native/models/detection/backbones/test_vit_tiny.py @@ -0,0 +1,485 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Vision Transformer (ViT) Tiny backbone.""" + +import pytest +import torch + +from otx.backend.native.models.detection.backbones.vit_tiny import ( + Attention, + Block, + DropPath, + SimplifiedPatchEmbed, + VisionTransformer, + apply_rope, + drop_path, + rotate_half, +) + + +class TestHelperFunctions: + """Test helper functions for ViT.""" + + def test_rotate_half(self): + """Test rotate_half function.""" + x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float32) + result = rotate_half(x) + + # First half becomes negated second half + expected = torch.tensor([[-3, -4, 1, 2], [-7, -8, 5, 6]], dtype=torch.float32) + assert torch.allclose(result, expected) + + def test_rotate_half_batched(self): + """Test rotate_half with batched input.""" + x = torch.randn(2, 4, 8, 64) # (B, H, N, D) + result = rotate_half(x) + + assert result.shape == x.shape + # Check the rotation pattern + assert torch.allclose(result[..., :32], -x[..., 32:]) + assert torch.allclose(result[..., 32:], x[..., :32]) + + def test_apply_rope(self): + """Test apply_rope function.""" + x = torch.randn(2, 4, 196, 64) + sin = torch.randn(1, 1, 196, 64) + cos = torch.randn(1, 1, 196, 64) + + result = apply_rope(x, sin, cos) + + assert result.shape == x.shape + assert not torch.isnan(result).any() + + def test_drop_path_no_drop(self): + """Test drop_path with 0 probability.""" + x = torch.randn(2, 100, 256) + result = drop_path(x, drop_prob=0.0, training=True) + assert torch.allclose(result, x) + + def test_drop_path_eval_mode(self): + """Test drop_path in eval mode (training=False).""" + x = torch.randn(2, 100, 256) + result = drop_path(x, drop_prob=0.5, training=False) + assert torch.allclose(result, x) + + def test_drop_path_training(self): + """Test drop_path during training.""" + torch.manual_seed(42) + x = torch.ones(10, 100, 256) + result = drop_path(x, drop_prob=0.5, training=True) + + # Some samples may be zeroed out, others scaled + # Check that the shape is preserved + assert result.shape == x.shape + + +class TestDropPath: + """Test DropPath module.""" + + def test_droppath_init(self): + """Test DropPath initialization.""" + dp = DropPath(drop_prob=0.1) + assert dp.drop_prob == 0.1 + + def test_droppath_init_none(self): + """Test DropPath with None probability.""" + dp = DropPath(drop_prob=None) + assert dp.drop_prob is None + + def test_droppath_forward_training(self): + """Test DropPath forward in training mode.""" + dp = DropPath(drop_prob=0.5) + dp.train() + x = torch.randn(4, 100, 256) + result = dp(x) + assert result.shape == x.shape + + def test_droppath_forward_eval(self): + """Test DropPath forward in eval mode.""" + dp = DropPath(drop_prob=0.5) + dp.eval() + x = torch.randn(4, 100, 256) + result = dp(x) + assert torch.allclose(result, x) + + +class TestSimplifiedPatchEmbed: + """Test SimplifiedPatchEmbed module.""" + + @pytest.fixture + def patch_embed(self): + """Create patch embed layer.""" + return SimplifiedPatchEmbed( + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=192, + ) + + def test_init(self, patch_embed): + """Test patch embed initialization.""" + assert patch_embed.grid_size == (14, 14) + assert patch_embed.num_patches == 196 + assert patch_embed.proj.kernel_size == (16, 16) + assert patch_embed.proj.stride == (16, 16) + + def test_init_tuple_sizes(self): + """Test patch embed with tuple sizes.""" + pe = SimplifiedPatchEmbed( + img_size=(224, 224), + patch_size=(16, 16), + in_chans=3, + embed_dim=192, + ) + assert pe.grid_size == (14, 14) + + def test_forward(self, patch_embed): + """Test patch embed forward pass.""" + x = torch.randn(2, 3, 224, 224) + output = patch_embed(x) + + # Output shape: (B, num_patches, embed_dim) + assert output.shape == (2, 196, 192) + + def test_forward_different_sizes(self): + """Test patch embed with different image sizes.""" + pe = SimplifiedPatchEmbed( + img_size=448, + patch_size=16, + in_chans=3, + embed_dim=256, + ) + x = torch.randn(1, 3, 448, 448) + output = pe(x) + + # 448/16 = 28, so 28*28 = 784 patches + assert output.shape == (1, 784, 256) + + +class TestAttention: + """Test Attention module.""" + + @pytest.fixture + def attention(self): + """Create attention module.""" + return Attention( + dim=192, + num_heads=3, + qkv_bias=True, + attn_drop=0.0, + proj_drop=0.0, + ) + + def test_init(self, attention): + """Test attention initialization.""" + assert attention.num_heads == 3 + assert attention.scale == (192 // 3) ** -0.5 + assert attention.qkv.in_features == 192 + assert attention.qkv.out_features == 192 * 3 + assert attention.proj.in_features == 192 + assert attention.proj.out_features == 192 + + def test_forward_without_rope(self, attention): + """Test attention forward without RoPE.""" + x = torch.randn(2, 197, 192) # (B, N, C) with cls token + output = attention(x) + + assert output.shape == x.shape + + def test_forward_with_rope(self, attention): + """Test attention forward with RoPE.""" + x = torch.randn(2, 197, 192) # (B, N, C) with cls token + + # Create RoPE sin/cos for patches (excluding cls token) + sin = torch.randn(1, 1, 196, 64) # 192 / 3 = 64 head dim + cos = torch.randn(1, 1, 196, 64) + + output = attention(x, rope_sincos=(sin, cos)) + + assert output.shape == x.shape + + def test_forward_batched(self, attention): + """Test attention with different batch sizes.""" + for batch_size in [1, 2, 4]: + x = torch.randn(batch_size, 197, 192) + output = attention(x) + assert output.shape == x.shape + + +class TestBlock: + """Test Transformer Block module.""" + + @pytest.fixture + def block(self): + """Create transformer block.""" + return Block( + dim=192, + num_heads=3, + mlp_ratio=4.0, + qkv_bias=True, + attn_drop=0.0, + drop_path=0.0, + drop=0.0, + ) + + @pytest.fixture + def block_with_droppath(self): + """Create transformer block with drop path.""" + return Block( + dim=192, + num_heads=3, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.1, + ) + + def test_init(self, block): + """Test block initialization.""" + assert hasattr(block, "norm1") + assert hasattr(block, "attn") + assert hasattr(block, "norm2") + assert hasattr(block, "mlp") + assert hasattr(block, "drop_path") + + def test_init_with_droppath(self, block_with_droppath): + """Test block with drop path initialization.""" + assert isinstance(block_with_droppath.drop_path, DropPath) + assert block_with_droppath.drop_path.drop_prob == 0.1 + + def test_forward_without_rope(self, block): + """Test block forward without RoPE.""" + x = torch.randn(2, 197, 192) + output = block(x) + + assert output.shape == x.shape + assert not torch.isnan(output).any() + + def test_forward_with_rope(self, block): + """Test block forward with RoPE.""" + x = torch.randn(2, 197, 192) + sin = torch.randn(1, 1, 196, 64) + cos = torch.randn(1, 1, 196, 64) + + output = block(x, rope_sincos=(sin, cos)) + + assert output.shape == x.shape + + +class TestVisionTransformer: + """Test VisionTransformer module.""" + + @pytest.fixture + def vit(self): + """Create ViT model.""" + return VisionTransformer( + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4.0, + qkv_bias=True, + return_layers=[3, 7, 11], + ) + + @pytest.fixture + def vit_small(self): + """Create smaller ViT for faster testing.""" + return VisionTransformer( + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=128, + depth=4, + num_heads=4, + mlp_ratio=4.0, + return_layers=[1, 2, 3], + ) + + def test_init(self, vit): + """Test ViT initialization.""" + assert vit.num_features == 192 + assert vit.embed_dim == 192 + assert vit.num_tokens == 1 + assert vit.return_layers == [3, 7, 11] + assert vit.patch_size == 16 + + def test_init_components(self, vit): + """Test ViT components initialization.""" + model = vit.get_model() + + assert hasattr(model, "patch_embed") + assert hasattr(model, "cls_token") + assert hasattr(model, "blocks") + assert hasattr(model, "rope_embed") + + assert len(model.blocks) == 12 + assert model.cls_token.shape == (1, 1, 192) + + def test_forward(self, vit_small): + """Test ViT forward pass.""" + x = torch.randn(2, 3, 224, 224) + outputs = vit_small(x) + + assert len(outputs) == 3 # 3 return layers + for patch_features, cls_token in outputs: + assert patch_features.shape == (2, 196, 128) + assert cls_token.shape == (2, 128) + + def test_forward_different_batch_sizes(self, vit_small): + """Test ViT with different batch sizes.""" + for batch_size in [1, 2, 4]: + x = torch.randn(batch_size, 3, 224, 224) + outputs = vit_small(x) + + assert len(outputs) == 3 + for patch_features, cls_token in outputs: + assert patch_features.shape[0] == batch_size + assert cls_token.shape[0] == batch_size + + def test_forward_different_image_sizes(self, vit_small): + """Test ViT with different image sizes.""" + # ViT can handle different input sizes + for img_size in [224, 448, 640]: + x = torch.randn(1, 3, img_size, img_size) + outputs = vit_small(x) + + num_patches = (img_size // 16) ** 2 + for patch_features, cls_token in outputs: + assert patch_features.shape == (1, num_patches, 128) + assert cls_token.shape == (1, 128) + + def test_feature_dim(self, vit_small): + """Test feature_dim method.""" + assert vit_small.feature_dim() == 128 + + def test_get_model(self, vit_small): + """Test get_model method.""" + model = vit_small.get_model() + assert isinstance(model, torch.nn.Module) + assert hasattr(model, "blocks") + + def test_no_weight_decay(self, vit_small): + """Test no_weight_decay method.""" + no_wd = vit_small.no_weight_decay() + assert "cls_token" in no_wd + + def test_init_weights(self, vit_small): + """Test weight initialization.""" + # Simply verify init_weights runs without error + vit_small.init_weights() + + model = vit_small.get_model() + # Check cls_token is not all zeros after init + assert not torch.all(model.cls_token == 0) + + def test_custom_return_layers(self): + """Test ViT with custom return layers.""" + vit = VisionTransformer( + img_size=224, + patch_size=16, + embed_dim=128, + depth=6, + num_heads=4, + return_layers=[0, 2, 5], + ) + + x = torch.randn(1, 3, 224, 224) + outputs = vit(x) + + assert len(outputs) == 3 + + def test_default_return_layers(self): + """Test ViT with default return layers.""" + vit = VisionTransformer( + img_size=224, + patch_size=16, + embed_dim=128, + depth=12, + num_heads=4, + # return_layers defaults to [3, 7, 11] + ) + + assert vit.return_layers == [3, 7, 11] + + def test_with_drop_path(self): + """Test ViT with drop path rate.""" + vit = VisionTransformer( + img_size=224, + patch_size=16, + embed_dim=128, + depth=4, + num_heads=4, + drop_path_rate=0.1, + return_layers=[1, 2, 3], + ) + + x = torch.randn(1, 3, 224, 224) + outputs = vit(x) + + assert len(outputs) == 3 + + def test_with_dropout(self): + """Test ViT with dropout.""" + vit = VisionTransformer( + img_size=224, + patch_size=16, + embed_dim=128, + depth=4, + num_heads=4, + drop_rate=0.1, + attn_drop_rate=0.1, + return_layers=[1, 2, 3], + ) + + x = torch.randn(1, 3, 224, 224) + outputs = vit(x) + + assert len(outputs) == 3 + + +class TestGradientFlow: + """Test gradient flow through ViT.""" + + @pytest.fixture + def vit(self): + """Create ViT for gradient testing.""" + return VisionTransformer( + img_size=224, + patch_size=16, + embed_dim=128, + depth=4, + num_heads=4, + return_layers=[1, 2, 3], + ) + + def test_gradient_flow(self, vit): + """Test that gradients flow through the model.""" + vit.train() + x = torch.randn(2, 3, 224, 224, requires_grad=True) + + outputs = vit(x) + + # Sum all outputs for loss + loss = sum(patch_feat.sum() + cls_token.sum() for patch_feat, cls_token in outputs) + loss.backward() + + assert x.grad is not None + assert not torch.all(x.grad == 0) + + def test_gradient_flow_to_parameters(self, vit): + """Test that gradients flow to parameters.""" + vit.train() + x = torch.randn(2, 3, 224, 224) + + outputs = vit(x) + loss = sum(patch_feat.sum() + cls_token.sum() for patch_feat, cls_token in outputs) + loss.backward() + + model = vit.get_model() + + # Check key parameters have gradients + assert model.cls_token.grad is not None + assert model.blocks[0].attn.qkv.weight.grad is not None diff --git a/library/tests/unit/backend/native/models/detection/heads/test_deim_decoder.py b/library/tests/unit/backend/native/models/detection/heads/test_deim_decoder.py new file mode 100644 index 00000000000..fd74eb7dfe0 --- /dev/null +++ b/library/tests/unit/backend/native/models/detection/heads/test_deim_decoder.py @@ -0,0 +1,533 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for DEIM Transformer Decoder.""" + +import pytest +import torch + +from otx.backend.native.models.detection.heads.deim_decoder import ( + DEIMTransformer, + DEIMTransformerModule, + TransformerDecoder, + TransformerDecoderLayer, +) + + +class TestTransformerDecoderLayer: + """Test class for TransformerDecoderLayer.""" + + @pytest.fixture + def decoder_layer(self): + """Create a basic decoder layer.""" + return TransformerDecoderLayer( + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0.0, + n_levels=3, + n_points=4, + ) + + @pytest.fixture + def decoder_layer_with_gateway(self): + """Create a decoder layer with gateway enabled.""" + return TransformerDecoderLayer( + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0.0, + n_levels=3, + n_points=4, + use_gateway=True, + ) + + @pytest.fixture + def decoder_layer_with_scale(self): + """Create a decoder layer with layer scale.""" + return TransformerDecoderLayer( + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0.0, + n_levels=3, + n_points=[3, 6, 3], + layer_scale=2.0, + ) + + def test_decoder_layer_init(self, decoder_layer): + """Test decoder layer initialization.""" + assert isinstance(decoder_layer, TransformerDecoderLayer) + assert decoder_layer.use_gateway is False + # Check for memory-efficient self-attention components + assert hasattr(decoder_layer, "qkv_proj") + assert hasattr(decoder_layer, "out_proj") + assert hasattr(decoder_layer, "cross_attn") + assert hasattr(decoder_layer, "swish_ffn") + + def test_decoder_layer_with_gateway_init(self, decoder_layer_with_gateway): + """Test decoder layer with gateway initialization.""" + assert decoder_layer_with_gateway.use_gateway is True + assert hasattr(decoder_layer_with_gateway, "gateway") + + def test_decoder_layer_forward(self, decoder_layer): + """Test decoder layer forward pass.""" + batch_size = 2 + num_queries = 300 + hidden_dim = 256 + spatial_shapes = [[40, 40], [20, 20], [10, 10]] + + target = torch.randn(batch_size, num_queries, hidden_dim) + reference_points = torch.rand(batch_size, num_queries, 1, 4) + + # Create value tuple for each level + values = tuple(torch.randn(batch_size, 8, hidden_dim // 8, h * w) for h, w in spatial_shapes) + + output = decoder_layer( + target=target, + reference_points=reference_points, + value=values, + spatial_shapes=spatial_shapes, + ) + + assert output.shape == target.shape + assert not torch.isnan(output).any() + + def test_with_pos_embed(self, decoder_layer): + """Test position embedding addition.""" + tensor = torch.randn(2, 100, 256) + pos = torch.randn(2, 100, 256) + + # With position embedding + result = decoder_layer.with_pos_embed(tensor, pos) + assert torch.allclose(result, tensor + pos) + + # Without position embedding + result_no_pos = decoder_layer.with_pos_embed(tensor, None) + assert torch.allclose(result_no_pos, tensor) + + +class TestDEIMTransformerModule: + """Test class for DEIMTransformerModule.""" + + @pytest.fixture + def deim_transformer(self): + """Create a basic DEIM transformer module.""" + return DEIMTransformerModule( + num_classes=10, + hidden_dim=256, + num_queries=100, + feat_channels=[256, 256, 256], + feat_strides=[8, 16, 32], + num_levels=3, + num_points=[3, 6, 3], + nhead=8, + num_layers=2, + dim_feedforward=512, + dropout=0.0, + num_denoising=50, + eval_spatial_size=(640, 640), + reg_max=32, + ) + + @pytest.fixture + def deim_transformer_minimal(self): + """Create a minimal DEIM transformer for faster testing.""" + return DEIMTransformerModule( + num_classes=5, + hidden_dim=128, + num_queries=50, + feat_channels=[128, 128, 128], + feat_strides=[8, 16, 32], + num_levels=3, + num_points=[2, 4, 2], + nhead=4, + num_layers=1, + dim_feedforward=256, + dropout=0.0, + num_denoising=0, + eval_spatial_size=(320, 320), + reg_max=16, + ) + + @pytest.fixture + def targets(self): + """Create sample targets for training.""" + return [ + { + "boxes": torch.tensor([[0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9]]), + "labels": torch.tensor([1, 0]), + }, + { + "boxes": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), + "labels": torch.tensor([2]), + }, + ] + + def test_deim_transformer_init(self, deim_transformer): + """Test DEIM transformer initialization.""" + assert isinstance(deim_transformer, DEIMTransformerModule) + assert deim_transformer.num_classes == 10 + assert deim_transformer.hidden_dim == 256 + assert deim_transformer.num_queries == 100 + assert deim_transformer.num_levels == 3 + assert deim_transformer.reg_max == 32 + assert deim_transformer.aux_loss is True + + def test_deim_transformer_components(self, deim_transformer): + """Test that all components are properly initialized.""" + # Check input projection + assert hasattr(deim_transformer, "input_proj") + assert len(deim_transformer.input_proj) == 3 + + # Check decoder + assert hasattr(deim_transformer, "decoder") + assert isinstance(deim_transformer.decoder, TransformerDecoder) + + # Check heads + assert hasattr(deim_transformer, "enc_score_head") + assert hasattr(deim_transformer, "enc_bbox_head") + assert hasattr(deim_transformer, "dec_score_head") + assert hasattr(deim_transformer, "dec_bbox_head") + assert hasattr(deim_transformer, "pre_bbox_head") + assert hasattr(deim_transformer, "query_pos_head") + + # Check denoising embedding + assert hasattr(deim_transformer, "denoising_class_embed") + + # Check integral for FDR + assert hasattr(deim_transformer, "integral") + + def test_deim_transformer_forward_training(self, deim_transformer_minimal, targets): + """Test DEIM transformer forward pass in training mode.""" + deim_transformer_minimal.train() + + feats = [ + torch.randn(2, 128, 40, 40), + torch.randn(2, 128, 20, 20), + torch.randn(2, 128, 10, 10), + ] + + output = deim_transformer_minimal(feats, targets) + + assert isinstance(output, dict) + assert "pred_logits" in output + assert "pred_boxes" in output + assert "pred_corners" in output + assert "ref_points" in output + + # Check output shapes + num_queries = deim_transformer_minimal.num_queries + num_classes = deim_transformer_minimal.num_classes + assert output["pred_logits"].shape == (2, num_queries, num_classes) + assert output["pred_boxes"].shape == (2, num_queries, 4) + + def test_deim_transformer_forward_eval(self, deim_transformer_minimal): + """Test DEIM transformer forward pass in eval mode.""" + deim_transformer_minimal.eval() + + feats = [ + torch.randn(1, 128, 40, 40), + torch.randn(1, 128, 20, 20), + torch.randn(1, 128, 10, 10), + ] + + output = deim_transformer_minimal(feats) + + assert isinstance(output, dict) + assert "pred_logits" in output + assert "pred_boxes" in output + + # In eval mode, should not have training-specific outputs + assert "pred_corners" not in output + assert "aux_outputs" not in output + + def test_deim_transformer_forward_explain_mode(self, deim_transformer_minimal): + """Test DEIM transformer forward pass with explain mode.""" + deim_transformer_minimal.eval() + + feats = [ + torch.randn(1, 128, 40, 40), + torch.randn(1, 128, 20, 20), + torch.randn(1, 128, 10, 10), + ] + + output = deim_transformer_minimal(feats, explain_mode=True) + + assert isinstance(output, dict) + assert "raw_logits" in output + + def test_deim_transformer_aux_loss(self, deim_transformer_minimal, targets): + """Test auxiliary loss outputs.""" + deim_transformer_minimal.train() + + feats = [ + torch.randn(2, 128, 40, 40), + torch.randn(2, 128, 20, 20), + torch.randn(2, 128, 10, 10), + ] + + output = deim_transformer_minimal(feats, targets) + + # Check auxiliary outputs exist + if deim_transformer_minimal.aux_loss: + assert "aux_outputs" in output or deim_transformer_minimal.num_layers == 1 + assert "enc_aux_outputs" in output + assert "pre_outputs" in output + + def test_generate_anchors(self, deim_transformer_minimal): + """Test anchor generation.""" + spatial_shapes = [[40, 40], [20, 20], [10, 10]] + + anchors, valid_mask = deim_transformer_minimal._generate_anchors( + spatial_shapes=spatial_shapes, + device="cpu", + ) + + # Check anchor shape: should be [1, total_anchors, 4] + total_anchors = sum(h * w for h, w in spatial_shapes) + assert anchors.shape == (1, total_anchors, 4) + assert valid_mask.shape == (1, total_anchors, 1) + + def test_get_encoder_input(self, deim_transformer_minimal): + """Test encoder input preparation.""" + feats = [ + torch.randn(2, 128, 40, 40), + torch.randn(2, 128, 20, 20), + torch.randn(2, 128, 10, 10), + ] + + feat_flatten, spatial_shapes = deim_transformer_minimal._get_encoder_input(feats) + + # Check flattened features + total_tokens = 40 * 40 + 20 * 20 + 10 * 10 + assert feat_flatten.shape == (2, total_tokens, 128) + + # Check spatial shapes + assert spatial_shapes == [[40, 40], [20, 20], [10, 10]] + + def test_select_topk(self, deim_transformer_minimal): + """Test top-k query selection.""" + batch_size = 2 + num_tokens = 1000 + hidden_dim = 128 + num_classes = 5 + topk = 50 + + memory = torch.randn(batch_size, num_tokens, hidden_dim) + outputs_logits = torch.randn(batch_size, num_tokens, num_classes) + outputs_anchors = torch.randn(batch_size, num_tokens, 4) + + topk_memory, topk_logits, topk_anchors = deim_transformer_minimal._select_topk( + memory, outputs_logits, outputs_anchors, topk + ) + + assert topk_memory.shape == (batch_size, topk, hidden_dim) + assert topk_anchors.shape == (batch_size, topk, 4) + # topk_logits is None in eval mode + if deim_transformer_minimal.training: + assert topk_logits.shape == (batch_size, topk, num_classes) + + def test_convert_to_deploy(self, deim_transformer_minimal): + """Test deployment conversion.""" + deim_transformer_minimal.convert_to_deploy() + + # After conversion, some heads should be Identity + eval_idx = deim_transformer_minimal.eval_idx + for i, head in enumerate(deim_transformer_minimal.dec_score_head): + if i < eval_idx: + assert isinstance(head, torch.nn.Identity) + + def test_input_proj_identity(self): + """Test input projection with matching dimensions.""" + transformer = DEIMTransformerModule( + num_classes=5, + hidden_dim=256, + feat_channels=[256, 256, 256], # Same as hidden_dim + num_layers=1, + num_denoising=0, + ) + + # When feat_channels == hidden_dim, should use Identity + for proj in transformer.input_proj: + assert isinstance(proj, torch.nn.Identity) + + def test_input_proj_conv(self): + """Test input projection with different dimensions.""" + transformer = DEIMTransformerModule( + num_classes=5, + hidden_dim=256, + feat_channels=[128, 128, 128], # Different from hidden_dim + num_layers=1, + num_denoising=0, + ) + + # When feat_channels != hidden_dim, should use Conv projection + for proj in transformer.input_proj: + assert isinstance(proj, torch.nn.Sequential) + + def test_validation_errors(self): + """Test that validation errors are raised correctly.""" + # feat_channels > num_levels should raise error + with pytest.raises(ValueError, match="feat_channels.*must be <= num_levels"): + DEIMTransformerModule( + num_classes=5, + feat_channels=[256, 256, 256, 256], + num_levels=3, + ) + + # feat_strides length mismatch should raise error + with pytest.raises(ValueError, match="feat_strides.*must match feat_channels"): + DEIMTransformerModule( + num_classes=5, + feat_channels=[256, 256, 256], + feat_strides=[8, 16], # Mismatch + ) + + +class TestDEIMTransformerFactory: + """Test class for DEIMTransformer factory.""" + + @pytest.mark.parametrize( + "model_name", + [ + "deimv2_x", + "deimv2_l", + "deimv2_m", + "deimv2_s", + ], + ) + def test_factory_creates_correct_model(self, model_name): + """Test that factory creates correct model variants.""" + transformer = DEIMTransformer( + model_name=model_name, + num_classes=80, + eval_spatial_size=(640, 640), + ) + + assert isinstance(transformer, DEIMTransformerModule) + assert transformer.num_classes == 80 + assert transformer.eval_spatial_size == (640, 640) + + def test_factory_config_deimv2_x(self): + """Test DEIMv2-X configuration.""" + transformer = DEIMTransformer( + model_name="deimv2_x", + num_classes=10, + ) + + assert transformer.hidden_dim == 256 + assert transformer.num_layers == 6 + + def test_factory_config_deimv2_l(self): + """Test DEIMv2-L configuration.""" + transformer = DEIMTransformer( + model_name="deimv2_l", + num_classes=10, + ) + + assert transformer.hidden_dim == 224 + assert transformer.num_layers == 4 + + def test_factory_config_deimv2_m(self): + """Test DEIMv2-M configuration.""" + transformer = DEIMTransformer( + model_name="deimv2_m", + num_classes=10, + ) + + assert transformer.hidden_dim == 256 + assert transformer.num_layers == 4 + + def test_factory_config_deimv2_s(self): + """Test DEIMv2-S configuration.""" + transformer = DEIMTransformer( + model_name="deimv2_s", + num_classes=10, + ) + + assert transformer.hidden_dim == 192 + assert transformer.num_layers == 4 + + def test_factory_invalid_model_name(self): + """Test that invalid model name raises error.""" + with pytest.raises(KeyError): + DEIMTransformer( + model_name="invalid_model", + num_classes=10, + ) + + +class TestTransformerDecoder: + """Test class for TransformerDecoder.""" + + @pytest.fixture + def decoder(self): + """Create a basic transformer decoder.""" + decoder_layer = TransformerDecoderLayer( + d_model=128, + n_head=4, + dim_feedforward=256, + n_levels=3, + n_points=4, + ) + decoder_layer_wide = TransformerDecoderLayer( + d_model=128, + n_head=4, + dim_feedforward=256, + n_levels=3, + n_points=4, + layer_scale=2.0, + ) + up = torch.nn.Parameter(torch.tensor([0.5]), requires_grad=False) + reg_scale = torch.nn.Parameter(torch.tensor([4.0]), requires_grad=False) + + return TransformerDecoder( + hidden_dim=128, + decoder_layer=decoder_layer, + decoder_layer_wide=decoder_layer_wide, + num_layers=2, + num_head=4, + reg_max=16, + reg_scale=reg_scale, + up=up, + eval_idx=-1, + layer_scale=2, + ) + + def test_decoder_init(self, decoder): + """Test decoder initialization.""" + assert isinstance(decoder, TransformerDecoder) + assert decoder.hidden_dim == 128 + assert decoder.num_layers == 2 + assert len(decoder.layers) == 2 + assert len(decoder.lqe_layers) == 2 + + def test_decoder_convert_to_deploy(self, decoder): + """Test decoder deployment conversion.""" + original_num_layers = len(decoder.layers) + decoder.convert_to_deploy() + + # After conversion, only layers up to eval_idx should remain + assert len(decoder.layers) <= original_num_layers + assert hasattr(decoder, "project") + + def test_value_op(self, decoder): + """Test value operation for attention.""" + batch_size = 2 + seq_len = 100 + hidden_dim = 128 + spatial_shapes = [[10, 10]] + + memory = torch.randn(batch_size, seq_len, hidden_dim) + + values = decoder.value_op( + memory=memory, + value_proj=None, + value_scale=None, + memory_mask=None, + memory_spatial_shapes=spatial_shapes, + ) + + assert isinstance(values, tuple) + assert len(values) == len(spatial_shapes) diff --git a/library/tests/unit/backend/native/models/detection/necks/test_hybrid_encoder.py b/library/tests/unit/backend/native/models/detection/necks/test_hybrid_encoder.py index a26157816cd..a621e82597a 100644 --- a/library/tests/unit/backend/native/models/detection/necks/test_hybrid_encoder.py +++ b/library/tests/unit/backend/native/models/detection/necks/test_hybrid_encoder.py @@ -3,31 +3,331 @@ # """Test of HybridEncoder.""" +import pytest import torch -from otx.backend.native.models.detection.necks.hybrid_encoder import HybridEncoderModule +from otx.backend.native.models.detection.necks.hybrid_encoder import HybridEncoder, HybridEncoderModule -def test_hybrid_encoder_forward(): - hidden_dim = 256 - feat_strides = [8, 16, 32] - in_channels = [128, 256, 512] - encoder = HybridEncoderModule(in_channels=in_channels, hidden_dim=hidden_dim, feat_strides=feat_strides) +class TestHybridEncoderModule: + """Test class for HybridEncoderModule.""" - # Create dummy input - batch_size = 2 - input_sizes = [(128, 64, 64), (256, 32, 32), (512, 16, 16)] - dummy_input = [ - torch.randn(batch_size, *input_sizes[0]), - torch.randn(batch_size, *input_sizes[1]), - torch.randn(batch_size, *input_sizes[2]), - ] + @pytest.fixture + def encoder_default(self): + """Create encoder with default settings.""" + return HybridEncoderModule( + in_channels=[512, 1024, 2048], + hidden_dim=256, + feat_strides=[8, 16, 32], + ) - # Forward pass - outputs = encoder(dummy_input) + @pytest.fixture + def encoder_small(self): + """Create a smaller encoder for faster testing.""" + return HybridEncoderModule( + in_channels=[128, 256, 512], + hidden_dim=128, + feat_strides=[8, 16, 32], + dim_feedforward=512, + num_encoder_layers=1, + ) - # Check output shapes - assert len(outputs) == 3 - for i, output in enumerate(outputs): - expected_shape = (batch_size, hidden_dim, input_sizes[i][1], input_sizes[i][2]) - assert output.shape == expected_shape + @pytest.fixture + def encoder_with_eval_size(self): + """Create encoder with evaluation spatial size.""" + return HybridEncoderModule( + in_channels=[128, 256, 512], + hidden_dim=128, + feat_strides=[8, 16, 32], + eval_spatial_size=(640, 640), + ) + + def test_init_default(self, encoder_default): + """Test default initialization.""" + assert isinstance(encoder_default, HybridEncoderModule) + assert encoder_default.hidden_dim == 256 + assert encoder_default.in_channels == [512, 1024, 2048] + assert encoder_default.feat_strides == [8, 16, 32] + assert encoder_default.use_encoder_idx == [2] + assert encoder_default.num_encoder_layers == 1 + + def test_init_components(self, encoder_small): + """Test that all components are initialized.""" + # Input projection + assert hasattr(encoder_small, "input_proj") + assert len(encoder_small.input_proj) == 3 + + # Encoder + assert hasattr(encoder_small, "encoder") + assert len(encoder_small.encoder) == 1 # Only index 2 by default + + # FPN components + assert hasattr(encoder_small, "lateral_convs") + assert hasattr(encoder_small, "fpn_blocks") + assert len(encoder_small.lateral_convs) == 2 + assert len(encoder_small.fpn_blocks) == 2 + + # PAN components + assert hasattr(encoder_small, "downsample_convs") + assert hasattr(encoder_small, "pan_blocks") + assert len(encoder_small.downsample_convs) == 2 + assert len(encoder_small.pan_blocks) == 2 + + def test_out_channels(self, encoder_small): + """Test output channels property.""" + assert encoder_small.out_channels == [128, 128, 128] + assert encoder_small.out_strides == [8, 16, 32] + + def test_forward(self, encoder_small): + """Test forward pass.""" + batch_size = 2 + feats = [ + torch.randn(batch_size, 128, 80, 80), + torch.randn(batch_size, 256, 40, 40), + torch.randn(batch_size, 512, 20, 20), + ] + + outputs = encoder_small(feats) + + assert len(outputs) == 3 + assert outputs[0].shape == (batch_size, 128, 80, 80) + assert outputs[1].shape == (batch_size, 128, 40, 40) + assert outputs[2].shape == (batch_size, 128, 20, 20) + + def test_forward_different_batch_sizes(self, encoder_small): + """Test forward with different batch sizes.""" + for batch_size in [1, 2, 4]: + feats = [ + torch.randn(batch_size, 128, 80, 80), + torch.randn(batch_size, 256, 40, 40), + torch.randn(batch_size, 512, 20, 20), + ] + + outputs = encoder_small(feats) + + assert len(outputs) == 3 + assert outputs[0].shape[0] == batch_size + + def test_forward_mismatched_features_raises(self, encoder_small): + """Test that mismatched feature count raises error.""" + feats = [ + torch.randn(2, 128, 80, 80), + torch.randn(2, 256, 40, 40), + # Missing third feature + ] + + with pytest.raises(ValueError, match="Input feature size"): + encoder_small(feats) + + def test_forward_training_mode(self, encoder_small): + """Test forward in training mode.""" + encoder_small.train() + feats = [ + torch.randn(2, 128, 80, 80), + torch.randn(2, 256, 40, 40), + torch.randn(2, 512, 20, 20), + ] + + outputs = encoder_small(feats) + assert len(outputs) == 3 + + def test_forward_eval_mode(self, encoder_small): + """Test forward in eval mode.""" + encoder_small.eval() + feats = [ + torch.randn(2, 128, 80, 80), + torch.randn(2, 256, 40, 40), + torch.randn(2, 512, 20, 20), + ] + + with torch.no_grad(): + outputs = encoder_small(feats) + + assert len(outputs) == 3 + + def test_init_weights_with_eval_size(self, encoder_with_eval_size): + """Test weight initialization with eval spatial size.""" + encoder_with_eval_size.init_weights() + + # Check that position embedding is created for encoder index + assert hasattr(encoder_with_eval_size, "pos_embed2") + + def test_build_2d_sincos_position_embedding(self): + """Test 2D sin-cos position embedding generation.""" + w, h = 20, 20 + embed_dim = 256 + temperature = 10000.0 + + pos_embed = HybridEncoderModule.build_2d_sincos_position_embedding(w, h, embed_dim, temperature) + + assert pos_embed.shape == (1, w * h, embed_dim) + assert not torch.isnan(pos_embed).any() + + def test_build_2d_sincos_position_embedding_invalid_dim(self): + """Test that invalid embed_dim raises error.""" + with pytest.raises(ValueError, match="divisible by 4"): + HybridEncoderModule.build_2d_sincos_position_embedding(10, 10, 255, 10000.0) + + def test_no_encoder_layers(self): + """Test encoder with no encoder layers.""" + encoder = HybridEncoderModule( + in_channels=[128, 256, 512], + hidden_dim=128, + num_encoder_layers=0, + ) + + feats = [ + torch.randn(2, 128, 80, 80), + torch.randn(2, 256, 40, 40), + torch.randn(2, 512, 20, 20), + ] + + outputs = encoder(feats) + assert len(outputs) == 3 + + def test_multiple_encoder_indices(self): + """Test encoder with multiple encoder indices.""" + encoder = HybridEncoderModule( + in_channels=[128, 256, 512], + hidden_dim=128, + use_encoder_idx=[1, 2], + num_encoder_layers=1, + ) + + assert len(encoder.encoder) == 2 + + feats = [ + torch.randn(2, 128, 80, 80), + torch.randn(2, 256, 40, 40), + torch.randn(2, 512, 20, 20), + ] + + outputs = encoder(feats) + assert len(outputs) == 3 + + def test_custom_activation(self): + """Test encoder with custom activation.""" + encoder = HybridEncoderModule( + in_channels=[128, 256, 512], + hidden_dim=128, + activation=torch.nn.ReLU, + enc_activation=torch.nn.ReLU, + ) + + feats = [ + torch.randn(2, 128, 80, 80), + torch.randn(2, 256, 40, 40), + torch.randn(2, 512, 20, 20), + ] + + outputs = encoder(feats) + assert len(outputs) == 3 + + def test_depth_mult(self): + """Test encoder with depth multiplier.""" + encoder = HybridEncoderModule( + in_channels=[128, 256, 512], + hidden_dim=128, + depth_mult=0.5, + ) + + feats = [ + torch.randn(2, 128, 80, 80), + torch.randn(2, 256, 40, 40), + torch.randn(2, 512, 20, 20), + ] + + outputs = encoder(feats) + assert len(outputs) == 3 + + def test_expansion(self): + """Test encoder with expansion factor.""" + encoder = HybridEncoderModule( + in_channels=[128, 256, 512], + hidden_dim=128, + expansion=0.5, + ) + + feats = [ + torch.randn(2, 128, 80, 80), + torch.randn(2, 256, 40, 40), + torch.randn(2, 512, 20, 20), + ] + + outputs = encoder(feats) + assert len(outputs) == 3 + + +class TestHybridEncoderFactory: + """Test class for HybridEncoder factory.""" + + @pytest.mark.parametrize( + "model_name", + [ + "rtdetr_18", + "rtdetr_50", + "rtdetr_101", + ], + ) + def test_factory_creates_module(self, model_name): + """Test that factory creates HybridEncoderModule.""" + encoder = HybridEncoder(model_name) + assert isinstance(encoder, HybridEncoderModule) + + def test_factory_rtdetr_18_config(self): + """Test RTDETR-18 configuration.""" + encoder = HybridEncoder("rtdetr_18") + assert encoder.in_channels == [128, 256, 512] + + def test_factory_rtdetr_50_config(self): + """Test RTDETR-50 configuration (defaults).""" + encoder = HybridEncoder("rtdetr_50") + # Uses default values + assert encoder.hidden_dim == 256 + + def test_factory_rtdetr_101_config(self): + """Test RTDETR-101 configuration.""" + encoder = HybridEncoder("rtdetr_101") + assert encoder.hidden_dim == 384 + assert encoder.in_channels == [512, 1024, 2048] + + def test_factory_with_eval_size(self): + """Test factory with evaluation spatial size.""" + encoder = HybridEncoder("rtdetr_18", eval_spatial_size=(640, 640)) + assert encoder.eval_spatial_size == (640, 640) + + def test_factory_invalid_model_raises(self): + """Test that invalid model name raises error.""" + with pytest.raises(KeyError, match="not supported"): + HybridEncoder("invalid_model") + + +class TestGradientFlow: + """Test gradient flow through HybridEncoder.""" + + @pytest.fixture + def encoder(self): + """Create encoder for gradient testing.""" + return HybridEncoderModule( + in_channels=[128, 256, 512], + hidden_dim=128, + num_encoder_layers=1, + ) + + def test_gradient_flow(self, encoder): + """Test that gradients flow through the encoder.""" + encoder.train() + feats = [ + torch.randn(2, 128, 40, 40, requires_grad=True), + torch.randn(2, 256, 20, 20, requires_grad=True), + torch.randn(2, 512, 10, 10, requires_grad=True), + ] + + outputs = encoder(feats) + loss = sum(out.sum() for out in outputs) + loss.backward() + + # Check gradients exist for all inputs + for feat in feats: + assert feat.grad is not None + assert not torch.all(feat.grad == 0) diff --git a/library/tests/unit/backend/native/models/detection/test_deimv2.py b/library/tests/unit/backend/native/models/detection/test_deimv2.py new file mode 100644 index 00000000000..37801701fd5 --- /dev/null +++ b/library/tests/unit/backend/native/models/detection/test_deimv2.py @@ -0,0 +1,443 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for DEIMV2 detection model.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from otx.backend.native.models.base import DataInputParams +from otx.backend.native.models.detection.deimv2 import DEIMV2 +from otx.data.entity.torch import OTXPredBatch + + +class TestDEIMV2: + """Test class for DEIMV2 detection model.""" + + @pytest.mark.parametrize( + "model_name", + [ + "deimv2_s", + "deimv2_m", + "deimv2_l", + "deimv2_x", + ], + ) + def test_init(self, model_name: str) -> None: + """Test DEIMV2 model initialization.""" + model = DEIMV2( + model_name=model_name, + label_info=3, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + assert model.model_name == model_name + assert model.num_classes == 3 + assert model.data_input_params.input_size == (640, 640) + assert model.input_size_multiplier == 32 + assert model_name in model._pretrained_weights + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_create_model(self, mock_load_checkpoint: MagicMock) -> None: + """Test DEIMV2 model creation.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=10, + ) + created_model = model._create_model() + assert created_model is not None + assert isinstance(created_model, torch.nn.Module) + + # Check if the model has the expected components + assert hasattr(created_model, "backbone") + assert hasattr(created_model, "encoder") + assert hasattr(created_model, "decoder") + assert hasattr(created_model, "criterion") + assert hasattr(created_model, "num_classes") + assert created_model.num_classes == 10 + + # Verify load_checkpoint was called (may be called multiple times for backbone and model) + assert mock_load_checkpoint.call_count >= 1 + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_backbone_lr_mapping(self, mock_load_checkpoint: MagicMock) -> None: + """Test that backbone learning rate mapping works correctly.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=5, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + created_model = model._create_model() + + # Check optimizer configuration exists + assert hasattr(created_model, "optimizer_configuration") + assert len(created_model.optimizer_configuration) == 3 + + @pytest.mark.parametrize( + ("model_name", "expected_lr"), + [ + ("deimv2_x", 0.00001), + ("deimv2_l", 0.0000125), + ("deimv2_m", 0.000025), + ("deimv2_s", 0.000025), + ], + ) + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_backbone_lr_values(self, mock_load_checkpoint: MagicMock, model_name: str, expected_lr: float) -> None: + """Test that backbone learning rates are correctly set for each model variant.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name=model_name, + label_info=5, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + created_model = model._create_model() + + # Check that the first optimizer config has the expected backbone lr + assert created_model.optimizer_configuration[0]["lr"] == expected_lr + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_loss_computation(self, mock_load_checkpoint: MagicMock, fxt_data_module) -> None: + """Test DEIMV2 loss computation in training mode.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=10, + ) + + # Get data batch + data = next(iter(fxt_data_module.train_dataloader())) + data.images = torch.randn(2, 3, 640, 640) + + # Set model to training mode + model.train() + + # Forward pass should return loss dictionary + output = model(data) + + # Check that output contains expected DEIM loss components + assert isinstance(output, dict) + expected_losses = ["loss_vfl", "loss_bbox", "loss_giou", "loss_fgl", "loss_mal"] + + for loss_name in expected_losses: + assert loss_name in output + assert isinstance(output[loss_name], torch.Tensor) + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + @pytest.mark.parametrize( + "model_name", + [ + "deimv2_s", + "deimv2_m", + "deimv2_l", + "deimv2_x", + ], + ) + def test_predict(self, mock_load_checkpoint: MagicMock, model_name: str, fxt_data_module) -> None: + """Test DEIMV2 prediction in evaluation mode.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name=model_name, + label_info=3, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + + # Get data batch + data = next(iter(fxt_data_module.train_dataloader())) + data.images = torch.randn(2, 3, 640, 640) + + # Set model to evaluation mode + model.eval() + + # Forward pass should return predictions + output = model(data) + + # Check that output is OTXPredBatch + assert isinstance(output, OTXPredBatch) + assert output.batch_size == 2 + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + @pytest.mark.parametrize( + "model_name", + [ + "deimv2_s", + "deimv2_m", + "deimv2_l", + "deimv2_x", + ], + ) + def test_export(self, mock_load_checkpoint: MagicMock, model_name: str) -> None: + """Test DEIMV2 export functionality.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name=model_name, + label_info=3, + ) + + # Set model to evaluation mode + model.eval() + + # Test export forward pass + output = model.forward_for_tracing(torch.randn(1, 3, 640, 640)) + assert len(output) == 3 # Should return boxes, scores, labels + + # Test with explain mode + model.explain_mode = True + output = model.forward_for_tracing(torch.randn(1, 3, 640, 640)) + assert len(output) == 5 # Should return boxes, scores, labels, saliency_map, feature_vector + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_multi_scale_training(self, mock_load_checkpoint: MagicMock) -> None: + """Test DEIMV2 with multi-scale training enabled.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=3, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + multi_scale=True, + ) + + # Multi-scale should be created in the model + created_model = model._create_model() + assert isinstance(created_model.multi_scale, list) + assert len(created_model.multi_scale) > 0 + + def test_torch_compile_integration(self) -> None: + """Test DEIMV2 with torch compile enabled.""" + model = DEIMV2( + model_name="deimv2_s", + label_info=3, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + torch_compile=True, + ) + + # Check that torch compile is enabled + assert model.torch_compile is True + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_gradient_checkpointing(self, mock_load_checkpoint: MagicMock) -> None: + """Test DEIMV2 with gradient checkpointing enabled.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=3, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + gradient_checkpointing=True, + ) + + # Check that gradient_checkpointing is enabled + assert model.gradient_checkpointing is True + + # Create model and verify backbone has gradient checkpointing + created_model = model._create_model() + assert hasattr(created_model.backbone, "dinov3") + assert created_model.backbone.dinov3.gradient_checkpointing is True + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_weight_dict_configuration(self, mock_load_checkpoint: MagicMock) -> None: + """Test that the weight dictionary is properly configured.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=5, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + + created_model = model._create_model() + criterion = created_model.criterion + + # Check that weight dict contains expected keys + expected_weights = ["loss_vfl", "loss_bbox", "loss_giou", "loss_fgl", "loss_ddf", "loss_mal"] + for weight_key in expected_weights: + assert weight_key in criterion.weight_dict + + # Check specific weight values + assert criterion.weight_dict["loss_vfl"] == 1 + assert criterion.weight_dict["loss_bbox"] == 5 + assert criterion.weight_dict["loss_giou"] == 2 + assert criterion.weight_dict["loss_fgl"] == 0.15 + assert criterion.weight_dict["loss_ddf"] == 1.5 + assert criterion.weight_dict["loss_mal"] == 1.0 + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_criterion_parameters(self, mock_load_checkpoint: MagicMock) -> None: + """Test that the criterion is configured with correct parameters.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=10, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + + created_model = model._create_model() + criterion = created_model.criterion + + # Check criterion parameters + assert criterion.alpha == 0.75 + assert criterion.gamma == 1.5 + assert criterion.reg_max == 32 + assert criterion.num_classes == 10 + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_dummy_input_generation(self, mock_load_checkpoint: MagicMock) -> None: + """Test dummy input generation for different batch sizes.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=3, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + + # Test with different batch sizes + for batch_size in [1, 2, 4]: + dummy_input = model.get_dummy_input(batch_size) + assert len(dummy_input.images) == batch_size + assert dummy_input.images[0].shape == (3, 640, 640) + + def test_model_properties(self) -> None: + """Test various model properties.""" + model = DEIMV2( + model_name="deimv2_m", + label_info=20, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + + # Test input size multiplier + assert model.input_size_multiplier == 32 + + # Test pretrained weights availability + assert model.model_name in model._pretrained_weights + assert isinstance(model._pretrained_weights[model.model_name], str) + + def test_default_preprocessing_params(self) -> None: + """Test default preprocessing parameters.""" + model = DEIMV2( + model_name="deimv2_s", + label_info=3, + ) + + default_params = model._default_preprocessing_params + assert isinstance(default_params, DataInputParams) + assert default_params.input_size == (640, 640) + assert default_params.mean == (123.675, 116.280, 103.530) + assert default_params.std == (58.395, 57.120, 57.375) + + def test_inheritance_from_deim_dfine(self) -> None: + """Test that DEIMV2 properly inherits from DEIMDFine.""" + from otx.backend.native.models.detection.deim import DEIMDFine + + model = DEIMV2( + model_name="deimv2_s", + label_info=3, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + + # Check inheritance + assert isinstance(model, DEIMDFine) + + # Check that it has inherited methods + assert hasattr(model, "forward") + assert hasattr(model, "training_step") + assert hasattr(model, "validation_step") + assert hasattr(model, "predict_step") + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_dinov3_backbone(self, mock_load_checkpoint: MagicMock) -> None: + """Test that DEIMV2 uses DINOv3STA backbone.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=5, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + + created_model = model._create_model() + + # Check that backbone is DINOv3STAsModule + from otx.backend.native.models.detection.backbones.dinov3sta import DINOv3STAsModule + + assert isinstance(created_model.backbone, DINOv3STAsModule) + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_hybrid_encoder(self, mock_load_checkpoint: MagicMock) -> None: + """Test that DEIMV2 uses HybridEncoder.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=5, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + + created_model = model._create_model() + + # Check that encoder is HybridEncoderModule + from otx.backend.native.models.detection.necks.dfine_hybrid_encoder import HybridEncoderModule + + assert isinstance(created_model.encoder, HybridEncoderModule) + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_deim_transformer_decoder(self, mock_load_checkpoint: MagicMock) -> None: + """Test that DEIMV2 uses DEIMTransformer decoder.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=5, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + + created_model = model._create_model() + + # Check that decoder is DEIMTransformerModule + from otx.backend.native.models.detection.heads.deim_decoder import DEIMTransformerModule + + assert isinstance(created_model.decoder, DEIMTransformerModule) + + @patch("otx.backend.native.models.detection.deimv2.load_checkpoint") + def test_optimizer_configuration_structure(self, mock_load_checkpoint: MagicMock) -> None: + """Test optimizer configuration has proper structure.""" + mock_load_checkpoint.return_value = None + + model = DEIMV2( + model_name="deimv2_s", + label_info=5, + data_input_params=DataInputParams((640, 640), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + ) + + created_model = model._create_model() + opt_config = created_model.optimizer_configuration + + # Should have 3 configurations + assert len(opt_config) == 3 + + # First config: dinov3 params excluding norm/bn/bias + assert "params" in opt_config[0] + assert "lr" in opt_config[0] + assert "dinov3" in opt_config[0]["params"] + + # Second config: dinov3 norm/bn/bias with weight_decay=0 + assert "params" in opt_config[1] + assert "lr" in opt_config[1] + assert opt_config[1].get("weight_decay") == 0.0 + + # Third config: sta/encoder/decoder norm/bn/bias with weight_decay=0 + assert "params" in opt_config[2] + assert opt_config[2].get("weight_decay") == 0.0