Skip to content

Commit 1b4fa5c

Browse files
committed
clean up 2
1 parent 48ece32 commit 1b4fa5c

File tree

1 file changed

+328
-0
lines changed

1 file changed

+328
-0
lines changed
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Common components shared between D-FINE and DEIM transformer decoders."""
5+
6+
from __future__ import annotations
7+
8+
from functools import partial
9+
from typing import Callable
10+
11+
import torch
12+
import torch.nn.functional as f
13+
from torch import Tensor, nn
14+
from torch.nn import init
15+
import torchvision
16+
17+
from otx.backend.native.models.common.layers.transformer_layers import MLP
18+
from otx.backend.native.models.utils.weight_init import bias_init_with_prob
19+
from otx.backend.native.models.common.utils.utils import inverse_sigmoid
20+
21+
22+
class RMSNorm(nn.Module):
23+
"""Root Mean Square Layer Normalization.
24+
25+
Args:
26+
dim (int): The number of features in the input.
27+
eps (float, optional): A value added for numerical stability. Defaults to 1e-6.
28+
"""
29+
30+
def __init__(self, dim: int, eps: float = 1e-6) -> None:
31+
super().__init__()
32+
self.dim = dim
33+
self.eps = eps
34+
self.scale = nn.Parameter(torch.ones(dim))
35+
36+
def _norm(self, x: Tensor) -> Tensor:
37+
"""Compute RMS normalization."""
38+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
39+
40+
def forward(self, x: Tensor) -> Tensor:
41+
"""Forward pass of RMSNorm.
42+
43+
Args:
44+
x (Tensor): Input tensor.
45+
46+
Returns:
47+
Tensor: Normalized and scaled tensor.
48+
"""
49+
output = self._norm(x.float()).type_as(x)
50+
return output * self.scale
51+
52+
def extra_repr(self) -> str:
53+
"""Extra representation string."""
54+
return f"dim={self.dim}, eps={self.eps}"
55+
56+
def reset_parameters(self) -> None:
57+
nn.init.constant_(self.scale, 1)
58+
59+
60+
class Gate(nn.Module):
61+
"""Target Gating Layer with learnable fusion weights.
62+
63+
This module combines two input tensors using learnable gating weights,
64+
allowing the model to dynamically control information flow.
65+
66+
Args:
67+
d_model (int): The number of expected features in the input.
68+
use_rmsnorm (bool, optional): Whether to use RMSNorm instead of LayerNorm. Defaults to False.
69+
"""
70+
71+
def __init__(self, d_model: int, use_rmsnorm: bool = False) -> None:
72+
super().__init__()
73+
self.gate = nn.Linear(2 * d_model, 2 * d_model)
74+
bias = bias_init_with_prob(0.5)
75+
init.constant_(self.gate.bias, bias)
76+
init.constant_(self.gate.weight, 0)
77+
self.norm = RMSNorm(d_model) if use_rmsnorm else nn.LayerNorm(d_model)
78+
79+
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
80+
"""Forward pass of gating mechanism.
81+
82+
Args:
83+
x1 (Tensor): First input tensor.
84+
x2 (Tensor): Second input tensor.
85+
86+
Returns:
87+
Tensor: Gated and normalized output tensor.
88+
"""
89+
gate_input = torch.cat([x1, x2], dim=-1)
90+
gates = torch.sigmoid(self.gate(gate_input))
91+
gate1, gate2 = gates.chunk(2, dim=-1)
92+
return self.norm(gate1 * x1 + gate2 * x2)
93+
94+
95+
class Integral(nn.Module):
96+
"""Convert distribution predictions to continuous bounding box coordinates.
97+
98+
This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
99+
where Pr(n) is the softmax probability vector representing the discrete
100+
distribution, and W(n) is the non-uniform Weighting Function.
101+
102+
Args:
103+
reg_max (int, optional): Maximum number of discrete bins. Defaults to 32.
104+
"""
105+
106+
def __init__(self, reg_max: int = 32) -> None:
107+
super().__init__()
108+
self.reg_max = reg_max
109+
110+
def forward(self, x: Tensor, box_distance_weight: Tensor) -> Tensor:
111+
"""Convert distribution to coordinates.
112+
113+
Args:
114+
x (Tensor): Distribution predictions of shape (..., 4*(reg_max+1)).
115+
box_distance_weight (Tensor): Weighting function for integration.
116+
117+
Returns:
118+
Tensor: Continuous bounding box coordinates of shape (..., 4).
119+
"""
120+
shape = x.shape
121+
x = f.softmax(x.reshape(-1, self.reg_max + 1), dim=1)
122+
x = f.linear(x, box_distance_weight).reshape(-1, 4)
123+
return x.reshape([*list(shape[:-1]), -1])
124+
125+
126+
class LQE(nn.Module):
127+
"""Localization Quality Estimation module.
128+
129+
Estimates the quality of predicted bounding boxes by analyzing the
130+
distribution statistics of corner predictions.
131+
132+
Args:
133+
k (int): Number of top-k edge points to consider.
134+
hidden_dim (int): Hidden dimension for the MLP.
135+
num_layers (int): Number of MLP layers.
136+
reg_max (int): Maximum number of discrete bins for bbox regression.
137+
act (Callable[..., nn.Module], optional): Activation function. Defaults to ReLU.
138+
"""
139+
140+
def __init__(
141+
self,
142+
k: int,
143+
hidden_dim: int,
144+
num_layers: int,
145+
reg_max: int,
146+
activation: Callable[..., nn.Module] = partial(nn.ReLU, inplace=True),
147+
) -> None:
148+
super().__init__()
149+
self.k = k
150+
self.reg_max = reg_max
151+
self.reg_conf = MLP(
152+
input_dim=4 * (k + 1),
153+
hidden_dim=hidden_dim,
154+
output_dim=1,
155+
num_layers=num_layers,
156+
activation=activation,
157+
)
158+
init.constant_(self.reg_conf.layers[-1].bias, 0)
159+
init.constant_(self.reg_conf.layers[-1].weight, 0)
160+
161+
def forward(self, scores: Tensor, pred_corners: Tensor) -> Tensor:
162+
"""Estimate localization quality and adjust scores.
163+
164+
Args:
165+
scores (Tensor): Predicted classification scores of shape (B, N, C).
166+
pred_corners (Tensor): Predicted bounding box corners of shape (B, N, 4*(reg_max+1)).
167+
168+
Returns:
169+
Tensor: Quality-adjusted scores of shape (B, N, C).
170+
"""
171+
b, num_pred, _ = pred_corners.size()
172+
prob = f.softmax(pred_corners.reshape(b, num_pred, 4, self.reg_max + 1), dim=-1)
173+
prob_topk, _ = prob.topk(self.k, dim=-1)
174+
stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
175+
quality_score = self.reg_conf(stat.reshape(b, num_pred, -1))
176+
return scores + quality_score
177+
178+
179+
class SwiGLUFFN(nn.Module):
180+
def __init__(
181+
self,
182+
in_features: int,
183+
hidden_features: int,
184+
out_features: int,
185+
bias: bool = True,
186+
) -> None:
187+
"""
188+
Initializes SwiGLUFFN module.
189+
190+
Args:
191+
in_features (int): Number of input features.
192+
hidden_features (int): Number of hidden features.
193+
out_features (int): Number of output features.
194+
bias (bool, optional): Whether to use bias in linear layers. Defaults to True.
195+
"""
196+
super().__init__()
197+
out_features = out_features or in_features
198+
hidden_features = hidden_features or in_features
199+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
200+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
201+
self._reset_parameters()
202+
203+
def _reset_parameters(self):
204+
init.xavier_uniform_(self.w12.weight)
205+
init.constant_(self.w12.bias, 0)
206+
init.xavier_uniform_(self.w3.weight)
207+
init.constant_(self.w3.bias, 0)
208+
209+
def forward(self, x: torch.Tensor) -> torch.Tensor:
210+
"""
211+
Forward pass of the SwiGLUFFN module.
212+
213+
Args:
214+
x (torch.Tensor): Input tensor.
215+
216+
Returns:
217+
torch.Tensor: Output tensor.
218+
"""
219+
x12 = self.w12(x)
220+
x1, x2 = x12.chunk(2, dim=-1)
221+
hidden = f.silu(x1) * x2
222+
return self.w3(hidden)
223+
224+
225+
def get_contrastive_denoising_training_group(
226+
targets: list[dict[str, torch.Tensor]],
227+
num_classes: int,
228+
num_queries: int,
229+
class_embed: torch.nn.Module,
230+
num_denoising: int = 100,
231+
label_noise_ratio: float = 0.5,
232+
box_noise_scale: float = 1.0,
233+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]] | tuple[None, None, None, None]:
234+
"""Generate contrastive denoising training group.
235+
236+
Args:
237+
targets (List[Dict[str, torch.Tensor]]): List of target dictionaries.
238+
num_classes (int): Number of classes.
239+
num_queries (int): Number of queries.
240+
class_embed (torch.nn.Module): Class embedding module.
241+
num_denoising (int, optional): Number of denoising queries. Defaults to 100.
242+
label_noise_ratio (float, optional): Ratio of label noise. Defaults to 0.5.
243+
box_noise_scale (float, optional): Scale of box noise. Defaults to 1.0.
244+
245+
Returns:
246+
Tuple[Tensor,Tensor,Tensor, dict[str, Tensor]] | tuple[None,None,None,None]:
247+
Tuple containing input query class, input query bbox, attention mask, and denoising metadata.
248+
"""
249+
num_gts = [len(t["labels"]) for t in targets]
250+
device = targets[0]["labels"].device
251+
252+
max_gt_num = max(num_gts)
253+
if max_gt_num == 0:
254+
return None, None, None, None
255+
256+
num_group = num_denoising // max_gt_num
257+
num_group = 1 if num_group == 0 else num_group
258+
# pad gt to max_num of a batch
259+
bs = len(num_gts)
260+
261+
input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device)
262+
input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device)
263+
pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device)
264+
265+
for i in range(bs):
266+
num_gt = num_gts[i]
267+
if num_gt > 0:
268+
input_query_class[i, :num_gt] = targets[i]["labels"]
269+
input_query_bbox[i, :num_gt] = targets[i]["boxes"]
270+
pad_gt_mask[i, :num_gt] = 1
271+
# each group has positive and negative queries.
272+
input_query_class = input_query_class.tile([1, 2 * num_group])
273+
input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
274+
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
275+
# positive and negative mask
276+
negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device)
277+
negative_gt_mask[:, max_gt_num:] = 1
278+
negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
279+
positive_gt_mask = 1 - negative_gt_mask
280+
# contrastive denoising training positive index
281+
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
282+
dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
283+
dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts])
284+
# total denoising queries
285+
num_denoising = int(max_gt_num * 2 * num_group)
286+
287+
if label_noise_ratio > 0:
288+
mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
289+
# randomly put a new one here
290+
new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
291+
input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
292+
293+
if box_noise_scale > 0:
294+
known_bbox = torchvision.ops.box_convert(input_query_bbox, in_fmt="cxcywh", out_fmt="xyxy")
295+
diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
296+
rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
297+
rand_part = torch.rand_like(input_query_bbox)
298+
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
299+
rand_part *= rand_sign
300+
known_bbox += rand_part * diff
301+
known_bbox.clip_(min=0.0, max=1.0)
302+
input_query_bbox = torchvision.ops.box_convert(known_bbox, in_fmt="xyxy", out_fmt="cxcywh")
303+
input_query_bbox = inverse_sigmoid(input_query_bbox)
304+
305+
input_query_class = class_embed(input_query_class)
306+
307+
tgt_size = num_denoising + num_queries
308+
attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device)
309+
# match query cannot see the reconstruction
310+
attn_mask[num_denoising:, :num_denoising] = True
311+
312+
# reconstruct cannot see each other
313+
for i in range(num_group):
314+
if i == 0:
315+
attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True
316+
if i == num_group - 1:
317+
attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * i * 2] = True
318+
else:
319+
attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True
320+
attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * 2 * i] = True
321+
322+
dn_meta = {
323+
"dn_positive_idx": dn_positive_idx,
324+
"dn_num_group": num_group,
325+
"dn_num_split": [num_denoising, num_queries],
326+
}
327+
328+
return input_query_class, input_query_bbox, attn_mask, dn_meta

0 commit comments

Comments
 (0)