diff --git a/configs/rtdetrv3/_base_/optimizer_6x.yml b/configs/rtdetrv3/_base_/optimizer_6x.yml new file mode 100644 index 0000000000..c253ca3b3b --- /dev/null +++ b/configs/rtdetrv3/_base_/optimizer_6x.yml @@ -0,0 +1,19 @@ +epoch: 72 + +LearningRate: + base_lr: 0.0004 + schedulers: + - !PiecewiseDecay + gamma: 1.0 + milestones: [100] + use_warmup: true + - !LinearWarmup + start_factor: 0.001 + steps: 2000 + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + regularizer: false + optimizer: + type: AdamW + weight_decay: 0.0001 diff --git a/configs/rtdetrv3/_base_/rtdetr_reader.yml b/configs/rtdetrv3/_base_/rtdetr_reader.yml new file mode 100644 index 0000000000..4f90a9b8fa --- /dev/null +++ b/configs/rtdetrv3/_base_/rtdetr_reader.yml @@ -0,0 +1,44 @@ +worker_num: 4 +TrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {prob: 0.8} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomCrop: {prob: 0.8} + - RandomFlip: {} + batch_transforms: + - BatchRandomResize: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {retain_origin_box: true} + - BboxXYXY2XYWH: {} + - Permute: {} + - PadGT: {only_origin_box: true} + batch_size: 16 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: true + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 16 + shuffle: false + drop_last: false + + +TestReader: + inputs_def: + image_shape: [3, 640, 640] + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/rtdetrv3/_base_/rtdetrv3_r50vd.yml b/configs/rtdetrv3/_base_/rtdetrv3_r50vd.yml new file mode 100644 index 0000000000..b0cd5a7eff --- /dev/null +++ b/configs/rtdetrv3/_base_/rtdetrv3_r50vd.yml @@ -0,0 +1,100 @@ +architecture: RTDETRV3 +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams +norm_type: sync_bn +use_ema: True +ema_decay: 0.9999 +ema_decay_type: "exponential" +ema_filter_no_grad: True +hidden_dim: 256 +use_focal_loss: True +eval_size: [640, 640] + + +RTDETRV3: + backbone: ResNet + neck: HybridEncoder + transformer: RTDETRTransformerv3 + detr_head: DINOv3Head + aux_o2m_head: PPYOLOEHead + post_process: DETRPostProcess + +ResNet: + # index 0 stands for res2 + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [1, 2, 3] + lr_mult_list: [0.1, 0.1, 0.1, 0.1] + num_stages: 4 + freeze_stem_only: True + +HybridEncoder: + hidden_dim: 256 + use_encoder_idx: [2] + num_encoder_layers: 1 + encoder_layer: + name: TransformerLayer + d_model: 256 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + activation: 'gelu' + expansion: 1.0 + + +RTDETRTransformerv3: + num_queries: 300 + position_embed_type: sine + feat_strides: [8, 16, 32] + num_levels: 3 + nhead: 8 + num_decoder_layers: 6 + dim_feedforward: 1024 + dropout: 0.0 + activation: relu + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 + learnt_init_query: False + num_noises: 0 + num_noise_queries: [] + num_noise_denoising: 100 + learnt_init_query: False + + +DINOv3Head: + o2m: 4 + loss: + name: DINOv3Loss + loss_coeff: {class: 1, bbox: 5, giou: 2} + aux_loss: True + use_vfl: True + matcher: + name: HungarianMatcher + matcher_coeff: {class: 2, bbox: 5, giou: 2} + +PPYOLOEHead: + fpn_strides: [8, 16, 32] + grid_cell_scale: 5.0 + grid_cell_offset: 0.5 + static_assigner_epoch: 30 + use_varifocal_loss: True + loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5} + static_assigner: + name: ATSSAssigner + topk: 9 + assigner: + name: TaskAlignedAssigner + topk: 13 + alpha: 1.0 + beta: 6.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 300 + score_threshold: 0.01 + nms_threshold: 0.7 + +DETRPostProcess: + num_top_queries: 300 diff --git a/configs/rtdetrv3/rtdetrv3_r18vd_6x_coco.yml b/configs/rtdetrv3/rtdetrv3_r18vd_6x_coco.yml new file mode 100644 index 0000000000..6b1d39fa2b --- /dev/null +++ b/configs/rtdetrv3/rtdetrv3_r18vd_6x_coco.yml @@ -0,0 +1,54 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_6x.yml', + '_base_/rtdetrv3_r50vd.yml', + '_base_/rtdetr_reader.yml', +] + +weights: output/rtdetrv3_r18vd_6x_coco/model_final +find_unused_parameters: True +log_iter: 200 + +o2m_branch: True +num_queries_o2m: 450 + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet18_vd_pretrained.pdparams + +RTDETRV3: + backbone: ResNet + neck: HybridEncoder + transformer: RTDETRTransformerv3 + detr_head: DINOv3Head + aux_o2m_head: PPYOLOEHead + post_process: DETRPostProcess + +ResNet: + depth: 18 + variant: d + return_idx: [1, 2, 3] + freeze_at: -1 + freeze_norm: false + norm_decay: 0. + +HybridEncoder: + hidden_dim: 256 + use_encoder_idx: [2] + num_encoder_layers: 1 + encoder_layer: + name: TransformerLayer + d_model: 256 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + activation: 'gelu' + expansion: 0.5 + depth_mult: 1.0 + +RTDETRTransformerv3: + eval_idx: -1 + num_decoder_layers: 3 + num_noises: 3 + num_noise_queries: [300, 300, 300] + num_noise_denoising: 100 + learnt_init_query: False diff --git a/configs/rtdetrv3/rtdetrv3_r34vd_6x_coco.yml b/configs/rtdetrv3/rtdetrv3_r34vd_6x_coco.yml new file mode 100644 index 0000000000..fde28c5024 --- /dev/null +++ b/configs/rtdetrv3/rtdetrv3_r34vd_6x_coco.yml @@ -0,0 +1,54 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_6x.yml', + '_base_/rtdetrv3_r50vd.yml', + '_base_/rtdetr_reader.yml', +] + +weights: output/rtdetrv3_r34vd_6x_coco/model_final +find_unused_parameters: True +log_iter: 200 + +o2m_branch: True +num_queries_o2m: 450 + +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ResNet34_vd_pretrained.pdparams + +RTDETRV3: + backbone: ResNet + neck: HybridEncoder + transformer: RTDETRTransformerv3 + detr_head: DINOv3Head + aux_o2m_head: PPYOLOEHead + post_process: DETRPostProcess + +ResNet: + depth: 34 + variant: d + return_idx: [1, 2, 3] + freeze_at: -1 + freeze_norm: false + norm_decay: 0. + +HybridEncoder: + hidden_dim: 256 + use_encoder_idx: [2] + num_encoder_layers: 1 + encoder_layer: + name: TransformerLayer + d_model: 256 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + activation: 'gelu' + expansion: 0.5 + depth_mult: 1.0 + +RTDETRTransformerv3: + eval_idx: -1 + num_decoder_layers: 4 + num_noises: 3 + num_noise_queries: [300, 300, 300] + num_noise_denoising: 100 + learnt_init_query: False diff --git a/configs/rtdetrv3/rtdetrv3_r50vd_6x_coco.yml b/configs/rtdetrv3/rtdetrv3_r50vd_6x_coco.yml new file mode 100644 index 0000000000..664e95cd12 --- /dev/null +++ b/configs/rtdetrv3/rtdetrv3_r50vd_6x_coco.yml @@ -0,0 +1,25 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_6x.yml', + '_base_/rtdetrv3_r50vd.yml', + '_base_/rtdetr_reader.yml', +] + +weights: output/rtdetrv3_r50vd_6x_coco/model_final +find_unused_parameters: True +log_iter: 200 + +o2m_branch: True +num_queries_o2m: 450 + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams + + +RTDETRTransformerv3: + eval_idx: -1 + num_decoder_layers: 6 + num_noises: 2 + num_noise_queries: [300, 300] + num_noise_denoising: 100 + learnt_init_query: False diff --git a/ppdet/engine/naive_sync_bn.py b/ppdet/engine/naive_sync_bn.py index 17f0007a91..f65bc40bc8 100644 --- a/ppdet/engine/naive_sync_bn.py +++ b/ppdet/engine/naive_sync_bn.py @@ -104,4 +104,4 @@ def convert_bn(model): bn.set_dict(m.state_dict()) setattr(model, n, bn) else: - convert_bn(m) \ No newline at end of file + convert_bn(m) diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 4803b97e42..1e1384fb4e 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -83,3 +83,6 @@ from .detr_ssod import * from .multi_stream_detector import * from .clrnet import * + +from . import rtdetrv3 +from .rtdetrv3 import * diff --git a/ppdet/modeling/architectures/rtdetrv3.py b/ppdet/modeling/architectures/rtdetrv3.py new file mode 100644 index 0000000000..0af4258e90 --- /dev/null +++ b/ppdet/modeling/architectures/rtdetrv3.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Borrowed from RT-DETRv3 (https://github.com/clxia12/RT-DETRv3) + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from .meta_arch import BaseArch +from ppdet.core.workspace import register, create + +__all__ = ['RTDETRV3'] +# Deformable DETR, DINO use the same architecture as DETR + + +@register +class RTDETRV3(BaseArch): + __category__ = 'architecture' + __inject__ = ['post_process', 'post_process_semi'] + __shared__ = ['with_mask', 'exclude_post_process'] + + def __init__(self, + backbone, + transformer='DETRTransformer', + detr_head='DETRHead', + neck=None, + aux_o2m_head=None, + post_process='DETRPostProcess', + post_process_semi=None, + with_mask=False, + exclude_post_process=False): + super(RTDETRV3, self).__init__() + self.backbone = backbone + self.transformer = transformer + self.detr_head = detr_head + self.neck = neck + self.aux_o2m_head = aux_o2m_head + self.post_process = post_process + self.with_mask = with_mask + self.exclude_post_process = exclude_post_process + self.post_process_semi = post_process_semi + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + # backbone + backbone = create(cfg['backbone']) + # neck + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) if cfg['neck'] else None + + # transformer + if neck is not None: + kwargs = {'input_shape': neck.out_shape} + transformer = create(cfg['transformer'], **kwargs) + # head + kwargs = { + 'hidden_dim': transformer.hidden_dim, + 'nhead': transformer.nhead, + 'input_shape': backbone.out_shape + } + detr_head = create(cfg['detr_head'], **kwargs) + + kwargs = {'input_shape': neck.out_shape} + aux_o2m_head = create(cfg['aux_o2m_head'], **kwargs) + + return { + 'backbone': backbone, + 'transformer': transformer, + "detr_head": detr_head, + "neck": neck, + "aux_o2m_head": aux_o2m_head + } + + def _forward(self): + # Backbone + body_feats = self.backbone(self.inputs) + + # Neck + if self.neck is not None: + body_feats = self.neck(body_feats) + + # Transformer + pad_mask = self.inputs.get('pad_mask', None) + out_transformer = self.transformer(body_feats, pad_mask, self.inputs) + + # DETR Head + if self.training: + detr_losses = self.detr_head(out_transformer, body_feats, + self.inputs) + detr_losses.update({ + 'loss': paddle.add_n( + [v for k, v in detr_losses.items() if 'log' not in k]) + }) + if self.aux_o2m_head is not None: + aux_o2m_losses = self.aux_o2m_head(body_feats, self.inputs) + for k, v in aux_o2m_losses.items(): + if k == 'loss': + detr_losses[k] += v + k = k + '_aux_o2m' + detr_losses[k] = v + return detr_losses + else: + preds = self.detr_head(out_transformer, body_feats) + if self.exclude_post_process: + bbox, bbox_num, mask = preds + else: + bbox, bbox_num, mask = self.post_process( + preds, self.inputs['im_shape'], self.inputs['scale_factor'], + self.inputs['image'][2:].shape) + + # aux_o2m_outs = self.aux_o2m_head(body_feats) + # bbox, bbox_num, nms_keep_idx = self.aux_o2m_head.post_process( + # aux_o2m_outs, self.inputs['scale_factor']) + + output = {'bbox': bbox, 'bbox_num': bbox_num} + if self.with_mask: + output['mask'] = mask + return output + + def get_loss(self): + return self._forward() + + def get_pred(self): + return self._forward() diff --git a/ppdet/modeling/heads/detr_head.py b/ppdet/modeling/heads/detr_head.py index 50ed14ae61..2836f7f788 100644 --- a/ppdet/modeling/heads/detr_head.py +++ b/ppdet/modeling/heads/detr_head.py @@ -24,7 +24,7 @@ from ..initializer import linear_init_, constant_ from ..transformers.utils import inverse_sigmoid -__all__ = ['DETRHead', 'DeformableDETRHead', 'DINOHead', 'MaskDINOHead'] +__all__ = ['DETRHead', 'DeformableDETRHead', 'DINOHead', 'MaskDINOHead', 'DINOv3Head'] class MLP(nn.Layer): @@ -537,3 +537,109 @@ def forward(self, out_transformer, body_feats, inputs=None): dn_meta=dn_meta) else: return (dec_out_bboxes[-1], dec_out_logits[-1], dec_out_masks[-1]) + +@register +class DINOv3Head(nn.Layer): + __inject__ = ['loss'] + __shared__ = ['o2m_branch', 'num_queries_o2m'] + + + def __init__(self, loss='DINOLoss', eval_idx=-1, o2m=4, o2m_branch=False, num_queries_o2m=450): + super(DINOv3Head, self).__init__() + self.loss = loss + self.eval_idx = eval_idx + self.o2m = o2m + self.o2m_branch = o2m_branch + self.num_queries_o2m = num_queries_o2m + + def forward(self, out_transformer, body_feats, inputs=None): + (dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, + dn_meta) = out_transformer + if self.training: + assert inputs is not None + assert 'gt_bbox' in inputs and 'gt_class' in inputs + + if dn_meta is not None: + num_groups = len(dn_meta) + total_dec_queries = dec_out_bboxes.shape[2] + total_enc_queries = enc_topk_bboxes.shape[1] + loss = {} + if self.o2m_branch: + dec_out_bboxes, dec_out_bboxes_o2m = paddle.split(dec_out_bboxes, [total_dec_queries - self.num_queries_o2m, self.num_queries_o2m], axis=2) + dec_out_logits, dec_out_logits_o2m = paddle.split(dec_out_logits, [total_dec_queries - self.num_queries_o2m, self.num_queries_o2m], axis=2) + enc_topk_bboxes, enc_topk_bboxes_o2m = paddle.split(enc_topk_bboxes, [total_enc_queries - self.num_queries_o2m, self.num_queries_o2m], axis=1) + enc_topk_logits, enc_topk_logits_o2m = paddle.split(enc_topk_logits, [total_enc_queries - self.num_queries_o2m, self.num_queries_o2m], axis=1) + + out_bboxes_o2m = paddle.concat([enc_topk_bboxes_o2m.unsqueeze(0), dec_out_bboxes_o2m]) + out_logits_o2m = paddle.concat([enc_topk_logits_o2m.unsqueeze(0), dec_out_logits_o2m]) + loss_o2m = self.loss( + out_bboxes_o2m, + out_logits_o2m, + inputs['gt_bbox'], + inputs['gt_class'], + dn_out_bboxes=None, + dn_out_logits=None, + dn_meta=None, + o2m=self.o2m) + for key, value in loss_o2m.items(): + key = key + '_o2m_branch' + loss.update({ + key: loss.get(key, paddle.zeros([1])) + value + }) + + split_dec_num = [sum(dn['dn_num_split']) for dn in dn_meta] + split_enc_num = [dn['dn_num_split'][1] for dn in dn_meta] + dec_out_bboxes = paddle.split(dec_out_bboxes, split_dec_num, axis=2) + dec_out_logits = paddle.split(dec_out_logits, split_dec_num, axis=2) + enc_topk_bboxes = paddle.split(enc_topk_bboxes, split_enc_num, axis=1) + enc_topk_logits = paddle.split(enc_topk_logits, split_enc_num, axis=1) + + for g_id in range(num_groups): + dn_out_bboxes_gid, dec_out_bboxes_gid = paddle.split( + dec_out_bboxes[g_id], dn_meta[g_id]['dn_num_split'], axis=2) + dn_out_logits_gid, dec_out_logits_gid = paddle.split( + dec_out_logits[g_id], dn_meta[g_id]['dn_num_split'], axis=2) + out_bboxes_gid = paddle.concat([ + enc_topk_bboxes[g_id].unsqueeze(0), dec_out_bboxes_gid]) + out_logits_gid = paddle.concat([ + enc_topk_logits[g_id].unsqueeze(0), dec_out_logits_gid]) + + loss_gid = self.loss( + out_bboxes_gid, + out_logits_gid, + inputs['gt_bbox'], + inputs['gt_class'], + dn_out_bboxes=dn_out_bboxes_gid, + dn_out_logits=dn_out_logits_gid, + dn_meta=dn_meta[g_id]) + # sum loss + for key, value in loss_gid.items(): + loss.update({ + key: loss.get(key, paddle.zeros([1])) + value + }) + + # average across (dual_groups + 1) + for key, value in loss.items(): + if '_o2m_branch' not in key: + loss.update({key: value / num_groups}) + return loss + else: + dn_out_bboxes, dn_out_logits = None, None + + out_bboxes = paddle.concat( + [enc_topk_bboxes.unsqueeze(0), dec_out_bboxes]) + out_logits = paddle.concat( + [enc_topk_logits.unsqueeze(0), dec_out_logits]) + + return self.loss( + out_bboxes, + out_logits, + inputs['gt_bbox'], + inputs['gt_class'], + dn_out_bboxes=dn_out_bboxes, + dn_out_logits=dn_out_logits, + dn_meta=dn_meta, + gt_score=inputs.get('gt_score', None)) + else: + return (dec_out_bboxes[self.eval_idx], + dec_out_logits[self.eval_idx], None) diff --git a/ppdet/modeling/losses/detr_loss.py b/ppdet/modeling/losses/detr_loss.py index 6712f928d5..b5314525f4 100644 --- a/ppdet/modeling/losses/detr_loss.py +++ b/ppdet/modeling/losses/detr_loss.py @@ -24,7 +24,7 @@ from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss, varifocal_loss_with_logits from ..bbox_utils import bbox_iou -__all__ = ['DETRLoss', 'DINOLoss'] +__all__ = ['DETRLoss', 'DINOLoss', 'DINOv3Loss'] @register @@ -400,6 +400,7 @@ def forward(self, gt_mask=None, postfix="", gt_score=None, + o2m=1, **kwargs): r""" Args: @@ -515,6 +516,101 @@ def get_dn_match_indices(labels, dn_positive_idx, dn_num_group): [0], dtype="int64"))) return dn_match_indices +@register +class DINOv3Loss(DETRLoss): + def forward(self, + boxes, + logits, + gt_bbox, + gt_class, + masks=None, + gt_mask=None, + postfix="", + dn_out_bboxes=None, + dn_out_logits=None, + dn_meta=None, + gt_score=None, + o2m=1, + **kwargs): + + if o2m != 1: + gt_boxes_copy = [box.tile([o2m, 1]) for box in gt_bbox] + gt_class_copy = [label.tile([o2m, 1]) for label in gt_class] + else: + gt_boxes_copy = gt_bbox + gt_class_copy = gt_class + num_gts_copy = self._get_num_gts(gt_class_copy) + total_loss = self._get_prediction_loss( + boxes[-1], + logits[-1], + gt_boxes_copy, + gt_class_copy, + masks=masks[-1] if masks is not None else None, + gt_mask=gt_mask, + postfix=postfix, + dn_match_indices=None, + num_gts=num_gts_copy, + gt_score=gt_score if gt_score is not None else None) + + if self.aux_loss: + total_loss.update( + self._get_loss_aux( + boxes[:-1], + logits[:-1], + gt_boxes_copy, + gt_class_copy, + self.num_classes, + num_gts_copy, + dn_match_indices=None, + postfix=postfix, + masks=masks[:-1] if masks is not None else None, + gt_mask=gt_mask, + gt_score=gt_score if gt_score is not None else None)) + + if dn_meta is not None: + num_gts = self._get_num_gts(gt_class) + dn_positive_idx, dn_num_group = \ + dn_meta["dn_positive_idx"], dn_meta["dn_num_group"] + assert len(gt_class) == len(dn_positive_idx) + + # denoising match indices + dn_match_indices = self.get_dn_match_indices( + gt_class, dn_positive_idx, dn_num_group) + + # compute denoising training loss + num_gts *= dn_num_group + dn_loss = super(DINOv3Loss, self).forward( + dn_out_bboxes, + dn_out_logits, + gt_bbox, + gt_class, + postfix="_dn", + dn_match_indices=dn_match_indices, + num_gts=num_gts, + gt_score=gt_score) + total_loss.update(dn_loss) + else: + total_loss.update( + {k + '_dn': paddle.to_tensor([0.]) + for k in total_loss.keys()}) + + return total_loss + + @staticmethod + def get_dn_match_indices(labels, dn_positive_idx, dn_num_group): + dn_match_indices = [] + for i in range(len(labels)): + num_gt = len(labels[i]) + if num_gt > 0: + gt_idx = paddle.arange(end=num_gt, dtype="int64") + gt_idx = gt_idx.tile([dn_num_group]) + assert len(dn_positive_idx[i]) == len(gt_idx) + dn_match_indices.append((dn_positive_idx[i], gt_idx)) + else: + dn_match_indices.append((paddle.zeros( + [0], dtype="int64"), paddle.zeros( + [0], dtype="int64"))) + return dn_match_indices @register class MaskDINOLoss(DETRLoss): diff --git a/ppdet/modeling/transformers/__init__.py b/ppdet/modeling/transformers/__init__.py index d11f3d69e6..fcbcc96a7c 100644 --- a/ppdet/modeling/transformers/__init__.py +++ b/ppdet/modeling/transformers/__init__.py @@ -26,6 +26,7 @@ from . import rtdetr_transformerv2 from . import co_deformable_detr_transformer from . import co_dino_transformer +from . import rtdetr_transformerv3 from .detr_transformer import * from .utils import * @@ -42,3 +43,4 @@ from .rtdetr_transformerv2 import * from .co_deformable_detr_transformer import * from .co_dino_transformer import * +from .rtdetr_transformerv3 import * diff --git a/ppdet/modeling/transformers/rtdetr_transformerv3.py b/ppdet/modeling/transformers/rtdetr_transformerv3.py new file mode 100644 index 0000000000..68dc316b68 --- /dev/null +++ b/ppdet/modeling/transformers/rtdetr_transformerv3.py @@ -0,0 +1,654 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Borrowed from RT-DETRv3 (https://github.com/clxia12/RT-DETRv3) +# +# Modified from Deformable-DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Modified from detrex (https://github.com/IDEA-Research/detrex) +# Copyright 2022 The IDEA Authors. All rights reserved. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay + +from ppdet.core.workspace import register +from ..layers import MultiHeadAttention +from ..heads.detr_head import MLP +from .deformable_transformer import MSDeformableAttention +from ..initializer import (linear_init_, constant_, xavier_uniform_, normal_, + bias_init_with_prob) +from .utils import (_get_clones, get_sine_pos_embed, + get_contrastive_denoising_training_group, inverse_sigmoid) + +__all__ = ['RTDETRTransformerv3'] + + +class PPMSDeformableAttention(MSDeformableAttention): + def forward(self, + query, + reference_points, + value, + value_spatial_shapes, + value_level_start_index, + value_mask=None): + """ + Args: + query (Tensor): [bs, query_length, C] + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (Tensor): [bs, value_length, C] + value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value_level_start_index (List): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...] + value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_q = query.shape[:2] + Len_v = value.shape[1] + + value = self.value_proj(value) + if value_mask is not None: + value_mask = value_mask.astype(value.dtype).unsqueeze(-1) + value *= value_mask + value = value.reshape([bs, Len_v, self.num_heads, self.head_dim]) + + sampling_offsets = self.sampling_offsets(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2]) + attention_weights = self.attention_weights(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels * self.num_points]) + attention_weights = F.softmax(attention_weights).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points]) + + if reference_points.shape[-1] == 2: + offset_normalizer = paddle.to_tensor(value_spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape( + [1, 1, 1, self.num_levels, 1, 2]) + sampling_locations = reference_points.reshape([ + bs, Len_q, 1, self.num_levels, 1, 2 + ]) + sampling_offsets / offset_normalizer + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + sampling_offsets / + self.num_points * reference_points[:, :, None, :, None, 2:] * + 0.5) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.". + format(reference_points.shape[-1])) + + if not isinstance(query, paddle.Tensor): + from ppdet.modeling.transformers.utils import deformable_attention_core_func + output = deformable_attention_core_func( + value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights) + else: + value_spatial_shapes = paddle.to_tensor(value_spatial_shapes) + value_level_start_index = paddle.to_tensor(value_level_start_index) + output = self.ms_deformable_attn_core( + value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights) + output = self.output_proj(output) + + return output + + +class TransformerDecoderLayer(nn.Layer): + def __init__(self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0., + activation="relu", + n_levels=4, + n_points=4, + weight_attr=None, + bias_attr=None): + super(TransformerDecoderLayer, self).__init__() + + # self attention + self.self_attn = MultiHeadAttention(d_model, n_head, dropout=dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + # cross attention + self.cross_attn = PPMSDeformableAttention(d_model, n_head, n_levels, + n_points, 1.0) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr, + bias_attr) + self.activation = getattr(F, activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr, + bias_attr) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + xavier_uniform_(self.linear1.weight) + xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + return self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + + def forward(self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + attn_mask=None, + memory_mask=None, + query_pos_embed=None): + # self attention + q = k = self.with_pos_embed(tgt, query_pos_embed) + if attn_mask is not None: + attn_mask = paddle.where( + attn_mask.astype('bool'), + paddle.zeros(attn_mask.shape, tgt.dtype), + paddle.full(attn_mask.shape, float("-inf"), tgt.dtype)) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos_embed), reference_points, memory, + memory_spatial_shapes, memory_level_start_index, memory_mask) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # ffn + tgt2 = self.forward_ffn(tgt) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +class TransformerDecoder(nn.Layer): + def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + + def forward(self, + tgt, + ref_points_unact, + memory, + memory_spatial_shapes, + memory_level_start_index, + bbox_head, + score_head, + query_pos_head, + attn_mask=None, + memory_mask=None, + query_pos_head_inv_sig=False): + output = tgt + dec_out_bboxes = [] + dec_out_logits = [] + ref_points_detach = F.sigmoid(ref_points_unact) + for i, layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + if not query_pos_head_inv_sig: + query_pos_embed = query_pos_head(ref_points_detach) + else: + query_pos_embed = query_pos_head( + inverse_sigmoid(ref_points_detach)) + + output = layer(output, ref_points_input, memory, + memory_spatial_shapes, memory_level_start_index, + attn_mask, memory_mask, query_pos_embed) + + inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid( + ref_points_detach)) + + if self.training: + dec_out_logits.append(score_head[i](output)) + if i == 0: + dec_out_bboxes.append(inter_ref_bbox) + else: + dec_out_bboxes.append( + F.sigmoid(bbox_head[i](output) + inverse_sigmoid( + ref_points))) + elif i == self.eval_idx: + dec_out_logits.append(score_head[i](output)) + dec_out_bboxes.append(inter_ref_bbox) + break + + ref_points = inter_ref_bbox + ref_points_detach = inter_ref_bbox.detach( + ) if self.training else inter_ref_bbox + + return paddle.stack(dec_out_bboxes), paddle.stack(dec_out_logits) + + +@register +class RTDETRTransformerv3(nn.Layer): + __shared__ = ['num_classes', 'hidden_dim', 'eval_size', + 'o2m_branch', 'num_queries_o2m'] + + def __init__(self, + num_classes=80, + hidden_dim=256, + num_queries=300, + position_embed_type='sine', + backbone_feat_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + num_levels=3, + num_decoder_points=4, + nhead=8, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0., + activation="relu", + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learnt_init_query=True, + query_pos_head_inv_sig=False, + eval_size=None, + eval_idx=-1, + num_noises=0, + num_noise_queries=[], + num_noise_denoising=100, + o2m_branch=False, + num_queries_o2m=450, + eps=1e-2): + super(RTDETRTransformerv3, self).__init__() + assert position_embed_type in ['sine', 'learned'], \ + f'ValueError: position_embed_type not supported {position_embed_type}!' + assert len(backbone_feat_channels) <= num_levels + assert len(feat_strides) == len(backbone_feat_channels) + assert len(num_noise_queries) == num_noises + for _ in range(num_levels - len(feat_strides)): + feat_strides.append(feat_strides[-1] * 2) + + self.hidden_dim = hidden_dim + self.nhead = nhead + self.feat_strides = feat_strides + self.num_levels = num_levels + self.num_classes = num_classes + self.num_queries = [num_queries] + self.eps = eps + self.num_decoder_layers = num_decoder_layers + self.eval_size = eval_size + + self.num_noises = num_noises + self.num_noise_denoising = num_noise_denoising + self.num_groups = 1 + if num_noises > 0: + self.num_queries.extend(num_noise_queries) + self.num_groups += num_noises + + self.o2m_branch = o2m_branch + self.num_queries_o2m = num_queries_o2m + if o2m_branch: + self.num_queries.append(num_queries_o2m) + self.num_groups += 1 + + # backbone feature projection + self._build_input_proj_layer(backbone_feat_channels) + + # Transformer module + decoder_layer = TransformerDecoderLayer( + hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels, + num_decoder_points) + self.decoder = TransformerDecoder(hidden_dim, decoder_layer, + num_decoder_layers, eval_idx) + + # denoising part + self.denoising_class_embed = nn.Embedding( + num_classes, + hidden_dim, + weight_attr=ParamAttr(initializer=nn.initializer.Normal())) + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + + # decoder embedding + self.learnt_init_query = learnt_init_query + if learnt_init_query: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2) + self.query_pos_head_inv_sig = query_pos_head_inv_sig + + # encoder head + self.enc_output = nn.LayerList([ + nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm( + hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))) + for _ in range(self.num_groups) + ]) + self.enc_score_head = nn.LayerList([ + nn.Linear(hidden_dim, num_classes) + for _ in range(self.num_groups) + ]) + self.enc_bbox_head = nn.LayerList([ + MLP(hidden_dim, hidden_dim, 4, num_layers=3) + for _ in range(self.num_groups) + ]) + + self.map_memory = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm( + hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + ) + + # decoder head + self.dec_score_head = nn.LayerList([ + nn.Linear(hidden_dim, num_classes) + for _ in range(num_decoder_layers) + ]) + self.dec_bbox_head = nn.LayerList([ + MLP(hidden_dim, hidden_dim, 4, num_layers=3) + for _ in range(num_decoder_layers) + ]) + + self._reset_parameters() + + def _reset_parameters(self): + # class and bbox head init + bias_cls = bias_init_with_prob(0.01) + for enc_score_head in self.enc_score_head: + linear_init_(enc_score_head) + constant_(enc_score_head.bias, bias_cls) + for enc_bbox_head in self.enc_bbox_head: + constant_(enc_bbox_head.layers[-1].weight) + constant_(enc_bbox_head.layers[-1].bias) + for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): + linear_init_(cls_) + constant_(cls_.bias, bias_cls) + constant_(reg_.layers[-1].weight) + constant_(reg_.layers[-1].bias) + + for enc_output in self.enc_output: + linear_init_(enc_output[0]) + xavier_uniform_(enc_output[0].weight) + linear_init_(self.map_memory[0]) + xavier_uniform_(self.map_memory[0].weight) + + if self.learnt_init_query: + xavier_uniform_(self.tgt_embed.weight) + xavier_uniform_(self.query_pos_head.layers[0].weight) + xavier_uniform_(self.query_pos_head.layers[1].weight) + for l in self.input_proj: + xavier_uniform_(l[0].weight) + + # init encoder output anchors and valid_mask + if self.eval_size: + self.anchors, self.valid_mask = self._generate_anchors() + + @classmethod + def from_config(cls, cfg, input_shape): + return {'backbone_feat_channels': [i.channels for i in input_shape]} + + def _build_input_proj_layer(self, backbone_feat_channels): + self.input_proj = nn.LayerList() + for in_channels in backbone_feat_channels: + self.input_proj.append( + nn.Sequential( + ('conv', nn.Conv2D( + in_channels, + self.hidden_dim, + kernel_size=1, + bias_attr=False)), ('norm', nn.BatchNorm2D( + self.hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))))) + in_channels = backbone_feat_channels[-1] + for _ in range(self.num_levels - len(backbone_feat_channels)): + self.input_proj.append( + nn.Sequential( + ('conv', nn.Conv2D( + in_channels, + self.hidden_dim, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False)), ('norm', nn.BatchNorm2D( + self.hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))))) + in_channels = self.hidden_dim + + def _get_encoder_input(self, feats): + # get projection features + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + if self.num_levels > len(proj_feats): + len_srcs = len(proj_feats) + for i in range(len_srcs, self.num_levels): + if i == len_srcs: + proj_feats.append(self.input_proj[i](feats[-1])) + else: + proj_feats.append(self.input_proj[i](proj_feats[-1])) + + # get encoder inputs + feat_flatten = [] + spatial_shapes = [] + level_start_index = [0, ] + for i, feat in enumerate(proj_feats): + _, _, h, w = feat.shape + # [b, c, h, w] -> [b, h*w, c] + feat_flatten.append(feat.flatten(2).transpose([0, 2, 1])) + # [num_levels, 2] + spatial_shapes.append([h, w]) + # [l], start index of each level + level_start_index.append(h * w + level_start_index[-1]) + + # [b, l, c] + feat_flatten = paddle.concat(feat_flatten, 1) + level_start_index.pop() + return (feat_flatten, spatial_shapes, level_start_index) + + def forward(self, feats, pad_mask=None, gt_meta=None, is_teacher=False): + # input projection and embedding + (memory, spatial_shapes, + level_start_index) = self._get_encoder_input(feats) + + # prepare denoising training + if self.training: + denoising_classes, denoising_bbox_unacts, attn_masks, dn_metas = [], [], [], [] + for g_id in range(self.num_noises + 1): + if g_id == 0: + num_denoising = self.num_denoising + else: + num_denoising = self.num_noise_denoising + denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \ + get_contrastive_denoising_training_group(gt_meta, + self.num_classes, + self.num_queries[g_id], + self.denoising_class_embed.weight, + num_denoising, + self.label_noise_ratio, + self.box_noise_scale) + denoising_classes.append(denoising_class) + denoising_bbox_unacts.append(denoising_bbox_unact) + attn_masks.append(attn_mask) + dn_metas.append(dn_meta) + else: + denoising_classes, denoising_bbox_unacts, attn_masks, dn_metas = None, None, None, None + + target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \ + self._get_decoder_input( + memory, spatial_shapes, denoising_classes, denoising_bbox_unacts, is_teacher) + + # multi group noise attention + if self.training: + new_size = target.shape[1] + new_attn_mask = paddle.ones([new_size, new_size]) < 0 + begin, end = 0, 0 + mask = None + for g_id in range(self.num_groups): + new_mask = paddle.rand([self.num_queries[g_id], self.num_queries[g_id]]) + if self.o2m_branch and g_id == self.num_groups - 1: + end = end + self.num_queries_o2m + new_mask = new_mask >= 0.0 + new_attn_mask[begin: end, begin: end] = new_mask + else: + end = end + attn_masks[g_id].shape[1] + dn_size, q_size = dn_metas[g_id]['dn_num_split'] + if g_id > 0: + new_mask = new_mask > 0.1 + else: + new_mask = new_mask >= 0.0 + attn_masks[g_id][dn_size: dn_size + q_size, dn_size: dn_size + q_size] = new_mask + new_attn_mask[begin: end, begin: end] = attn_masks[g_id] + begin = end + attn_masks = new_attn_mask + + # decoder + out_bboxes, out_logits = self.decoder( + target, + init_ref_points_unact, + memory, + spatial_shapes, + level_start_index, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + attn_mask=attn_masks, + memory_mask=None, + query_pos_head_inv_sig=self.query_pos_head_inv_sig) + return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, + dn_metas) + + def _generate_anchors(self, + spatial_shapes=None, + grid_size=0.05, + dtype="float32"): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.eval_size[0] / s), int(self.eval_size[1] / s)] + for s in self.feat_strides + ] + anchors = [] + for lvl, (h, w) in enumerate(spatial_shapes): + grid_y, grid_x = paddle.meshgrid( + paddle.arange( + end=h, dtype=dtype), + paddle.arange( + end=w, dtype=dtype)) + grid_xy = paddle.stack([grid_x, grid_y], -1) + + valid_WH = paddle.to_tensor([h, w]).astype(dtype) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH + wh = paddle.ones_like(grid_xy) * grid_size * (2.0**lvl) + anchors.append( + paddle.concat([grid_xy, wh], -1).reshape([-1, h * w, 4])) + + anchors = paddle.concat(anchors, 1) + valid_mask = ((anchors > self.eps) * + (anchors < 1 - self.eps)).all(-1, keepdim=True) + anchors = paddle.log(anchors / (1 - anchors)) + anchors = paddle.where(valid_mask, anchors, + paddle.to_tensor(float("inf"))) + return anchors, valid_mask + + def _get_decoder_input(self, + memory, + spatial_shapes, + denoising_classes=None, + denoising_bbox_unacts=None, + is_teacher=False): + bs, _, _ = memory.shape + # prepare input for decoder + if self.training or self.eval_size is None or is_teacher: + anchors, valid_mask = self._generate_anchors(spatial_shapes) + else: + anchors, valid_mask = self.anchors, self.valid_mask + memory = paddle.where(valid_mask, memory, paddle.to_tensor(0.)) + map_memory = self.map_memory(memory.detach()) + targets, reference_points_unacts, enc_topk_bboxes, enc_topk_logits = [], [], [], [] + + for g_id in range(self.num_groups): + output_memory = self.enc_output[g_id](memory) + enc_outputs_class = self.enc_score_head[g_id](output_memory) + enc_outputs_coord_unact = self.enc_bbox_head[g_id](output_memory) + anchors + + _, topk_ind = paddle.topk( + enc_outputs_class.max(-1), self.num_queries[g_id], axis=1) + # extract region proposal boxes + batch_ind = paddle.arange(end=bs, dtype=topk_ind.dtype) + batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries[g_id]]) + topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1) + + reference_points_unact = paddle.gather_nd(enc_outputs_coord_unact, topk_ind) # unsigmoided. + enc_topk_bbox = F.sigmoid(reference_points_unact) + enc_topk_logit = paddle.gather_nd(enc_outputs_class, topk_ind) + + if denoising_bbox_unacts is not None and not (self.o2m_branch and g_id == self.num_groups - 1): + reference_points_unact = paddle.concat( + [denoising_bbox_unacts[g_id], reference_points_unact], 1) + if self.training: + reference_points_unact = reference_points_unact.detach() + + # extract region features + if self.learnt_init_query: + target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + else: + if g_id == 0: + target = paddle.gather_nd(output_memory, topk_ind) + if self.training: + target = target.detach() + else: + target = paddle.gather_nd(map_memory, topk_ind) + if denoising_classes is not None and not (self.o2m_branch and g_id == self.num_groups - 1): + target = paddle.concat([denoising_classes[g_id], target], 1) + + if not self.training: + return target, reference_points_unact, enc_topk_bbox, enc_topk_logit + + targets.append(target) + reference_points_unacts.append(reference_points_unact) + enc_topk_bboxes.append(enc_topk_bbox) + enc_topk_logits.append(enc_topk_logit) + + targets = paddle.concat(targets, 1) + reference_points_unacts = paddle.concat(reference_points_unacts, 1) + enc_topk_bboxes = paddle.concat(enc_topk_bboxes, 1) + enc_topk_logits = paddle.concat(enc_topk_logits, 1) + return targets, reference_points_unacts, enc_topk_bboxes, enc_topk_logits