Skip to content

Commit 9a59b46

Browse files
Add BagWriter for creating bag files (fixes #163) (#167)
* remove unnecessary import * added coding header * added BagWriter class * added filename argument * added write method * write version * added _write_header_record * added missing imports * outlined _write_header * implemented _write_header * implemented header and size for chunk records * truncate file before write * add call to _write_chunk_data * outlined _write_index * fix pycodestyle violations * outlined more of _write_chunk * maintain chunk lisT * outlined _write_connection_record * built the connection header * allow opcodes to be optional * added stub _get_connection * implemented _get_connection * added message definition to ConnectionInfo * tweak: use msg_format * tweak: obtain definition directly * added md5sum to ConnectionInfo * added header construction to _write_chunk_info_record * added basic _write_message * maintain index * implemented time tracking * added interface for _write_connection_index * implemented _write_connection_index * wrote connection count to chunk info record * report num connections * added chunk connections * added close method to BagWriter * added test_write
1 parent b906175 commit 9a59b46

File tree

4 files changed

+334
-8
lines changed

4 files changed

+334
-8
lines changed

src/roswire/bag/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
__all__ = ('BagReader',)
1+
__all__ = ('BagReader', 'BagWriter')
22

33
from .reader import BagReader
4+
from .writer import BagWriter

src/roswire/bag/reader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# -*- coding: utf-8 -*-
12
__all__ = ('BagReader',)
23

34
from typing import (Dict, Sequence, Union, Optional, Tuple, List, Type,
@@ -9,8 +10,6 @@
910
import logging
1011
import heapq
1112

12-
import attr
13-
1413
from .core import *
1514
from ..definitions.base import Time
1615
from ..definitions.msg import Message

src/roswire/bag/writer.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
This module provides a bag file writer that writes the contents of a bag to a
4+
binary file on disk.
5+
6+
See
7+
---
8+
https://github.com/ros/ros_comm/blob/melodic-devel/tools/rosbag/src/rosbag/bag.py
9+
"""
10+
__all__ = ('BagWriter',)
11+
12+
from typing import BinaryIO, Iterable, Dict, Type, Tuple
13+
14+
from .core import (BagMessage, OpCode, Compression, ConnectionInfo, Chunk,
15+
Index, IndexEntry, ChunkConnection)
16+
from ..definitions import Message
17+
from ..definitions.encode import *
18+
19+
BIN_CHUNK_INFO_VERSION = encode_uint32(1)
20+
BIN_INDEX_VERSION = encode_uint32(1)
21+
22+
23+
class BagWriter:
24+
"""
25+
Provides an interface for writing messages to a bag file on disk.
26+
27+
Attributes
28+
----------
29+
filename: str
30+
The name of the file to which the bag will be written.
31+
"""
32+
def __init__(self, fn: str) -> None:
33+
self.__fn = fn
34+
self.__fp: BinaryIO = open(fn, 'wb')
35+
self.__connections: Dict[str, ConnectionInfo] = {}
36+
self.__chunks: List[Chunk] = []
37+
self.__pos_header = 0
38+
self.__pos_chunks = 0
39+
self.__pos_index = 0
40+
41+
@property
42+
def filename(self) -> str:
43+
return self.__fn
44+
45+
def _write_header(self,
46+
code: Optional[OpCode],
47+
fields: Dict[str, bytes]
48+
) -> None:
49+
# write a length of zero for now and correct that length once the
50+
# fields have been written
51+
pos_size = self.__fp.tell()
52+
write_uint32(0, self.__fp)
53+
pos_content = self.__fp.tell()
54+
55+
if code:
56+
fields['op'] = code.value
57+
for name, bin_value in fields.items():
58+
bin_name = f'{name}='.encode('utf-8')
59+
size_field = len(bin_name) + len(bin_value)
60+
write_uint32(size_field, self.__fp)
61+
self.__fp.write(bin_name)
62+
self.__fp.write(bin_value)
63+
64+
# correct the length
65+
pos_end = self.__fp.tell()
66+
size_total = pos_end - pos_content
67+
self.__fp.seek(pos_size)
68+
write_uint32(size_total, self.__fp)
69+
self.__fp.seek(pos_end)
70+
71+
def _write_header_record(self) -> None:
72+
self.__fp.seek(self.__pos_header)
73+
self._write_header(OpCode.HEADER, {
74+
'index_pos': encode_uint64(self.__pos_index),
75+
'conn_count': encode_uint32(len(self.__connections)),
76+
'chunk_count': encode_uint32(len(self.__chunks))})
77+
78+
# ensure the bag header record is 4096 characters long by padding it
79+
# with ASCII space characters (0x20) where necessary.
80+
pos_current = self.__fp.tell()
81+
size = 4096
82+
size_header = pos_current - self.__pos_header
83+
size_padding = size - size_header - 4
84+
85+
write_uint32(size_padding, self.__fp)
86+
padding = b'\x20' * size_padding
87+
self.__fp.write(padding)
88+
89+
def _write_message(self,
90+
offset: int,
91+
index: Index,
92+
message: BagMessage,
93+
) -> None:
94+
typ = message.message.__class__
95+
connection = self._get_connection(message.topic, typ)
96+
97+
pos_header = self.__fp.tell()
98+
self._write_header(OpCode.MESSAGE_DATA, {
99+
'conn': encode_uint32(connection.conn),
100+
'time': encode_time(message.time)})
101+
102+
bin_data = message.message.encode()
103+
size_data = len(bin_data)
104+
write_uint32(size_data, self.__fp)
105+
self.__fp.write(bin_data)
106+
107+
# update index
108+
index_entry = IndexEntry(time=message.time,
109+
pos=pos_header,
110+
offset=offset)
111+
if connection.conn not in index:
112+
index[connection.conn] = []
113+
index[connection.conn].append(index_entry)
114+
115+
def _write_chunk_data(self, messages: Iterable[BagMessage]) -> Index:
116+
index: Index = {}
117+
pos_start = self.__fp.tell()
118+
pos_end = pos_start
119+
offset = 0
120+
for m in messages:
121+
self._write_message(offset, index, m)
122+
pos_end = self.__fp.tell()
123+
size_record = pos_end - pos_start
124+
offset += size_record
125+
pos_start = pos_end
126+
return index
127+
128+
def _write_chunk_record(self,
129+
compression: Compression,
130+
messages: Iterable[BagMessage]
131+
) -> Tuple[Chunk, Index]:
132+
bin_compression = compression.value.encode('utf-8')
133+
134+
# for now, we write a bogus header and size field
135+
# once we've finished writing the data, we'll correct them
136+
pos_header = self.__fp.tell()
137+
self._write_header(OpCode.CHUNK, {
138+
'compression': bin_compression,
139+
'size': encode_uint32(0)})
140+
write_uint32(0, self.__fp)
141+
pos_data = self.__fp.tell()
142+
143+
# write chunk contents
144+
index = self._write_chunk_data(messages)
145+
146+
# determine time of earliest and latest message in the bag
147+
time_start = time_end = Time(0, 0)
148+
for time in (e.time for ci in index.values() for e in ci):
149+
time_start = min(time, time_start)
150+
time_end = max(time, time_end)
151+
152+
# compute chunk size
153+
pos_end = self.__fp.tell()
154+
size_compressed = pos_end - pos_data
155+
size_uncompressed = size_compressed
156+
157+
# update header and size
158+
self.__fp.seek(pos_header)
159+
self._write_header(OpCode.CHUNK, {
160+
'compression': bin_compression,
161+
'size': encode_uint32(size_uncompressed)})
162+
write_uint32(size_compressed, self.__fp)
163+
self.__fp.seek(pos_end)
164+
165+
# build a description of the chunk
166+
conns = [ChunkConnection(conn, len(entries))
167+
for conn, entries in index.items()]
168+
chunk = Chunk(pos_record=pos_header, # type: ignore
169+
pos_data=pos_data,
170+
time_start=time_start,
171+
time_end=time_end,
172+
connections=conns,
173+
compression=compression,
174+
size_compressed=size_compressed,
175+
size_uncompressed=size_uncompressed)
176+
return chunk, index
177+
178+
def _write_connection_index(self,
179+
conn: int,
180+
entries: List[IndexEntry]
181+
) -> None:
182+
num_entries = len(entries)
183+
size_data = num_entries * 12
184+
self._write_header(OpCode.INDEX_DATA, {
185+
'ver': BIN_INDEX_VERSION,
186+
'conn': encode_uint32(conn),
187+
'count': encode_uint32(num_entries)})
188+
write_uint32(size_data, self.__fp)
189+
for entry in entries:
190+
write_time(entry.time, self.__fp)
191+
write_uint32(entry.offset, self.__fp)
192+
193+
def _write_chunk(self,
194+
compression: Compression,
195+
messages: Iterable[BagMessage]
196+
) -> None:
197+
# TODO for now, we only support uncompressed writing
198+
assert compression == Compression.NONE
199+
chunk, index = self._write_chunk_record(compression, messages)
200+
self.__chunks.append(chunk)
201+
for conn, entries in index.items():
202+
self._write_connection_index(conn, entries)
203+
204+
def _write_connection_record(self, conn: ConnectionInfo) -> None:
205+
self._write_header(OpCode.CONNECTION_INFO, {
206+
'conn': encode_uint32(conn.conn),
207+
'topic': conn.topic.encode('utf-8')})
208+
pos_size = self.__fp.tell()
209+
write_uint32(0, self.__fp)
210+
211+
# write the connection header
212+
header_conn: Dict[str, bytes] = {}
213+
header_conn['topic'] = conn.topic_original.encode('utf-8')
214+
header_conn['type'] = conn.typ.encode('utf-8')
215+
header_conn['md5sum'] = conn.md5sum.encode('utf-8')
216+
header_conn['message_definition'] = \
217+
conn.message_definition.encode('utf-8')
218+
if conn.callerid is not None:
219+
header_conn['callerid'] = conn.callerid.encode('utf-8')
220+
if conn.latching is not None:
221+
header_conn['latching'] = conn.latching.encode('utf-8')
222+
self._write_header(None, header_conn)
223+
224+
# update the record size
225+
pos_end = self.__fp.tell()
226+
size_data = pos_end - pos_size
227+
self.__fp.seek(pos_size)
228+
write_uint32(size_data, self.__fp)
229+
self.__fp.seek(pos_end)
230+
231+
def _get_connection(self,
232+
topic: str,
233+
typ: Type[Message]
234+
) -> ConnectionInfo:
235+
# if there isn't a connection for the topic, create one and write a
236+
# connection record.
237+
if topic not in self.__connections:
238+
msg_format = typ.format
239+
conn = len(self.__connections)
240+
info = ConnectionInfo(conn=conn,
241+
topic=topic,
242+
topic_original=topic,
243+
typ=msg_format.fullname,
244+
md5sum=typ.md5sum(),
245+
message_definition=msg_format.definition,
246+
callerid=None,
247+
latching=None)
248+
self.__connections[topic] = info
249+
self._write_connection_record(info)
250+
251+
return self.__connections[topic]
252+
253+
def _write_chunk_info_record(self, chunk: Chunk) -> None:
254+
num_connections = len(chunk.connections)
255+
pos_header = self.__fp.tell()
256+
self._write_header(OpCode.CHUNK_INFO, {
257+
'ver': BIN_CHUNK_INFO_VERSION,
258+
'chunk_pos': encode_uint64(chunk.pos_record),
259+
'start_time': encode_time(chunk.time_start),
260+
'end_time': encode_time(chunk.time_end),
261+
'count': encode_uint32(num_connections)})
262+
263+
size_data = num_connections * 8
264+
write_uint32(size_data, self.__fp)
265+
for connection in chunk.connections:
266+
write_uint32(connection.uid, self.__fp)
267+
write_uint32(connection.count, self.__fp)
268+
269+
def _write_index(self) -> None:
270+
for connection in self.__connections.values():
271+
self._write_connection_record(connection)
272+
for chunk in self.__chunks:
273+
self._write_chunk_info_record(chunk)
274+
275+
def write(self, messages: Iterable[BagMessage]) -> None:
276+
"""
277+
Writes a sequence of messages to the bag.
278+
Any existing bag file contents will be overwritten.
279+
"""
280+
self.__fp.truncate(0)
281+
self.__fp.write('#ROSBAG V2.0\n'.encode('utf-8'))
282+
283+
# create a placeholder header for now
284+
self.__pos_header = self.__fp.tell()
285+
self._write_header_record()
286+
287+
# for now, we write to a single, uncompressed chunk
288+
# each chunk record is followed by a sequence of IndexData record
289+
# - each connection in the chunk is represented by an IndexData record
290+
self.__pos_chunks = self.__fp.tell()
291+
self._write_chunk(Compression.NONE, messages)
292+
293+
# write index
294+
self.__pos_index = self.__fp.tell()
295+
self._write_index()
296+
297+
# fix the header
298+
self._write_header_record()
299+
300+
def close(self) -> None:
301+
self.__fp.close()

test/test_bag.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
22
import logging
3+
import tempfile
34

45
import yaml
56
import pytest
67

7-
from roswire.bag import BagReader
8+
from roswire.bag import BagReader, BagWriter
89
from roswire.definitions import TypeDatabase, FormatDatabase
910

1011
from test_basic import build_ardu
@@ -19,11 +20,8 @@ def load_mavros_type_db() -> TypeDatabase:
1920
return TypeDatabase.build(db_format)
2021

2122

22-
def test_from_file():
23-
db_type = load_mavros_type_db()
23+
def check_example_bag(db_type: TypeDatabase, bag: BagReader) -> None:
2424
typ_mavlink = db_type['mavros_msgs/Mavlink']
25-
fn_bag = os.path.join(DIR_TEST, 'example.bag')
26-
bag = BagReader(fn_bag, db_type)
2725
assert bag.header.index_pos == 189991
2826
assert bag.header.conn_count == 7
2927
assert bag.header.chunk_count == 1
@@ -34,3 +32,30 @@ def test_from_file():
3432
msgs = list(bag.read_messages(['/mavlink/from']))
3533
assert all(m.topic == '/mavlink/from' for m in msgs)
3634
assert all(isinstance(m.message, typ_mavlink) for m in msgs)
35+
36+
37+
def test_from_file():
38+
db_type = load_mavros_type_db()
39+
fn_bag = os.path.join(DIR_TEST, 'example.bag')
40+
bag = BagReader(fn_bag, db_type)
41+
check_example_bag(db_type, bag)
42+
43+
44+
def test_write():
45+
db_type = load_mavros_type_db()
46+
47+
# load the messages from the example bag
48+
reader = BagReader(os.path.join(DIR_TEST, 'example.bag'), db_type)
49+
messages = list(reader)
50+
51+
fn_bag = tempfile.mkstemp(suffix='.bag')[1]
52+
try:
53+
bag = BagWriter(fn_bag)
54+
bag.write(messages)
55+
bag.close()
56+
57+
# try to read the bag again
58+
reader = BagReader(os.path.join(DIR_TEST, 'example.bag'), db_type)
59+
check_example_bag(db_type, reader)
60+
finally:
61+
os.remove(fn_bag)

0 commit comments

Comments
 (0)