44"""
55
66import unittest
7- from typing import Dict , List , Tuple
7+ from typing import Dict , Tuple
88
99import numpy as np
1010from gtsam .gtsfm import Keypoints
1111from gtsam .utils .test_case import GtsamTestCase
1212
1313import gtsam
14- from gtsam import IndexPair , Point2 , SfmTrack2d
14+ from gtsam import (IndexPair , KeypointsVector , MatchIndicesMap , Point2 ,
15+ SfmMeasurementVector , SfmTrack2d )
1516
1617
1718class TestDsfTrackGenerator (GtsamTestCase ):
@@ -22,20 +23,21 @@ def test_generate_tracks_from_pairwise_matches_nontransitive(
2223 ) -> None :
2324 """Tests DSF for non-transitive matches.
2425
25- Test will result in no tracks since nontransitive tracks are naively discarded by DSF.
26+ Test will result in no tracks since nontransitive tracks are naively
27+ discarded by DSF.
2628 """
27- keypoints_list = get_dummy_keypoints_list ()
28- nontransitive_matches_dict = get_nontransitive_matches () # contains one non-transitive track
29+ keypoints = get_dummy_keypoints_list ()
30+ nontransitive_matches = get_nontransitive_matches ()
2931
3032 # For each image pair (i1,i2), we provide a (K,2) matrix
3133 # of corresponding keypoint indices (k1,k2).
32- matches_dict = {}
33- for (i1 ,i2 ), corr_idxs in nontransitive_matches_dict .items ():
34- matches_dict [IndexPair (i1 , i2 )] = corr_idxs
34+ matches = MatchIndicesMap ()
35+ for (i1 , i2 ), correspondences in nontransitive_matches .items ():
36+ matches [IndexPair (i1 , i2 )] = correspondences
3537
3638 tracks = gtsam .gtsfm .tracksFromPairwiseMatches (
37- matches_dict ,
38- keypoints_list ,
39+ matches ,
40+ keypoints ,
3941 verbose = True ,
4042 )
4143 self .assertEqual (len (tracks ), 0 , "Tracks not filtered correctly" )
@@ -47,20 +49,20 @@ def test_track_generation(self) -> None:
4749 kps_i1 = Keypoints (np .array ([[50.0 , 60 ], [70 , 80 ], [90 , 100 ]]))
4850 kps_i2 = Keypoints (np .array ([[110.0 , 120 ], [130 , 140 ]]))
4951
50- keypoints_list = []
51- keypoints_list .append (kps_i0 )
52- keypoints_list .append (kps_i1 )
53- keypoints_list .append (kps_i2 )
52+ keypoints = KeypointsVector ()
53+ keypoints .append (kps_i0 )
54+ keypoints .append (kps_i1 )
55+ keypoints .append (kps_i2 )
5456
5557 # For each image pair (i1,i2), we provide a (K,2) matrix
56- # of corresponding keypoint indices (k1,k2).
57- matches_dict = {}
58- matches_dict [IndexPair (0 , 1 )] = np .array ([[0 , 0 ], [1 , 1 ]])
59- matches_dict [IndexPair (1 , 2 )] = np .array ([[2 , 0 ], [1 , 1 ]])
58+ # of corresponding image indices (k1,k2).
59+ matches = MatchIndicesMap ()
60+ matches [IndexPair (0 , 1 )] = np .array ([[0 , 0 ], [1 , 1 ]])
61+ matches [IndexPair (1 , 2 )] = np .array ([[2 , 0 ], [1 , 1 ]])
6062
6163 tracks = gtsam .gtsfm .tracksFromPairwiseMatches (
62- matches_dict ,
63- keypoints_list ,
64+ matches ,
65+ keypoints ,
6466 verbose = False ,
6567 )
6668 assert len (tracks ) == 3
@@ -110,17 +112,16 @@ class TestSfmTrack2d(GtsamTestCase):
110112
111113 def test_sfm_track_2d_constructor (self ) -> None :
112114 """Test construction of 2D SfM track."""
113- measurements = []
115+ measurements = SfmMeasurementVector ()
114116 measurements .append ((0 , Point2 (10 , 20 )))
115117 track = SfmTrack2d (measurements = measurements )
116118 track .measurement (0 )
117119 assert track .numberMeasurements () == 1
118120
119121
120- def get_dummy_keypoints_list () -> List [ Keypoints ] :
121- """ """
122+ def get_dummy_keypoints_list () -> KeypointsVector :
123+ """Generate a list of dummy keypoints for testing. """
122124 img1_kp_coords = np .array ([[1 , 1 ], [2 , 2 ], [3 , 3. ]])
123- img1_kp_scale = np .array ([6.0 , 9.0 , 8.5 ])
124125 img2_kp_coords = np .array (
125126 [
126127 [1 , 1. ],
@@ -156,33 +157,32 @@ def get_dummy_keypoints_list() -> List[Keypoints]:
156157 [5 , 5 ],
157158 ]
158159 )
159- keypoints_list = [
160- Keypoints (coordinates = img1_kp_coords ),
161- Keypoints (coordinates = img2_kp_coords ),
162- Keypoints (coordinates = img3_kp_coords ),
163- Keypoints (coordinates = img4_kp_coords ),
164- ]
165- return keypoints_list
160+ keypoints = KeypointsVector ()
161+ keypoints .append (Keypoints (coordinates = img1_kp_coords ))
162+ keypoints .append (Keypoints (coordinates = img2_kp_coords ))
163+ keypoints .append (Keypoints (coordinates = img3_kp_coords ))
164+ keypoints .append (Keypoints (coordinates = img4_kp_coords ))
165+ return keypoints
166166
167167
168168def get_nontransitive_matches () -> Dict [Tuple [int , int ], np .ndarray ]:
169169 """Set up correspondences for each (i1,i2) pair that violates transitivity.
170-
170+
171171 (i=0, k=0) (i=0, k=1)
172172 | \\ |
173173 | \\ |
174174 (i=1, k=2)--(i=2,k=3)--(i=3, k=4)
175175
176- Transitivity is violated due to the match between frames 0 and 3.
176+ Transitivity is violated due to the match between frames 0 and 3.
177177 """
178- nontransitive_matches_dict = {
178+ nontransitive_matches = {
179179 (0 , 1 ): np .array ([[0 , 2 ]]),
180180 (1 , 2 ): np .array ([[2 , 3 ]]),
181181 (0 , 2 ): np .array ([[0 , 3 ]]),
182182 (0 , 3 ): np .array ([[1 , 4 ]]),
183183 (2 , 3 ): np .array ([[3 , 4 ]]),
184184 }
185- return nontransitive_matches_dict
185+ return nontransitive_matches
186186
187187
188188if __name__ == "__main__" :
0 commit comments