Skip to content

Commit 56c4421

Browse files
authored
[detection] fix attention mask for RT-DETR-based models (#40269)
* Fix get_contrastive_denoising_training_group attention * Add bool attention_mask conversion
1 parent 5d9a715 commit 56c4421

File tree

10 files changed

+56
-12
lines changed

10 files changed

+56
-12
lines changed

examples/modular-transformers/modeling_test_detr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,10 @@ def forward(
561561
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
562562
f" {attention_mask.size()}"
563563
)
564+
if attention_mask.dtype == torch.bool:
565+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
566+
attention_mask, -torch.inf
567+
)
564568
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
565569
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
566570

src/transformers/models/conditional_detr/modeling_conditional_detr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,10 @@ def forward(
537537
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
538538
f" {attention_mask.size()}"
539539
)
540+
if attention_mask.dtype == torch.bool:
541+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
542+
attention_mask, -torch.inf
543+
)
540544
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
541545
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
542546

@@ -654,6 +658,10 @@ def forward(
654658
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
655659
f" {attention_mask.size()}"
656660
)
661+
if attention_mask.dtype == torch.bool:
662+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
663+
attention_mask, -torch.inf
664+
)
657665
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
658666
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
659667

src/transformers/models/d_fine/modeling_d_fine.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,10 @@ def forward(
299299
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
300300
f" {attention_mask.size()}"
301301
)
302+
if attention_mask.dtype == torch.bool:
303+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
304+
attention_mask, -torch.inf
305+
)
302306
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
303307
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
304308

@@ -856,16 +860,16 @@ def get_contrastive_denoising_training_group(
856860
input_query_class = class_embed(input_query_class)
857861

858862
target_size = num_denoising_queries + num_queries
859-
attn_mask = torch.full([target_size, target_size], False, dtype=torch.bool, device=device)
863+
attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
860864
# match query cannot see the reconstruction
861-
attn_mask[num_denoising_queries:, :num_denoising_queries] = True
865+
attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
862866

863867
# reconstructions cannot see each other
864868
for i in range(num_groups_denoising_queries):
865869
idx_block_start = max_gt_num * 2 * i
866870
idx_block_end = max_gt_num * 2 * (i + 1)
867-
attn_mask[idx_block_start:idx_block_end, :idx_block_start] = True
868-
attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = True
871+
attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
872+
attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
869873

870874
denoising_meta_values = {
871875
"dn_positive_idx": denoise_positive_idx,

src/transformers/models/deformable_detr/modeling_deformable_detr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,10 @@ def forward(
674674
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
675675
f" {attention_mask.size()}"
676676
)
677+
if attention_mask.dtype == torch.bool:
678+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
679+
attention_mask, -torch.inf
680+
)
677681
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
678682
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
679683

src/transformers/models/deprecated/deta/modeling_deta.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,10 @@ def forward(
789789
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
790790
f" {attention_mask.size()}"
791791
)
792+
if attention_mask.dtype == torch.bool:
793+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
794+
attention_mask, -torch.inf
795+
)
792796
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
793797
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
794798

src/transformers/models/detr/modeling_detr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,10 @@ def forward(
505505
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
506506
f" {attention_mask.size()}"
507507
)
508+
if attention_mask.dtype == torch.bool:
509+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
510+
attention_mask, -torch.inf
511+
)
508512
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
509513
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
510514

src/transformers/models/maskformer/modeling_maskformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,10 @@ def forward(
482482
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
483483
f" {attention_mask.size()}"
484484
)
485+
if attention_mask.dtype == torch.bool:
486+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
487+
attention_mask, -torch.inf
488+
)
485489
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
486490
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
487491

src/transformers/models/rt_detr/modeling_rt_detr.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,16 +456,16 @@ def get_contrastive_denoising_training_group(
456456
input_query_class = class_embed(input_query_class)
457457

458458
target_size = num_denoising_queries + num_queries
459-
attn_mask = torch.full([target_size, target_size], False, dtype=torch.bool, device=device)
459+
attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
460460
# match query cannot see the reconstruction
461-
attn_mask[num_denoising_queries:, :num_denoising_queries] = True
461+
attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
462462

463463
# reconstructions cannot see each other
464464
for i in range(num_groups_denoising_queries):
465465
idx_block_start = max_gt_num * 2 * i
466466
idx_block_end = max_gt_num * 2 * (i + 1)
467-
attn_mask[idx_block_start:idx_block_end, :idx_block_start] = True
468-
attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = True
467+
attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
468+
attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
469469

470470
denoising_meta_values = {
471471
"dn_positive_idx": denoise_positive_idx,
@@ -854,6 +854,10 @@ def forward(
854854
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
855855
f" {attention_mask.size()}"
856856
)
857+
if attention_mask.dtype == torch.bool:
858+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
859+
attention_mask, -torch.inf
860+
)
857861
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
858862
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
859863

src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,10 @@ def forward(
305305
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
306306
f" {attention_mask.size()}"
307307
)
308+
if attention_mask.dtype == torch.bool:
309+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
310+
attention_mask, -torch.inf
311+
)
308312
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
309313
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
310314

@@ -1188,16 +1192,16 @@ def get_contrastive_denoising_training_group(
11881192
input_query_class = class_embed(input_query_class)
11891193

11901194
target_size = num_denoising_queries + num_queries
1191-
attn_mask = torch.full([target_size, target_size], False, dtype=torch.bool, device=device)
1195+
attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
11921196
# match query cannot see the reconstruction
1193-
attn_mask[num_denoising_queries:, :num_denoising_queries] = True
1197+
attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
11941198

11951199
# reconstructions cannot see each other
11961200
for i in range(num_groups_denoising_queries):
11971201
idx_block_start = max_gt_num * 2 * i
11981202
idx_block_end = max_gt_num * 2 * (i + 1)
1199-
attn_mask[idx_block_start:idx_block_end, :idx_block_start] = True
1200-
attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = True
1203+
attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
1204+
attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
12011205

12021206
denoising_meta_values = {
12031207
"dn_positive_idx": denoise_positive_idx,

src/transformers/models/table_transformer/modeling_table_transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,10 @@ def forward(
464464
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
465465
f" {attention_mask.size()}"
466466
)
467+
if attention_mask.dtype == torch.bool:
468+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
469+
attention_mask, -torch.inf
470+
)
467471
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
468472
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
469473

0 commit comments

Comments
 (0)