Skip to content

Commit cdeaad9

Browse files
authored
[superglue] Fixed the way batch mask was applied to the scores before match assignment computation (#39968)
fix: mask filling to score was wrong
1 parent 2593932 commit cdeaad9

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/transformers/models/superglue/modeling_superglue.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,10 @@ def _match_image_pair(
676676

677677
if mask is not None:
678678
mask = mask.reshape(batch_size, 2, num_keypoints)
679-
mask0 = mask[:, 0].unsqueeze(-1).expand(-1, -1, num_keypoints)
680-
scores = scores.masked_fill(mask0 == 0, -1e9)
679+
mask0 = mask[:, 0].unsqueeze(2)
680+
mask1 = mask[:, 1].unsqueeze(1)
681+
mask = torch.logical_and(mask0, mask1)
682+
scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min)
681683

682684
# Run the optimal transport.
683685
scores = log_optimal_transport(scores, self.bin_score, iterations=self.config.sinkhorn_iterations)

tests/models/superglue/test_modeling_superglue.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,3 +423,5 @@ def test_inference(self):
423423
torch.sum(~torch.isclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-2)) < 4
424424
)
425425
self.assertTrue(torch.sum(predicted_matches_values != expected_matches_values) < 4)
426+
self.assertTrue(torch.all(outputs.matches[0, 1] < torch.sum(outputs.mask[0, 0])))
427+
self.assertTrue(torch.all(outputs.matches[0, 0] < torch.sum(outputs.mask[0, 1])))

0 commit comments

Comments
 (0)