Skip to content

Commit 1a86944

Browse files
authored
Merge pull request #1616 from borglab/cherrypick-transitivity-fix-dsftrackgenerator
Cherrypick commits for release/4.2, that include transitivity fix for DsfTrackGenerator
2 parents 13c7daf + 2f2d654 commit 1a86944

File tree

2 files changed

+108
-14
lines changed

2 files changed

+108
-14
lines changed

gtsam/sfm/DsfTrackGenerator.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include <algorithm>
2222
#include <iostream>
23+
#include <iomanip>
2324

2425
namespace gtsam {
2526

@@ -38,7 +39,8 @@ static DSFMapIndexPair generateDSF(const MatchIndicesMap& matches) {
3839
// Image pair is (i1,i2).
3940
size_t i1 = pair_indices.first;
4041
size_t i2 = pair_indices.second;
41-
for (size_t k = 0; k < corr_indices.rows(); k++) {
42+
size_t m = static_cast<size_t>(corr_indices.rows());
43+
for (size_t k = 0; k < m; k++) {
4244
// Measurement indices are found in a single matrix row, as (k1,k2).
4345
size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1);
4446
// Unique key for DSF is (i,k), representing keypoint index in an image.
@@ -128,7 +130,7 @@ std::vector<SfmTrack2d> tracksFromPairwiseMatches(
128130
}
129131

130132
// TODO(johnwlambert): return the Transitivity failure percentage here.
131-
return tracks2d;
133+
return validTracks;
132134
}
133135

134136
} // namespace gtsfm

python/gtsam/tests/test_DsfTrackGenerator.py

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,65 @@
44
"""
55

66
import unittest
7+
from typing import Dict, Tuple
78

8-
import gtsam
99
import numpy as np
10-
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
11-
SfmMeasurementVector, SfmTrack2d)
1210
from gtsam.gtsfm import Keypoints
1311
from gtsam.utils.test_case import GtsamTestCase
1412

13+
import gtsam
14+
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
15+
SfmMeasurementVector, SfmTrack2d)
16+
1517

1618
class TestDsfTrackGenerator(GtsamTestCase):
1719
"""Tests for DsfTrackGenerator."""
1820

21+
def test_generate_tracks_from_pairwise_matches_nontransitive(
22+
self,
23+
) -> None:
24+
"""Tests DSF for non-transitive matches.
25+
26+
Test will result in no tracks since nontransitive tracks are naively
27+
discarded by DSF.
28+
"""
29+
keypoints = get_dummy_keypoints_list()
30+
nontransitive_matches = get_nontransitive_matches()
31+
32+
# For each image pair (i1,i2), we provide a (K,2) matrix
33+
# of corresponding keypoint indices (k1,k2).
34+
matches = MatchIndicesMap()
35+
for (i1, i2), correspondences in nontransitive_matches.items():
36+
matches[IndexPair(i1, i2)] = correspondences
37+
38+
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
39+
matches,
40+
keypoints,
41+
verbose=True,
42+
)
43+
self.assertEqual(len(tracks), 0, "Tracks not filtered correctly")
44+
1945
def test_track_generation(self) -> None:
2046
"""Ensures that DSF generates three tracks from measurements
2147
in 3 images (H=200,W=400)."""
2248
kps_i0 = Keypoints(np.array([[10.0, 20], [30, 40]]))
2349
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
2450
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))
2551

26-
keypoints_list = KeypointsVector()
27-
keypoints_list.append(kps_i0)
28-
keypoints_list.append(kps_i1)
29-
keypoints_list.append(kps_i2)
52+
keypoints = KeypointsVector()
53+
keypoints.append(kps_i0)
54+
keypoints.append(kps_i1)
55+
keypoints.append(kps_i2)
3056

3157
# For each image pair (i1,i2), we provide a (K,2) matrix
3258
# of corresponding image indices (k1,k2).
33-
matches_dict = MatchIndicesMap()
34-
matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
35-
matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
59+
matches = MatchIndicesMap()
60+
matches[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
61+
matches[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
3662

3763
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
38-
matches_dict,
39-
keypoints_list,
64+
matches,
65+
keypoints,
4066
verbose=False,
4167
)
4268
assert len(tracks) == 3
@@ -93,5 +119,71 @@ def test_sfm_track_2d_constructor(self) -> None:
93119
assert track.numberMeasurements() == 1
94120

95121

122+
def get_dummy_keypoints_list() -> KeypointsVector:
123+
"""Generate a list of dummy keypoints for testing."""
124+
img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]])
125+
img2_kp_coords = np.array(
126+
[
127+
[1, 1.],
128+
[2, 2],
129+
[3, 3],
130+
[4, 4],
131+
[5, 5],
132+
[6, 6],
133+
[7, 7],
134+
[8, 8],
135+
]
136+
)
137+
img3_kp_coords = np.array(
138+
[
139+
[1, 1.],
140+
[2, 2],
141+
[3, 3],
142+
[4, 4],
143+
[5, 5],
144+
[6, 6],
145+
[7, 7],
146+
[8, 8],
147+
[9, 9],
148+
[10, 10],
149+
]
150+
)
151+
img4_kp_coords = np.array(
152+
[
153+
[1, 1.],
154+
[2, 2],
155+
[3, 3],
156+
[4, 4],
157+
[5, 5],
158+
]
159+
)
160+
keypoints = KeypointsVector()
161+
keypoints.append(Keypoints(coordinates=img1_kp_coords))
162+
keypoints.append(Keypoints(coordinates=img2_kp_coords))
163+
keypoints.append(Keypoints(coordinates=img3_kp_coords))
164+
keypoints.append(Keypoints(coordinates=img4_kp_coords))
165+
return keypoints
166+
167+
168+
def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]:
169+
"""Set up correspondences for each (i1,i2) pair that violates transitivity.
170+
171+
(i=0, k=0) (i=0, k=1)
172+
| \\ |
173+
| \\ |
174+
(i=1, k=2)--(i=2,k=3)--(i=3, k=4)
175+
176+
Transitivity is violated due to the match between frames 0 and 3.
177+
"""
178+
nontransitive_matches = {
179+
(0, 1): np.array([[0, 2]]),
180+
(1, 2): np.array([[2, 3]]),
181+
(0, 2): np.array([[0, 3]]),
182+
(0, 3): np.array([[1, 4]]),
183+
(2, 3): np.array([[3, 4]]),
184+
}
185+
return nontransitive_matches
186+
187+
96188
if __name__ == "__main__":
97189
unittest.main()

0 commit comments

Comments
 (0)