Skip to content

Commit 2f2d654

Browse files
committed
Fix 4.3 python style
1 parent 12f919d commit 2f2d654

File tree

1 file changed

+35
-35
lines changed

1 file changed

+35
-35
lines changed

python/gtsam/tests/test_DsfTrackGenerator.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
"""
55

66
import unittest
7-
from typing import Dict, List, Tuple
7+
from typing import Dict, Tuple
88

99
import numpy as np
1010
from gtsam.gtsfm import Keypoints
1111
from gtsam.utils.test_case import GtsamTestCase
1212

1313
import gtsam
14-
from gtsam import IndexPair, Point2, SfmTrack2d
14+
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
15+
SfmMeasurementVector, SfmTrack2d)
1516

1617

1718
class TestDsfTrackGenerator(GtsamTestCase):
@@ -22,20 +23,21 @@ def test_generate_tracks_from_pairwise_matches_nontransitive(
2223
) -> None:
2324
"""Tests DSF for non-transitive matches.
2425
25-
Test will result in no tracks since nontransitive tracks are naively discarded by DSF.
26+
Test will result in no tracks since nontransitive tracks are naively
27+
discarded by DSF.
2628
"""
27-
keypoints_list = get_dummy_keypoints_list()
28-
nontransitive_matches_dict = get_nontransitive_matches() # contains one non-transitive track
29+
keypoints = get_dummy_keypoints_list()
30+
nontransitive_matches = get_nontransitive_matches()
2931

3032
# For each image pair (i1,i2), we provide a (K,2) matrix
3133
# of corresponding keypoint indices (k1,k2).
32-
matches_dict = {}
33-
for (i1,i2), corr_idxs in nontransitive_matches_dict.items():
34-
matches_dict[IndexPair(i1, i2)] = corr_idxs
34+
matches = MatchIndicesMap()
35+
for (i1, i2), correspondences in nontransitive_matches.items():
36+
matches[IndexPair(i1, i2)] = correspondences
3537

3638
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
37-
matches_dict,
38-
keypoints_list,
39+
matches,
40+
keypoints,
3941
verbose=True,
4042
)
4143
self.assertEqual(len(tracks), 0, "Tracks not filtered correctly")
@@ -47,20 +49,20 @@ def test_track_generation(self) -> None:
4749
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
4850
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))
4951

50-
keypoints_list = []
51-
keypoints_list.append(kps_i0)
52-
keypoints_list.append(kps_i1)
53-
keypoints_list.append(kps_i2)
52+
keypoints = KeypointsVector()
53+
keypoints.append(kps_i0)
54+
keypoints.append(kps_i1)
55+
keypoints.append(kps_i2)
5456

5557
# For each image pair (i1,i2), we provide a (K,2) matrix
56-
# of corresponding keypoint indices (k1,k2).
57-
matches_dict = {}
58-
matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
59-
matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
58+
# of corresponding image indices (k1,k2).
59+
matches = MatchIndicesMap()
60+
matches[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
61+
matches[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
6062

6163
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
62-
matches_dict,
63-
keypoints_list,
64+
matches,
65+
keypoints,
6466
verbose=False,
6567
)
6668
assert len(tracks) == 3
@@ -110,17 +112,16 @@ class TestSfmTrack2d(GtsamTestCase):
110112

111113
def test_sfm_track_2d_constructor(self) -> None:
112114
"""Test construction of 2D SfM track."""
113-
measurements = []
115+
measurements = SfmMeasurementVector()
114116
measurements.append((0, Point2(10, 20)))
115117
track = SfmTrack2d(measurements=measurements)
116118
track.measurement(0)
117119
assert track.numberMeasurements() == 1
118120

119121

120-
def get_dummy_keypoints_list() -> List[Keypoints]:
121-
""" """
122+
def get_dummy_keypoints_list() -> KeypointsVector:
123+
"""Generate a list of dummy keypoints for testing."""
122124
img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]])
123-
img1_kp_scale = np.array([6.0, 9.0, 8.5])
124125
img2_kp_coords = np.array(
125126
[
126127
[1, 1.],
@@ -156,33 +157,32 @@ def get_dummy_keypoints_list() -> List[Keypoints]:
156157
[5, 5],
157158
]
158159
)
159-
keypoints_list = [
160-
Keypoints(coordinates=img1_kp_coords),
161-
Keypoints(coordinates=img2_kp_coords),
162-
Keypoints(coordinates=img3_kp_coords),
163-
Keypoints(coordinates=img4_kp_coords),
164-
]
165-
return keypoints_list
160+
keypoints = KeypointsVector()
161+
keypoints.append(Keypoints(coordinates=img1_kp_coords))
162+
keypoints.append(Keypoints(coordinates=img2_kp_coords))
163+
keypoints.append(Keypoints(coordinates=img3_kp_coords))
164+
keypoints.append(Keypoints(coordinates=img4_kp_coords))
165+
return keypoints
166166

167167

168168
def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]:
169169
"""Set up correspondences for each (i1,i2) pair that violates transitivity.
170-
170+
171171
(i=0, k=0) (i=0, k=1)
172172
| \\ |
173173
| \\ |
174174
(i=1, k=2)--(i=2,k=3)--(i=3, k=4)
175175
176-
Transitivity is violated due to the match between frames 0 and 3.
176+
Transitivity is violated due to the match between frames 0 and 3.
177177
"""
178-
nontransitive_matches_dict = {
178+
nontransitive_matches = {
179179
(0, 1): np.array([[0, 2]]),
180180
(1, 2): np.array([[2, 3]]),
181181
(0, 2): np.array([[0, 3]]),
182182
(0, 3): np.array([[1, 4]]),
183183
(2, 3): np.array([[3, 4]]),
184184
}
185-
return nontransitive_matches_dict
185+
return nontransitive_matches
186186

187187

188188
if __name__ == "__main__":

0 commit comments

Comments
 (0)