|
11 | 11 | from typing import Any, Callable, ClassVar |
12 | 12 |
|
13 | 13 | import torch |
14 | | -import torchvision |
15 | 14 | from torch import nn |
16 | 15 | from torch.nn import init |
17 | 16 |
|
18 | | -from otx.backend.native.models.common.layers.transformer_layers import MLP, MSDeformableAttention |
| 17 | +from otx.backend.native.models.common.layers.transformer_layers import MLP, MSDeformableAttention, get_contrastive_denoising_training_group |
19 | 18 | from otx.backend.native.models.common.utils.utils import inverse_sigmoid |
20 | 19 | from otx.backend.native.models.modules.base_module import BaseModule |
21 | 20 |
|
22 | 21 | __all__ = ["RTDETRTransformer"] |
23 | 22 |
|
24 | 23 |
|
25 | | -def get_contrastive_denoising_training_group( |
26 | | - targets: list[dict[str, torch.Tensor]], |
27 | | - num_classes: int, |
28 | | - num_queries: int, |
29 | | - class_embed: torch.nn.Module, |
30 | | - num_denoising: int = 100, |
31 | | - label_noise_ratio: float = 0.5, |
32 | | - box_noise_scale: float = 1.0, |
33 | | -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]] | tuple[None, None, None, None]: |
34 | | - """Generate contrastive denoising training group. |
35 | | -
|
36 | | - Args: |
37 | | - targets (List[Dict[str, torch.Tensor]]): List of target dictionaries. |
38 | | - num_classes (int): Number of classes. |
39 | | - num_queries (int): Number of queries. |
40 | | - class_embed (torch.nn.Module): Class embedding module. |
41 | | - num_denoising (int, optional): Number of denoising queries. Defaults to 100. |
42 | | - label_noise_ratio (float, optional): Ratio of label noise. Defaults to 0.5. |
43 | | - box_noise_scale (float, optional): Scale of box noise. Defaults to 1.0. |
44 | | -
|
45 | | - Returns: |
46 | | - Tuple[Tensor,Tensor,Tensor, dict[str, Tensor]] | tuple[None,None,None,None]: |
47 | | - Tuple containing input query class, input query bbox, attention mask, and denoising metadata. |
48 | | - """ |
49 | | - num_gts = [len(t["labels"]) for t in targets] |
50 | | - device = targets[0]["labels"].device |
51 | | - |
52 | | - max_gt_num = max(num_gts) |
53 | | - if max_gt_num == 0: |
54 | | - return None, None, None, None |
55 | | - |
56 | | - num_group = num_denoising // max_gt_num |
57 | | - num_group = 1 if num_group == 0 else num_group |
58 | | - # pad gt to max_num of a batch |
59 | | - bs = len(num_gts) |
60 | | - |
61 | | - input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device) |
62 | | - input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device) |
63 | | - pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device) |
64 | | - |
65 | | - for i in range(bs): |
66 | | - num_gt = num_gts[i] |
67 | | - if num_gt > 0: |
68 | | - input_query_class[i, :num_gt] = targets[i]["labels"] |
69 | | - input_query_bbox[i, :num_gt] = targets[i]["boxes"] |
70 | | - pad_gt_mask[i, :num_gt] = 1 |
71 | | - # each group has positive and negative queries. |
72 | | - input_query_class = input_query_class.tile([1, 2 * num_group]) |
73 | | - input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1]) |
74 | | - pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group]) |
75 | | - # positive and negative mask |
76 | | - negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device) |
77 | | - negative_gt_mask[:, max_gt_num:] = 1 |
78 | | - negative_gt_mask = negative_gt_mask.tile([1, num_group, 1]) |
79 | | - positive_gt_mask = 1 - negative_gt_mask |
80 | | - # contrastive denoising training positive index |
81 | | - positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask |
82 | | - dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] |
83 | | - dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts]) |
84 | | - # total denoising queries |
85 | | - num_denoising = int(max_gt_num * 2 * num_group) |
86 | | - |
87 | | - if label_noise_ratio > 0: |
88 | | - mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) |
89 | | - # randomly put a new one here |
90 | | - new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) |
91 | | - input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) |
92 | | - |
93 | | - if box_noise_scale > 0: |
94 | | - known_bbox = torchvision.ops.box_convert(input_query_bbox, in_fmt="cxcywh", out_fmt="xyxy") |
95 | | - diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale |
96 | | - rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 |
97 | | - rand_part = torch.rand_like(input_query_bbox) |
98 | | - rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) |
99 | | - rand_part *= rand_sign |
100 | | - known_bbox += rand_part * diff |
101 | | - known_bbox.clip_(min=0.0, max=1.0) |
102 | | - input_query_bbox = torchvision.ops.box_convert(known_bbox, in_fmt="xyxy", out_fmt="cxcywh") |
103 | | - input_query_bbox = inverse_sigmoid(input_query_bbox) |
104 | | - |
105 | | - input_query_class = class_embed(input_query_class) |
106 | | - |
107 | | - tgt_size = num_denoising + num_queries |
108 | | - attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device) |
109 | | - # match query cannot see the reconstruction |
110 | | - attn_mask[num_denoising:, :num_denoising] = True |
111 | | - |
112 | | - # reconstruct cannot see each other |
113 | | - for i in range(num_group): |
114 | | - if i == 0: |
115 | | - attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True |
116 | | - if i == num_group - 1: |
117 | | - attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * i * 2] = True |
118 | | - else: |
119 | | - attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1) : num_denoising] = True |
120 | | - attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * 2 * i] = True |
121 | | - |
122 | | - dn_meta = { |
123 | | - "dn_positive_idx": dn_positive_idx, |
124 | | - "dn_num_group": num_group, |
125 | | - "dn_num_split": [num_denoising, num_queries], |
126 | | - } |
127 | | - |
128 | | - return input_query_class, input_query_bbox, attn_mask, dn_meta |
129 | | - |
130 | | - |
131 | 24 | class TransformerDecoderLayer(nn.Module): |
132 | 25 | """TransformerDecoderLayer. |
133 | 26 |
|
|
0 commit comments