Skip to content

Commit 3a80932

Browse files
authored
Fix dataloaders (#76)
* fix dataloaders to return [frame, timestamp] always * remove unnecessary for loop * fix formatting * bump kiss-icp required version to allow proper deskewing
1 parent b92fefd commit 3a80932

File tree

13 files changed

+144
-112
lines changed

13 files changed

+144
-112
lines changed

cpp/map_closures/MapClosures.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@ void MapClosures::MatchAndAddToDatabase(const int id,
8484
self_matches.reserve(orb_keypoints.size());
8585
self_matcher.knnMatch(orb_descriptors, orb_descriptors, self_matches, 2);
8686

87-
std::for_each(orb_keypoints.begin(), orb_keypoints.end(), [&](cv::KeyPoint &keypoint) {
88-
keypoint.pt.x = keypoint.pt.x + static_cast<float>(density_map.lower_bound.y());
89-
keypoint.pt.y = keypoint.pt.y + static_cast<float>(density_map.lower_bound.x());
90-
});
9187
density_maps_.emplace(id, std::move(density_map));
9288
ground_alignments_.emplace(id, T_ground);
9389

@@ -96,7 +92,9 @@ void MapClosures::MatchAndAddToDatabase(const int id,
9692
std::for_each(self_matches.cbegin(), self_matches.cend(), [&](const auto &self_match) {
9793
if (self_match[1].distance > self_similarity_threshold) {
9894
const auto index_descriptor = self_match[0].queryIdx;
99-
const auto &keypoint = orb_keypoints[index_descriptor];
95+
cv::KeyPoint keypoint = orb_keypoints[index_descriptor];
96+
keypoint.pt.x = keypoint.pt.x + static_cast<float>(density_map.lower_bound.y());
97+
keypoint.pt.y = keypoint.pt.y + static_cast<float>(density_map.lower_bound.x());
10098
hbst_matchable.emplace_back(
10199
new Matchable(keypoint, orb_descriptors.row(index_descriptor), id));
102100
}

python/map_closures/datasets/apollo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __len__(self):
4949
return len(self.scan_files)
5050

5151
def __getitem__(self, idx):
52-
return self.get_scan(self.scan_files[idx])
52+
return self.get_scan(self.scan_files[idx]), np.array([])
5353

5454
def get_scan(self, scan_file: str):
5555
points = np.asarray(self.o3d.io.read_point_cloud(scan_file).points, dtype=np.float64)

python/map_closures/datasets/generic.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def __getitem__(self, idx):
6161
return self.read_point_cloud(self.scan_files[idx])
6262

6363
def read_point_cloud(self, file_path: str):
64-
points = self._read_point_cloud(file_path)
65-
return points.astype(np.float64)
64+
points, timestamps = self._read_point_cloud(file_path)
65+
return points.astype(np.float64), timestamps.astype(np.float64)
6666

6767
def _get_point_cloud_reader(self):
6868
"""Attempt to guess with try/catch blocks which is the best point cloud reader to use for
@@ -75,34 +75,77 @@ def _get_point_cloud_reader(self):
7575
# This is easy, the old KITTI format
7676
if self.file_extension == "bin":
7777
print("[WARNING] Reading .bin files, the only format supported is the KITTI format")
78-
return lambda file: np.fromfile(file, dtype=np.float32).reshape((-1, 4))[:, :3]
78+
79+
class ReadKITTI:
80+
def __call__(self, file):
81+
return np.fromfile(file, dtype=np.float32).reshape((-1, 4))[:, :3], np.array([])
82+
83+
return ReadKITTI()
7984

8085
first_scan_file = self.scan_files[0]
8186

82-
# first try trimesh
87+
# first try open3d
88+
try:
89+
import open3d as o3d
90+
91+
try_pcd = o3d.t.io.read_point_cloud(first_scan_file)
92+
if try_pcd.is_empty():
93+
# open3d binding does not raise an exception if file is unreadable or extension is not supported
94+
raise Exception("Generic Dataloader| Open3d PointCloud file is empty")
95+
96+
stamps_keys = ["t", "timestamp", "timestamps", "time", "stamps"]
97+
stamp_field = None
98+
for key in stamps_keys:
99+
try:
100+
try_pcd.point[key]
101+
stamp_field = key
102+
print("Generic Dataloader| found timestamps")
103+
break
104+
except:
105+
continue
106+
107+
class ReadOpen3d:
108+
def __init__(self, time_field):
109+
self.time_field = time_field
110+
if self.time_field is None:
111+
self.get_timestamps = lambda _: np.array([])
112+
else:
113+
self.get_timestamps = lambda pcd: pcd.point[self.time_field].numpy().ravel()
114+
115+
def __call__(self, file):
116+
pcd = o3d.t.io.read_point_cloud(file)
117+
points = pcd.point.positions.numpy()
118+
return points, self.get_timestamps(pcd)
119+
120+
return ReadOpen3d(stamp_field)
121+
except:
122+
pass
123+
83124
try:
84125
import trimesh
85126

86127
trimesh.load(first_scan_file)
87-
return lambda file: np.asarray(trimesh.load(file).vertices)
128+
129+
class ReadTriMesh:
130+
def __call__(self, file):
131+
return np.asarray(trimesh.load(file).vertices), np.array([])
132+
133+
return ReadTriMesh()
88134
except:
89135
pass
90136

91-
# then try pyntcloud
92137
try:
93138
from pyntcloud import PyntCloud
94139

95140
PyntCloud.from_file(first_scan_file)
96-
return lambda file: PyntCloud.from_file(file).points[["x", "y", "z"]].to_numpy()
97-
except:
98-
pass
99141

100-
# lastly with open3d
101-
try:
102-
import open3d as o3d
142+
class ReadPynt:
143+
def __call__(self, file):
144+
return PyntCloud.from_file(file).points[["x", "y", "z"]].to_numpy(), np.array(
145+
[]
146+
)
103147

104-
o3d.io.read_point_cloud(first_scan_file)
105-
return lambda file: np.asarray(o3d.io.read_point_cloud(file).points, dtype=np.float64)
148+
return ReadPynt()
106149
except:
107150
print("[ERROR], File format not supported")
108151
sys.exit(1)

python/map_closures/datasets/helipr.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,20 @@ def __init__(self, data_dir: Path, sequence: str, *_, **__):
4848
if self.sequence_id == "Avia":
4949
self.format_string = "fffBBBL"
5050
self.intensity_channel = None
51+
self.time_channel = 6
5152
elif self.sequence_id == "Aeva":
5253
self.format_string = "ffffflBf"
5354
self.format_string_no_intensity = "ffffflB"
5455
self.intensity_channel = 7
56+
self.time_channel = 5
5557
elif self.sequence_id == "Ouster":
5658
self.format_string = "ffffIHHH"
5759
self.intensity_channel = 3
60+
self.time_channel = 4
5861
elif self.sequence_id == "Velodyne":
5962
self.format_string = "ffffHf"
6063
self.intensity_channel = 3
64+
self.time_channel = 5
6165
else:
6266
print("[ERROR] Unsupported LiDAR Type")
6367
sys.exit(1)
@@ -66,7 +70,10 @@ def __len__(self):
6670
return len(self.scan_files)
6771

6872
def __getitem__(self, idx):
69-
return self.read_point_cloud(idx)
73+
data = self.get_data(idx)
74+
points = self.read_point_cloud(data)
75+
timestamps = self.read_timestamps(data)
76+
return points, timestamps
7077

7178
def get_data(self, idx: int):
7279
file_path = self.scan_files[idx]
@@ -89,16 +96,17 @@ def get_data(self, idx: int):
8996
data = np.stack(list_lines)
9097
return data
9198

92-
def read_point_cloud(self, idx: int):
93-
data = self.get_data(idx)
94-
points = data[:, :3]
95-
return points.astype(np.float64)
99+
def read_timestamps(self, data: np.ndarray) -> np.ndarray:
100+
time = data[:, self.time_channel]
101+
return (time - time.min()) / (time.max() - time.min())
102+
103+
def read_point_cloud(self, data: np.ndarray) -> np.ndarray:
104+
return data[:, :3]
96105

97106
def load_poses(self, poses_file):
98107
from pyquaternion import Quaternion
99108

100109
poses = np.loadtxt(poses_file, delimiter=" ")
101-
n = poses.shape[0]
102110

103111
xyz = poses[:, 1:4]
104112
rotations = np.array(

python/map_closures/datasets/kitti.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __len__(self):
5454
return len(self.scan_files)
5555

5656
def scans(self, idx):
57-
return self.read_point_cloud(self.scan_files[idx])
57+
return self.read_point_cloud(self.scan_files[idx]), np.array([])
5858

5959
def apply_calibration(self, poses: np.ndarray) -> np.ndarray:
6060
"""Converts from Velodyne to Camera Frame"""

python/map_closures/datasets/mcap.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,18 @@ def __init__(self, data_dir: str, topic: str, *_, **__):
4040
# we expect `data_dir` param to be a path to the .mcap file, so rename for clarity
4141
assert os.path.isfile(data_dir), "mcap dataloader expects an existing MCAP file"
4242
self.sequence_id = os.path.basename(data_dir).split(".")[0]
43-
mcap_file = str(data_dir)
43+
self.mcap_file = str(data_dir)
4444

45-
self.bag = make_reader(open(mcap_file, "rb"))
45+
self.make_reader = make_reader
46+
self.read_ros2_messages = read_ros2_messages
47+
self.read_point_cloud = read_point_cloud
48+
49+
self.bag = self.make_reader(open(self.mcap_file, "rb"))
4650
self.summary = self.bag.get_summary()
4751
self.topic = self.check_topic(topic)
4852
self.n_scans = self._get_n_scans()
49-
self.msgs = read_ros2_messages(mcap_file, topics=topic)
50-
self.read_point_cloud = read_point_cloud
53+
self.msgs = self.read_ros2_messages(self.mcap_file, topics=[self.topic])
54+
self.timestamps = []
5155
self.use_global_visualizer = True
5256

5357
def __del__(self):
@@ -56,6 +60,7 @@ def __del__(self):
5660

5761
def __getitem__(self, idx):
5862
msg = next(self.msgs).ros_msg
63+
self.timestamps.append(self.stamp_to_sec(msg.header.stamp))
5964
return self.read_point_cloud(msg)
6065

6166
def __len__(self):
@@ -68,6 +73,18 @@ def _get_n_scans(self) -> int:
6873
if self.summary.channels[id].topic == self.topic
6974
)
7075

76+
def reset(self):
77+
self.timestamps = []
78+
self.bag = self.make_reader(open(self.mcap_file, "rb"))
79+
self.msgs = self.read_ros2_messages(self.mcap_file, topics=[self.topic])
80+
81+
@staticmethod
82+
def stamp_to_sec(stamp):
83+
return stamp.sec + float(stamp.nanosec) / 1e9
84+
85+
def get_frames_timestamps(self) -> list:
86+
return self.timestamps
87+
7188
def check_topic(self, topic: str) -> str:
7289
# Extract schema id from the .mcap file that encodes the PointCloud2 msg
7390
schema_id = [

python/map_closures/datasets/mulran.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def read_point_cloud(self, file_path: str):
4848
timestamps = self.get_timestamps()
4949
if points.shape[0] != timestamps.shape[0]:
5050
# MuRan has some broken point clouds, just fallback to no timestamps
51-
return points.astype(np.float64), np.ones(points.shape[0])
51+
return points.astype(np.float64), np.array([])
5252
return points.astype(np.float64), timestamps
5353

5454
@staticmethod

python/map_closures/datasets/ncd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def getitem(self, scan_file: str):
6363
timestamps = self.get_timestamps()
6464
if points.shape[0] != timestamps.shape[0]:
6565
# MuRan has some broken point clouds, just fallback to no timestamps
66-
return points.astype(np.float64), np.ones(points.shape[0])
66+
return points.astype(np.float64), np.array([])
6767
return points.astype(np.float64), timestamps
6868

6969
@staticmethod

python/map_closures/datasets/ouster.py

Lines changed: 15 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,12 @@
2222
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2323
# SOFTWARE.
2424

25-
import glob
2625
import os
2726
from typing import Optional
2827

2928
import numpy as np
3029

3130

32-
def find_metadata_json(pcap_file: str) -> str:
33-
"""Attempts to resolve the metadata json file for a provided pcap file."""
34-
dir_path, filename = os.path.split(pcap_file)
35-
if not filename:
36-
return ""
37-
if not dir_path:
38-
dir_path = os.getcwd()
39-
json_candidates = sorted(glob.glob(f"{dir_path}/*.json"))
40-
if not json_candidates:
41-
return ""
42-
prefix_sizes = list(
43-
map(lambda p: len(os.path.commonprefix((filename, os.path.basename(p)))), json_candidates)
44-
)
45-
max_elem = max(range(len(prefix_sizes)), key=lambda i: prefix_sizes[i])
46-
return json_candidates[max_elem]
47-
48-
4931
class OusterDataloader:
5032
"""Ouster pcap dataloader"""
5133

@@ -83,64 +65,42 @@ def __init__(
8365
"""
8466

8567
try:
86-
import ouster.pcap as pcap
87-
from ouster import client
68+
from ouster.sdk import client, open_source
8869
except ImportError:
89-
print(
90-
f'[ERROR] ouster-sdk is not installed on your system, run "pip install ouster-sdk"'
91-
)
70+
print(f'ouster-sdk is not installed on your system, run "pip install ouster-sdk"')
9271
exit(1)
9372

94-
# since we import ouster-sdk's client module locally, we keep it locally as well
95-
self._client = client
96-
9773
assert os.path.isfile(data_dir), "Ouster pcap dataloader expects an existing PCAP file"
9874

9975
# we expect `data_dir` param to be a path to the .pcap file, so rename for clarity
10076
pcap_file = data_dir
10177

102-
metadata_json = meta or find_metadata_json(pcap_file)
103-
if not metadata_json:
104-
print("[ERROR] Ouster pcap dataloader can't find metadata json file.")
105-
exit(1)
106-
print("[INFO] Ouster pcap dataloader: using metadata json: ", metadata_json)
78+
print("Indexing Ouster pcap to count the scans number ...")
79+
source = open_source(str(pcap_file), meta=[meta] if meta else [], index=True)
10780

108-
self.data_dir = os.path.dirname(data_dir)
81+
# since we import ouster-sdk's client module locally, we keep reference
82+
# to it locally as well
83+
self._client = client
10984

110-
with open(metadata_json) as json:
111-
self._info_json = json.read()
112-
self._info = client.SensorInfo(self._info_json)
85+
self.data_dir = os.path.dirname(data_dir)
11386

11487
# lookup table for 2D range image projection to a 3D point cloud
115-
self._xyz_lut = client.XYZLut(self._info)
88+
self._xyz_lut = client.XYZLut(source.metadata)
11689

11790
self._pcap_file = str(data_dir)
11891

119-
# read pcap file for the first pass to count scans
120-
print("[INFO] Pre-reading Ouster pcap to count the scans number ...")
121-
self._source = pcap.Pcap(self._pcap_file, self._info)
122-
self._scans_num = sum((1 for _ in client.Scans(self._source)))
123-
print(f"[INFO] Ouster pcap total scans number: {self._scans_num}")
92+
self._scans_num = len(source)
93+
print(f"Ouster pcap total scans number: {self._scans_num}")
12494

12595
# frame timestamps array
12696
self._timestamps = np.linspace(0, self._scans_num, self._scans_num, endpoint=False)
12797

128-
# start Scans iterator for consumption in __getitem__
129-
self._source = pcap.Pcap(self._pcap_file, self._info)
130-
self._scans_iter = iter(client.Scans(self._source))
131-
self._next_idx = 0
98+
self._source = source
13299

133100
def __getitem__(self, idx):
134-
# we assume that users always reads sequentially and do not
135-
# pass idx as for a random access collection
136-
assert self._next_idx == idx, (
137-
"Ouster pcap dataloader supports only sequential reads. "
138-
f"Expected idx: {self._next_idx}, but got {idx}"
139-
)
140-
scan = next(self._scans_iter)
141-
self._next_idx += 1
142-
143-
self._timestamps[self._next_idx - 1] = 1e-9 * scan.timestamp[0]
101+
scan = self._source[idx]
102+
103+
self._timestamps[idx] = 1e-9 * scan.timestamp[0]
144104

145105
timestamps = np.tile(np.linspace(0, 1.0, scan.w, endpoint=False), (scan.h, 1))
146106

0 commit comments

Comments
 (0)