Skip to content

Commit 432ed58

Browse files
authored
Gupta/fix serialized dataloaders (#445)
* add reset option to mcap dataloader * add reset option to rosbag dataloader * reset timestamps as well
1 parent 6ab7170 commit 432ed58

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

python/kiss_icp/datasets/mcap.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,18 @@ def __init__(self, data_dir: str, topic: str, *_, **__):
4242
# we expect `data_dir` param to be a path to the .mcap file, so rename for clarity
4343
assert os.path.isfile(data_dir), "mcap dataloader expects an existing MCAP file"
4444
self.sequence_id = os.path.basename(data_dir).split(".")[0]
45-
mcap_file = str(data_dir)
45+
self.mcap_file = str(data_dir)
4646

47-
self.bag = make_reader(open(mcap_file, "rb"))
47+
self.make_reader = make_reader
48+
self.read_ros2_messages = read_ros2_messages
49+
self.read_point_cloud = read_point_cloud
50+
51+
self.bag = self.make_reader(open(self.mcap_file, "rb"))
4852
self.summary = self.bag.get_summary()
4953
self.topic = self.check_topic(topic)
5054
self.n_scans = self._get_n_scans()
51-
self.msgs = read_ros2_messages(mcap_file, topics=[self.topic])
55+
self.msgs = self.read_ros2_messages(self.mcap_file, topics=[self.topic])
5256
self.timestamps = []
53-
self.read_point_cloud = read_point_cloud
5457
self.use_global_visualizer = True
5558

5659
def __del__(self):
@@ -72,6 +75,11 @@ def _get_n_scans(self) -> int:
7275
if self.summary.channels[id].topic == self.topic
7376
)
7477

78+
def reset(self):
79+
self.timestamps = []
80+
self.bag = self.make_reader(open(self.mcap_file, "rb"))
81+
self.msgs = self.read_ros2_messages(self.mcap_file, topics=[self.topic])
82+
7583
@staticmethod
7684
def stamp_to_sec(stamp):
7785
return stamp.sec + float(stamp.nanosec) / 1e9

python/kiss_icp/datasets/rosbag.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def __init__(self, data_dir: Path, topic: str, *_, **__):
6868
self.n_scans = self.bag.topics[self.topic].msgcount
6969

7070
# limit connections to selected topic
71-
connections = [x for x in self.bag.connections if x.topic == self.topic]
72-
self.msgs = self.bag.messages(connections=connections)
71+
self.connections = [x for x in self.bag.connections if x.topic == self.topic]
72+
self.msgs = self.bag.messages(connections=self.connections)
7373
self.timestamps = []
7474

7575
# Visualization Options
@@ -88,6 +88,12 @@ def __getitem__(self, idx):
8888
msg = self.bag.deserialize(rawdata, connection.msgtype)
8989
return self.read_point_cloud(msg)
9090

91+
def reset(self):
92+
self.timestamps = []
93+
self.bag.close()
94+
self.bag.open()
95+
self.msgs = self.bag.messages(connections=self.connections)
96+
9197
@staticmethod
9298
def to_sec(nsec: int):
9399
return float(nsec) / 1e9

0 commit comments

Comments
 (0)