Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 93 additions & 65 deletions paddlex/inference/models/object_detection/modeling/rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,77 +181,107 @@ def __call__(self, head_out, im_shape, scale_factor, pad_shape):
class RTDETRConfig(PretrainedConfig):
def __init__(
self,
arch,
return_idx,
freeze_stem_only,
freeze_at,
freeze_norm,
lr_mult_list,
hidden_dim,
use_encoder_idx,
num_encoder_layers,
el_d_model,
el_nhead,
el_dim_feedforward,
el_dropout,
el_activation,
expansion,
tf_num_queries,
tf_position_embed_type,
tf_feat_strides,
tf_num_levels,
tf_nhead,
tf_num_decoder_layers,
tf_backbone_feat_channels,
tf_dim_feedforward,
tf_dropout,
tf_activation,
tf_num_denoising,
tf_label_noise_ratio,
tf_box_noise_scale,
tf_learnt_init_query,
loss_coeff,
aux_loss,
use_vfl,
matcher_coeff,
num_top_queries,
use_focal_loss,
initializer_range=0.01,
initializer_bias_prior_prob=None,
layer_norm_eps=1e-5,
batch_norm_eps=1e-5,
# backbone
backbone_config=None,
freeze_backbone_batch_norms=True,
# encoder HybridEncoder
encoder_hidden_dim=256,
encoder_in_channels=[512, 1024, 2048],
feat_strides=[8, 16, 32],
encoder_layers=1,
encoder_ffn_dim=1024,
encoder_attention_heads=8,
dropout=0.0,
activation_dropout=0.0,
encode_proj_layers=[2],
positional_encoding_temperature=10000,
encoder_activation_function="gelu",
activation_function="silu",
eval_size=None,
normalize_before=False,
hidden_expansion=1.0,
# decoder RTDetrTransformer
d_model=256,
num_queries=300,
decoder_in_channels=[256, 256, 256],
decoder_ffn_dim=1024,
num_feature_levels=3,
decoder_n_points=4,
decoder_layers=6,
decoder_attention_heads=8,
decoder_activation_function="relu",
attention_dropout=0.0,
num_denoising=100,
label_noise_ratio=0.5,
box_noise_scale=1.0,
learn_initial_query=False,
anchor_image_size=None,
disable_custom_kernels=True,
with_box_refine=True,
is_encoder_decoder=True,
# Loss
matcher_alpha=0.25,
matcher_gamma=2.0,
matcher_class_cost=2.0,
matcher_bbox_cost=5.0,
matcher_giou_cost=2.0,
use_focal_loss=True,
auxiliary_loss=True,
focal_loss_alpha=0.75,
focal_loss_gamma=2.0,
weight_loss_vfl=1.0,
weight_loss_bbox=5.0,
weight_loss_giou=2.0,
eos_coefficient=1e-4,
**kwargs,
):
self.arch = arch
self.return_idx = return_idx
self.freeze_stem_only = freeze_stem_only
self.freeze_at = freeze_at
self.freeze_norm = freeze_norm
self.lr_mult_list = lr_mult_list
self.hidden_dim = hidden_dim
self.use_encoder_idx = use_encoder_idx
self.num_encoder_layers = num_encoder_layers
if backbone_config["model_type"] != "hgnet_v2":
raise RuntimeError(
f"There is no dynamic graph implementation for backbone {repr(backbone_config["model_type"])}."
)
self.arch = backbone_config["arch"]
self.freeze_stem_only = backbone_config["freeze_stem_only"]
self.freeze_at = backbone_config["freeze_at"]
self.freeze_norm = backbone_config["freeze_norm"]
self.lr_mult_list = backbone_config["lr_mult_list"]
self.return_idx = backbone_config["return_idx"]
self.hidden_dim = encoder_hidden_dim
self.use_encoder_idx = encode_proj_layers
self.num_encoder_layers = encoder_layers
self.el_d_model = d_model
self.el_nhead = nhead
self.el_dim_feedforward = dim_feedforward
self.el_nhead = encoder_attention_heads
self.el_dim_feedforward = encoder_ffn_dim
self.el_dropout = dropout
self.el_activation = activation
self.expansion = expansion
self.el_activation = encoder_activation_function
self.expansion = hidden_expansion
self.tf_num_queries = num_queries
self.tf_position_embed_type = position_embed_type
self.tf_feat_strides = feat_strides
self.tf_num_levels = num_levels
self.tf_nhead = nhead
self.tf_num_decoder_layers = num_decoder_layers
self.tf_backbone_feat_channels = backbone_feat_channels
self.tf_dim_feedforward = dim_feedforward
self.tf_dropout = dropout
self.tf_activation = activation
self.tf_num_levels = num_feature_levels
self.tf_nhead = decoder_attention_heads
self.tf_num_decoder_layers = decoder_layers
self.tf_backbone_feat_channels = decoder_in_channels
self.tf_dim_feedforward = decoder_ffn_dim
self.tf_dropout = attention_dropout
self.tf_activation = decoder_activation_function
self.tf_num_denoising = num_denoising
self.tf_label_noise_ratio = label_noise_ratio
self.tf_box_noise_scale = box_noise_scale
self.tf_learnt_init_query = learnt_init_query
self.loss_coeff = loss_coeff
self.aux_loss = aux_loss
self.use_vfl = use_vfl
self.matcher_coeff = matcher_coeff
self.num_top_queries = num_top_queries
self.tf_learnt_init_query = learn_initial_query
self.loss_coeff = {
"class": weight_loss_vfl,
"bbox": weight_loss_bbox,
"giou": weight_loss_giou
}
self.aux_loss = auxiliary_loss
self.matcher_coeff = {
"class": matcher_class_cost,
"bbox": matcher_bbox_cost,
"giou": matcher_giou_cost
}
self.use_focal_loss = use_focal_loss
self.tensor_parallel_degree = 1

Expand Down Expand Up @@ -286,7 +316,6 @@ def __init__(self, config: RTDETRConfig):
)
self.transformer = RTDETRTransformer(
num_queries=self.config.tf_num_queries,
position_embed_type=self.config.tf_position_embed_type,
feat_strides=self.config.tf_feat_strides,
backbone_feat_channels=self.config.tf_backbone_feat_channels,
num_levels=self.config.tf_num_levels,
Expand All @@ -304,14 +333,13 @@ def __init__(self, config: RTDETRConfig):
loss=DINOLoss(
loss_coeff=self.config.loss_coeff,
aux_loss=self.config.aux_loss,
use_vfl=self.config.use_vfl,
matcher=HungarianMatcher(
matcher_coeff=self.config.matcher_coeff,
),
)
)
self.post_process = DETRPostProcess(
num_top_queries=self.config.num_top_queries,
num_top_queries=self.config.tf_num_queries,
use_focal_loss=self.config.use_focal_loss,
)

Expand Down