Skip to content

Commit 8c9bf6d

Browse files
committed
fixed dfine hybrid decoder
1 parent 1b4fa5c commit 8c9bf6d

File tree

8 files changed

+1129
-1325
lines changed

8 files changed

+1129
-1325
lines changed

library/src/otx/backend/native/models/common/backbones/dinov3.py

Lines changed: 298 additions & 307 deletions
Large diffs are not rendered by default.

library/src/otx/backend/native/models/common/layers/transformer_layers.py

Lines changed: 285 additions & 162 deletions
Large diffs are not rendered by default.

library/src/otx/backend/native/models/detection/heads/deim_decoder.py

Lines changed: 481 additions & 271 deletions
Large diffs are not rendered by default.

library/src/otx/backend/native/models/detection/heads/dfine_decoder.py

Lines changed: 1 addition & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
from torch import Tensor, nn
1616
from torch.nn import init
1717

18-
from otx.backend.native.models.common.layers.transformer_layers import MLP, MSDeformableAttentionV2
18+
from otx.backend.native.models.common.layers.transformer_layers import MLP, MSDeformableAttentionV2, get_contrastive_denoising_training_group, LQE, Gate, Integral
1919
from otx.backend.native.models.common.utils.utils import inverse_sigmoid
20-
from otx.backend.native.models.detection.heads.rtdetr_decoder import get_contrastive_denoising_training_group
2120
from otx.backend.native.models.detection.utils.utils import dfine_distance2bbox, dfine_weighting_function
2221
from otx.backend.native.models.utils.weight_init import bias_init_with_prob
2322

@@ -137,129 +136,6 @@ def forward(
137136
return self.norm3(target.clamp(min=-65504, max=65504))
138137

139138

140-
class RMSNorm(nn.Module):
141-
def __init__(self, dim: int, eps: float = 1e-6):
142-
super().__init__()
143-
self.dim = dim
144-
self.eps = eps
145-
self.scale = nn.Parameter(torch.ones(dim))
146-
147-
def _norm(self, x):
148-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
149-
150-
def forward(self, x):
151-
output = self._norm(x.float()).type_as(x)
152-
output = output * self.scale
153-
return output
154-
155-
def extra_repr(self) -> str:
156-
return f'dim={self.dim}, eps={self.eps}'
157-
158-
159-
class Gate(nn.Module):
160-
"""Target Gating Layers.
161-
162-
Args:
163-
d_model (int): The number of expected features in the input.
164-
use_rmsnorm (bool, optional): Whether to use RMSNorm. Defaults to False.
165-
"""
166-
167-
def __init__(self, d_model: int, use_rmsnorm: bool = False) -> None:
168-
super().__init__()
169-
self.gate = nn.Linear(2 * d_model, 2 * d_model)
170-
bias = bias_init_with_prob(0.5)
171-
init.constant_(self.gate.bias, bias)
172-
init.constant_(self.gate.weight, 0)
173-
self.norm = RMSNorm(d_model) if use_rmsnorm else nn.LayerNorm(d_model)
174-
175-
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
176-
"""Forward function of the gate.
177-
178-
Args:
179-
x1 (Tensor): first target input tensor.
180-
x2 (Tensor): second target input tensor.
181-
182-
Returns:
183-
Tensor: gated target tensor.
184-
"""
185-
gate_input = torch.cat([x1, x2], dim=-1)
186-
gates = torch.sigmoid(self.gate(gate_input))
187-
gate1, gate2 = gates.chunk(2, dim=-1)
188-
return self.norm(gate1 * x1 + gate2 * x2)
189-
190-
191-
class Integral(nn.Module):
192-
"""A static layer that calculates integral results from a distribution.
193-
194-
This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
195-
where Pr(n) is the softmax probability vector representing the discrete
196-
distribution, and W(n) is the non-uniform Weighting Function.
197-
198-
Args:
199-
reg_max (int): Max number of the discrete bins. Default is 32.
200-
It can be adjusted based on the dataset or task requirements.
201-
"""
202-
203-
def __init__(self, reg_max: int = 32):
204-
super().__init__()
205-
self.reg_max = reg_max
206-
207-
def forward(self, x: Tensor, box_distance_weight: Tensor) -> Tensor:
208-
"""Forward function of the Integral layer."""
209-
shape = x.shape
210-
x = f.softmax(x.reshape(-1, self.reg_max + 1), dim=1)
211-
x = f.linear(x, box_distance_weight).reshape(-1, 4)
212-
return x.reshape([*list(shape[:-1]), -1])
213-
214-
215-
class LQE(nn.Module):
216-
"""Localization Quality Estimation.
217-
218-
Args:
219-
k (int): number of edge points.
220-
hidden_dim (int): The number of expected features in the input.
221-
num_layers (int): The number of layers in the MLP.
222-
reg_max (int): Max number of the discrete bins.
223-
"""
224-
225-
def __init__(
226-
self,
227-
k: int,
228-
hidden_dim: int,
229-
num_layers: int,
230-
reg_max: int,
231-
):
232-
super().__init__()
233-
self.k = k
234-
self.reg_max = reg_max
235-
self.reg_conf = MLP(
236-
input_dim=4 * (k + 1),
237-
hidden_dim=hidden_dim,
238-
output_dim=1,
239-
num_layers=num_layers,
240-
activation=partial(nn.ReLU, inplace=True),
241-
)
242-
init.constant_(self.reg_conf.layers[-1].bias, 0)
243-
init.constant_(self.reg_conf.layers[-1].weight, 0)
244-
245-
def forward(self, scores: Tensor, pred_corners: Tensor) -> Tensor:
246-
"""Forward function of the LQE layer.
247-
248-
Args:
249-
scores (Tensor): Prediction scores.
250-
pred_corners (Tensor): Predicted bounding box corners.
251-
252-
Returns:
253-
Tensor: Updated scores.
254-
"""
255-
b, num_pred, _ = pred_corners.size()
256-
prob = f.softmax(pred_corners.reshape(b, num_pred, 4, self.reg_max + 1), dim=-1)
257-
prob_topk, _ = prob.topk(self.k, dim=-1)
258-
stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
259-
quality_score = self.reg_conf(stat.reshape(b, num_pred, -1))
260-
return scores + quality_score
261-
262-
263139
class TransformerDecoder(nn.Module):
264140
"""Transformer Decoder implementing Fine-grained Distribution Refinement (FDR).
265141

library/src/otx/backend/native/models/detection/heads/rtdetr_decoder.py

Lines changed: 1 addition & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -11,123 +11,16 @@
1111
from typing import Any, Callable, ClassVar
1212

1313
import torch
14-
import torchvision
1514
from torch import nn
1615
from torch.nn import init
1716

18-
from otx.backend.native.models.common.layers.transformer_layers import MLP, MSDeformableAttention
17+
from otx.backend.native.models.common.layers.transformer_layers import MLP, MSDeformableAttention, get_contrastive_denoising_training_group
1918
from otx.backend.native.models.common.utils.utils import inverse_sigmoid
2019
from otx.backend.native.models.modules.base_module import BaseModule
2120

2221
__all__ = ["RTDETRTransformer"]
2322

2423

25-
def get_contrastive_denoising_training_group(
26-
targets: list[dict[str, torch.Tensor]],
27-
num_classes: int,
28-
num_queries: int,
29-
class_embed: torch.nn.Module,
30-
num_denoising: int = 100,
31-
label_noise_ratio: float = 0.5,
32-
box_noise_scale: float = 1.0,
33-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]] | tuple[None, None, None, None]:
34-
"""Generate contrastive denoising training group.
35-
36-
Args:
37-
targets (List[Dict[str, torch.Tensor]]): List of target dictionaries.
38-
num_classes (int): Number of classes.
39-
num_queries (int): Number of queries.
40-
class_embed (torch.nn.Module): Class embedding module.
41-
num_denoising (int, optional): Number of denoising queries. Defaults to 100.
42-
label_noise_ratio (float, optional): Ratio of label noise. Defaults to 0.5.
43-
box_noise_scale (float, optional): Scale of box noise. Defaults to 1.0.
44-
45-
Returns:
46-
Tuple[Tensor,Tensor,Tensor, dict[str, Tensor]] | tuple[None,None,None,None]:
47-
Tuple containing input query class, input query bbox, attention mask, and denoising metadata.
48-
"""
49-
num_gts = [len(t["labels"]) for t in targets]
50-
device = targets[0]["labels"].device
51-
52-
max_gt_num = max(num_gts)
53-
if max_gt_num == 0:
54-
return None, None, None, None
55-
56-
num_group = num_denoising // max_gt_num
57-
num_group = 1 if num_group == 0 else num_group
58-
# pad gt to max_num of a batch
59-
bs = len(num_gts)
60-
61-
input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device)
62-
input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device)
63-
pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device)
64-
65-
for i in range(bs):
66-
num_gt = num_gts[i]
67-
if num_gt > 0:
68-
input_query_class[i, :num_gt] = targets[i]["labels"]
69-
input_query_bbox[i, :num_gt] = targets[i]["boxes"]
70-
pad_gt_mask[i, :num_gt] = 1
71-
# each group has positive and negative queries.
72-
input_query_class = input_query_class.tile([1, 2 * num_group])
73-
input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
74-
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
75-
# positive and negative mask
76-
negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device)
77-
negative_gt_mask[:, max_gt_num:] = 1
78-
negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
79-
positive_gt_mask = 1 - negative_gt_mask
80-
# contrastive denoising training positive index
81-
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
82-
dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
83-
dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts])
84-
# total denoising queries
85-
num_denoising = int(max_gt_num * 2 * num_group)
86-
87-
if label_noise_ratio > 0:
88-
mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
89-
# randomly put a new one here
90-
new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
91-
input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
92-
93-
if box_noise_scale > 0:
94-
known_bbox = torchvision.ops.box_convert(input_query_bbox, in_fmt="cxcywh", out_fmt="xyxy")
95-
diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
96-
rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
97-
rand_part = torch.rand_like(input_query_bbox)
98-
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
99-
rand_part *= rand_sign
100-
known_bbox += rand_part * diff
101-
known_bbox.clip_(min=0.0, max=1.0)
102-
input_query_bbox = torchvision.ops.box_convert(known_bbox, in_fmt="xyxy", out_fmt="cxcywh")
103-
input_query_bbox = inverse_sigmoid(input_query_bbox)
104-
105-
input_query_class = class_embed(input_query_class)
106-
107-
tgt_size = num_denoising + num_queries
108-
attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device)
109-
# match query cannot see the reconstruction
110-
attn_mask[num_denoising:, :num_denoising] = True
111-
112-
# reconstruct cannot see each other
113-
for i in range(num_group):
114-
if i == 0:
115-
attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True
116-
if i == num_group - 1:
117-
attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * i * 2] = True
118-
else:
119-
attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True
120-
attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * 2 * i] = True
121-
122-
dn_meta = {
123-
"dn_positive_idx": dn_positive_idx,
124-
"dn_num_group": num_group,
125-
"dn_num_split": [num_denoising, num_queries],
126-
}
127-
128-
return input_query_class, input_query_bbox, attn_mask, dn_meta
129-
130-
13124
class TransformerDecoderLayer(nn.Module):
13225
"""TransformerDecoderLayer.
13326

0 commit comments

Comments
 (0)