|
| 1 | +# Copyright (C) 2025 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Common components shared between D-FINE and DEIM transformer decoders.""" |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +from functools import partial |
| 9 | +from typing import Callable |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.nn.functional as f |
| 13 | +from torch import Tensor, nn |
| 14 | +from torch.nn import init |
| 15 | +import torchvision |
| 16 | + |
| 17 | +from otx.backend.native.models.common.layers.transformer_layers import MLP |
| 18 | +from otx.backend.native.models.utils.weight_init import bias_init_with_prob |
| 19 | +from otx.backend.native.models.common.utils.utils import inverse_sigmoid |
| 20 | + |
| 21 | + |
| 22 | +class RMSNorm(nn.Module): |
| 23 | + """Root Mean Square Layer Normalization. |
| 24 | +
|
| 25 | + Args: |
| 26 | + dim (int): The number of features in the input. |
| 27 | + eps (float, optional): A value added for numerical stability. Defaults to 1e-6. |
| 28 | + """ |
| 29 | + |
| 30 | + def __init__(self, dim: int, eps: float = 1e-6) -> None: |
| 31 | + super().__init__() |
| 32 | + self.dim = dim |
| 33 | + self.eps = eps |
| 34 | + self.scale = nn.Parameter(torch.ones(dim)) |
| 35 | + |
| 36 | + def _norm(self, x: Tensor) -> Tensor: |
| 37 | + """Compute RMS normalization.""" |
| 38 | + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| 39 | + |
| 40 | + def forward(self, x: Tensor) -> Tensor: |
| 41 | + """Forward pass of RMSNorm. |
| 42 | +
|
| 43 | + Args: |
| 44 | + x (Tensor): Input tensor. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + Tensor: Normalized and scaled tensor. |
| 48 | + """ |
| 49 | + output = self._norm(x.float()).type_as(x) |
| 50 | + return output * self.scale |
| 51 | + |
| 52 | + def extra_repr(self) -> str: |
| 53 | + """Extra representation string.""" |
| 54 | + return f"dim={self.dim}, eps={self.eps}" |
| 55 | + |
| 56 | + def reset_parameters(self) -> None: |
| 57 | + nn.init.constant_(self.scale, 1) |
| 58 | + |
| 59 | + |
| 60 | +class Gate(nn.Module): |
| 61 | + """Target Gating Layer with learnable fusion weights. |
| 62 | +
|
| 63 | + This module combines two input tensors using learnable gating weights, |
| 64 | + allowing the model to dynamically control information flow. |
| 65 | +
|
| 66 | + Args: |
| 67 | + d_model (int): The number of expected features in the input. |
| 68 | + use_rmsnorm (bool, optional): Whether to use RMSNorm instead of LayerNorm. Defaults to False. |
| 69 | + """ |
| 70 | + |
| 71 | + def __init__(self, d_model: int, use_rmsnorm: bool = False) -> None: |
| 72 | + super().__init__() |
| 73 | + self.gate = nn.Linear(2 * d_model, 2 * d_model) |
| 74 | + bias = bias_init_with_prob(0.5) |
| 75 | + init.constant_(self.gate.bias, bias) |
| 76 | + init.constant_(self.gate.weight, 0) |
| 77 | + self.norm = RMSNorm(d_model) if use_rmsnorm else nn.LayerNorm(d_model) |
| 78 | + |
| 79 | + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: |
| 80 | + """Forward pass of gating mechanism. |
| 81 | +
|
| 82 | + Args: |
| 83 | + x1 (Tensor): First input tensor. |
| 84 | + x2 (Tensor): Second input tensor. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + Tensor: Gated and normalized output tensor. |
| 88 | + """ |
| 89 | + gate_input = torch.cat([x1, x2], dim=-1) |
| 90 | + gates = torch.sigmoid(self.gate(gate_input)) |
| 91 | + gate1, gate2 = gates.chunk(2, dim=-1) |
| 92 | + return self.norm(gate1 * x1 + gate2 * x2) |
| 93 | + |
| 94 | + |
| 95 | +class Integral(nn.Module): |
| 96 | + """Convert distribution predictions to continuous bounding box coordinates. |
| 97 | +
|
| 98 | + This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`, |
| 99 | + where Pr(n) is the softmax probability vector representing the discrete |
| 100 | + distribution, and W(n) is the non-uniform Weighting Function. |
| 101 | +
|
| 102 | + Args: |
| 103 | + reg_max (int, optional): Maximum number of discrete bins. Defaults to 32. |
| 104 | + """ |
| 105 | + |
| 106 | + def __init__(self, reg_max: int = 32) -> None: |
| 107 | + super().__init__() |
| 108 | + self.reg_max = reg_max |
| 109 | + |
| 110 | + def forward(self, x: Tensor, box_distance_weight: Tensor) -> Tensor: |
| 111 | + """Convert distribution to coordinates. |
| 112 | +
|
| 113 | + Args: |
| 114 | + x (Tensor): Distribution predictions of shape (..., 4*(reg_max+1)). |
| 115 | + box_distance_weight (Tensor): Weighting function for integration. |
| 116 | +
|
| 117 | + Returns: |
| 118 | + Tensor: Continuous bounding box coordinates of shape (..., 4). |
| 119 | + """ |
| 120 | + shape = x.shape |
| 121 | + x = f.softmax(x.reshape(-1, self.reg_max + 1), dim=1) |
| 122 | + x = f.linear(x, box_distance_weight).reshape(-1, 4) |
| 123 | + return x.reshape([*list(shape[:-1]), -1]) |
| 124 | + |
| 125 | + |
| 126 | +class LQE(nn.Module): |
| 127 | + """Localization Quality Estimation module. |
| 128 | +
|
| 129 | + Estimates the quality of predicted bounding boxes by analyzing the |
| 130 | + distribution statistics of corner predictions. |
| 131 | +
|
| 132 | + Args: |
| 133 | + k (int): Number of top-k edge points to consider. |
| 134 | + hidden_dim (int): Hidden dimension for the MLP. |
| 135 | + num_layers (int): Number of MLP layers. |
| 136 | + reg_max (int): Maximum number of discrete bins for bbox regression. |
| 137 | + act (Callable[..., nn.Module], optional): Activation function. Defaults to ReLU. |
| 138 | + """ |
| 139 | + |
| 140 | + def __init__( |
| 141 | + self, |
| 142 | + k: int, |
| 143 | + hidden_dim: int, |
| 144 | + num_layers: int, |
| 145 | + reg_max: int, |
| 146 | + activation: Callable[..., nn.Module] = partial(nn.ReLU, inplace=True), |
| 147 | + ) -> None: |
| 148 | + super().__init__() |
| 149 | + self.k = k |
| 150 | + self.reg_max = reg_max |
| 151 | + self.reg_conf = MLP( |
| 152 | + input_dim=4 * (k + 1), |
| 153 | + hidden_dim=hidden_dim, |
| 154 | + output_dim=1, |
| 155 | + num_layers=num_layers, |
| 156 | + activation=activation, |
| 157 | + ) |
| 158 | + init.constant_(self.reg_conf.layers[-1].bias, 0) |
| 159 | + init.constant_(self.reg_conf.layers[-1].weight, 0) |
| 160 | + |
| 161 | + def forward(self, scores: Tensor, pred_corners: Tensor) -> Tensor: |
| 162 | + """Estimate localization quality and adjust scores. |
| 163 | +
|
| 164 | + Args: |
| 165 | + scores (Tensor): Predicted classification scores of shape (B, N, C). |
| 166 | + pred_corners (Tensor): Predicted bounding box corners of shape (B, N, 4*(reg_max+1)). |
| 167 | +
|
| 168 | + Returns: |
| 169 | + Tensor: Quality-adjusted scores of shape (B, N, C). |
| 170 | + """ |
| 171 | + b, num_pred, _ = pred_corners.size() |
| 172 | + prob = f.softmax(pred_corners.reshape(b, num_pred, 4, self.reg_max + 1), dim=-1) |
| 173 | + prob_topk, _ = prob.topk(self.k, dim=-1) |
| 174 | + stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1) |
| 175 | + quality_score = self.reg_conf(stat.reshape(b, num_pred, -1)) |
| 176 | + return scores + quality_score |
| 177 | + |
| 178 | + |
| 179 | +class SwiGLUFFN(nn.Module): |
| 180 | + def __init__( |
| 181 | + self, |
| 182 | + in_features: int, |
| 183 | + hidden_features: int, |
| 184 | + out_features: int, |
| 185 | + bias: bool = True, |
| 186 | + ) -> None: |
| 187 | + """ |
| 188 | + Initializes SwiGLUFFN module. |
| 189 | +
|
| 190 | + Args: |
| 191 | + in_features (int): Number of input features. |
| 192 | + hidden_features (int): Number of hidden features. |
| 193 | + out_features (int): Number of output features. |
| 194 | + bias (bool, optional): Whether to use bias in linear layers. Defaults to True. |
| 195 | + """ |
| 196 | + super().__init__() |
| 197 | + out_features = out_features or in_features |
| 198 | + hidden_features = hidden_features or in_features |
| 199 | + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) |
| 200 | + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) |
| 201 | + self._reset_parameters() |
| 202 | + |
| 203 | + def _reset_parameters(self): |
| 204 | + init.xavier_uniform_(self.w12.weight) |
| 205 | + init.constant_(self.w12.bias, 0) |
| 206 | + init.xavier_uniform_(self.w3.weight) |
| 207 | + init.constant_(self.w3.bias, 0) |
| 208 | + |
| 209 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 210 | + """ |
| 211 | + Forward pass of the SwiGLUFFN module. |
| 212 | +
|
| 213 | + Args: |
| 214 | + x (torch.Tensor): Input tensor. |
| 215 | +
|
| 216 | + Returns: |
| 217 | + torch.Tensor: Output tensor. |
| 218 | + """ |
| 219 | + x12 = self.w12(x) |
| 220 | + x1, x2 = x12.chunk(2, dim=-1) |
| 221 | + hidden = f.silu(x1) * x2 |
| 222 | + return self.w3(hidden) |
| 223 | + |
| 224 | + |
| 225 | +def get_contrastive_denoising_training_group( |
| 226 | + targets: list[dict[str, torch.Tensor]], |
| 227 | + num_classes: int, |
| 228 | + num_queries: int, |
| 229 | + class_embed: torch.nn.Module, |
| 230 | + num_denoising: int = 100, |
| 231 | + label_noise_ratio: float = 0.5, |
| 232 | + box_noise_scale: float = 1.0, |
| 233 | +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]] | tuple[None, None, None, None]: |
| 234 | + """Generate contrastive denoising training group. |
| 235 | +
|
| 236 | + Args: |
| 237 | + targets (List[Dict[str, torch.Tensor]]): List of target dictionaries. |
| 238 | + num_classes (int): Number of classes. |
| 239 | + num_queries (int): Number of queries. |
| 240 | + class_embed (torch.nn.Module): Class embedding module. |
| 241 | + num_denoising (int, optional): Number of denoising queries. Defaults to 100. |
| 242 | + label_noise_ratio (float, optional): Ratio of label noise. Defaults to 0.5. |
| 243 | + box_noise_scale (float, optional): Scale of box noise. Defaults to 1.0. |
| 244 | +
|
| 245 | + Returns: |
| 246 | + Tuple[Tensor,Tensor,Tensor, dict[str, Tensor]] | tuple[None,None,None,None]: |
| 247 | + Tuple containing input query class, input query bbox, attention mask, and denoising metadata. |
| 248 | + """ |
| 249 | + num_gts = [len(t["labels"]) for t in targets] |
| 250 | + device = targets[0]["labels"].device |
| 251 | + |
| 252 | + max_gt_num = max(num_gts) |
| 253 | + if max_gt_num == 0: |
| 254 | + return None, None, None, None |
| 255 | + |
| 256 | + num_group = num_denoising // max_gt_num |
| 257 | + num_group = 1 if num_group == 0 else num_group |
| 258 | + # pad gt to max_num of a batch |
| 259 | + bs = len(num_gts) |
| 260 | + |
| 261 | + input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device) |
| 262 | + input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device) |
| 263 | + pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device) |
| 264 | + |
| 265 | + for i in range(bs): |
| 266 | + num_gt = num_gts[i] |
| 267 | + if num_gt > 0: |
| 268 | + input_query_class[i, :num_gt] = targets[i]["labels"] |
| 269 | + input_query_bbox[i, :num_gt] = targets[i]["boxes"] |
| 270 | + pad_gt_mask[i, :num_gt] = 1 |
| 271 | + # each group has positive and negative queries. |
| 272 | + input_query_class = input_query_class.tile([1, 2 * num_group]) |
| 273 | + input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1]) |
| 274 | + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group]) |
| 275 | + # positive and negative mask |
| 276 | + negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device) |
| 277 | + negative_gt_mask[:, max_gt_num:] = 1 |
| 278 | + negative_gt_mask = negative_gt_mask.tile([1, num_group, 1]) |
| 279 | + positive_gt_mask = 1 - negative_gt_mask |
| 280 | + # contrastive denoising training positive index |
| 281 | + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask |
| 282 | + dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] |
| 283 | + dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts]) |
| 284 | + # total denoising queries |
| 285 | + num_denoising = int(max_gt_num * 2 * num_group) |
| 286 | + |
| 287 | + if label_noise_ratio > 0: |
| 288 | + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) |
| 289 | + # randomly put a new one here |
| 290 | + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) |
| 291 | + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) |
| 292 | + |
| 293 | + if box_noise_scale > 0: |
| 294 | + known_bbox = torchvision.ops.box_convert(input_query_bbox, in_fmt="cxcywh", out_fmt="xyxy") |
| 295 | + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale |
| 296 | + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 |
| 297 | + rand_part = torch.rand_like(input_query_bbox) |
| 298 | + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) |
| 299 | + rand_part *= rand_sign |
| 300 | + known_bbox += rand_part * diff |
| 301 | + known_bbox.clip_(min=0.0, max=1.0) |
| 302 | + input_query_bbox = torchvision.ops.box_convert(known_bbox, in_fmt="xyxy", out_fmt="cxcywh") |
| 303 | + input_query_bbox = inverse_sigmoid(input_query_bbox) |
| 304 | + |
| 305 | + input_query_class = class_embed(input_query_class) |
| 306 | + |
| 307 | + tgt_size = num_denoising + num_queries |
| 308 | + attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device) |
| 309 | + # match query cannot see the reconstruction |
| 310 | + attn_mask[num_denoising:, :num_denoising] = True |
| 311 | + |
| 312 | + # reconstruct cannot see each other |
| 313 | + for i in range(num_group): |
| 314 | + if i == 0: |
| 315 | + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True |
| 316 | + if i == num_group - 1: |
| 317 | + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * i * 2] = True |
| 318 | + else: |
| 319 | + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True |
| 320 | + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * 2 * i] = True |
| 321 | + |
| 322 | + dn_meta = { |
| 323 | + "dn_positive_idx": dn_positive_idx, |
| 324 | + "dn_num_group": num_group, |
| 325 | + "dn_num_split": [num_denoising, num_queries], |
| 326 | + } |
| 327 | + |
| 328 | + return input_query_class, input_query_bbox, attn_mask, dn_meta |
0 commit comments