Skip to content

Commit 3283925

Browse files
authored
Merge pull request #129 from mariuz/master
Add wire compression support for protocol version over 13
2 parents d68e159 + da01672 commit 3283925

File tree

6 files changed

+168
-14
lines changed

6 files changed

+168
-14
lines changed

firebirdsql/aio/fbcore.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,10 @@ async def _async_parse_connect_response(self):
633633
self.accept_type = bytes_to_bint(b[8:])
634634
self.lazy_response_count = 0
635635

636+
if self.accept_type & pflag_compress:
637+
self.sock.enable_compression()
638+
self.accept_type &= ptype_MASK
639+
636640
if op_code == self.op_cond_accept or op_code == self.op_accept_data:
637641
ln = bytes_to_bint(await self._async_recv_channel(4))
638642
data = await self._async_recv_channel(ln, word_alignment=True)
@@ -918,7 +922,7 @@ async def _initialize(self):
918922

919923
self.sock = AsyncSocketStream(self.hostname, self.port, self.loop, self.timeout, self.cloexec)
920924

921-
self._op_connect(self.auth_plugin_name, self.wire_crypt)
925+
self._op_connect(self.auth_plugin_name, self.wire_crypt, self.wire_compress)
922926
try:
923927
await self._async_parse_connect_response()
924928
except OperationalError as e:

firebirdsql/aio/stream.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
# Python DB-API 2.0 module for Firebird.
2727
##############################################################################
2828
import asyncio
29+
import zlib
2930

3031
from firebirdsql.stream import SocketStream
3132
from firebirdsql.utils import bytes_to_bint
@@ -48,12 +49,22 @@ async def _await_pending_send(self):
4849
async def async_recv(self, nbytes):
4950
await self._await_pending_send()
5051

51-
if len(self._buf) < nbytes:
52-
read_size = max(8192, nbytes - len(self._buf))
53-
chunk = await self.loop.sock_recv(self._sock, read_size)
54-
if self.read_translator:
55-
chunk = self.read_translator.decrypt(chunk)
56-
self._buf += chunk
52+
if self._decompressor:
53+
while len(self._buf) < nbytes:
54+
read_size = max(8192, nbytes - len(self._buf))
55+
chunk = await self.loop.sock_recv(self._sock, read_size)
56+
if not chunk:
57+
break
58+
if self.read_translator:
59+
chunk = self.read_translator.decrypt(chunk)
60+
self._buf += self._decompressor.decompress(chunk)
61+
else:
62+
if len(self._buf) < nbytes:
63+
read_size = max(8192, nbytes - len(self._buf))
64+
chunk = await self.loop.sock_recv(self._sock, read_size)
65+
if self.read_translator:
66+
chunk = self.read_translator.decrypt(chunk)
67+
self._buf += chunk
5768

5869
ret = self._buf[:nbytes]
5970
self._buf = self._buf[nbytes:]
@@ -62,6 +73,8 @@ async def async_recv(self, nbytes):
6273
def send(self, b):
6374
if not self.loop.is_running():
6475
return super().send(b)
76+
if self._compressor:
77+
b = self._compressor.compress(b) + self._compressor.flush(zlib.Z_SYNC_FLUSH)
6578
if self.write_translator:
6679
b = self.write_translator.encrypt(b)
6780

firebirdsql/fbcore.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,10 @@ def _parse_connect_response(self):
676676
self.accept_type = bytes_to_bint(b[8:])
677677
self.lazy_response_count = 0
678678

679+
if self.accept_type & pflag_compress:
680+
self.sock.enable_compression()
681+
self.accept_type &= ptype_MASK
682+
679683
if op_code == self.op_cond_accept or op_code == self.op_accept_data:
680684
ln = bytes_to_bint(self._recv_channel(4))
681685
data = self._recv_channel(ln, word_alignment=True)
@@ -950,7 +954,7 @@ def __init__(
950954
page_size=4096, is_services=False, cloexec=False,
951955
timeout=None, isolation_level=None,
952956
auth_plugin_name=None, wire_crypt=True, create_new=False,
953-
timezone=None
957+
timezone=None, wire_compress=False
954958
):
955959
DEBUG_OUTPUT("Connection::__init__()", id(self))
956960
self.accept_plugin_name = ''
@@ -967,6 +971,7 @@ def __init__(
967971
self.timeout = float(timeout) if timeout is not None else None
968972
self.auth_plugin_name = auth_plugin_name
969973
self.wire_crypt = wire_crypt
974+
self.wire_compress = wire_compress
970975
self.create_new = create_new
971976
self.page_size = page_size
972977
self.is_services = is_services
@@ -986,7 +991,7 @@ def _initialize(self):
986991

987992
self.sock = SocketStream(self.hostname, self.port, self.timeout, self.cloexec)
988993

989-
self._op_connect(self.auth_plugin_name, self.wire_crypt)
994+
self._op_connect(self.auth_plugin_name, self.wire_crypt, self.wire_compress)
990995
try:
991996
self._parse_connect_response()
992997
except OperationalError as e:

firebirdsql/stream.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
# Python DB-API 2.0 module for Firebird.
2727
##############################################################################
2828
import socket
29+
import zlib
2930

3031
try:
3132
import fcntl
@@ -50,14 +51,39 @@ def __init__(self, host, port, timeout=None, cloexec=False):
5051
self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
5152
self.read_translator = None
5253
self.write_translator = None
54+
self._compressor = None
55+
self._decompressor = None
56+
self._recv_buf = b''
57+
58+
def enable_compression(self):
59+
"""Enable zlib wire compression for the stream.
60+
Called after the server accepts compression during protocol negotiation.
61+
"""
62+
self._compressor = zlib.compressobj()
63+
self._decompressor = zlib.decompressobj()
64+
self._recv_buf = b''
5365

5466
def recv(self, nbytes):
55-
b = self._sock.recv(nbytes)
56-
if self.read_translator:
57-
b = self.read_translator.decrypt(b)
58-
return b
67+
if self._decompressor:
68+
while len(self._recv_buf) < nbytes:
69+
b = self._sock.recv(max(nbytes, 8192))
70+
if not b:
71+
break
72+
if self.read_translator:
73+
b = self.read_translator.decrypt(b)
74+
self._recv_buf += self._decompressor.decompress(b)
75+
result = self._recv_buf[:nbytes]
76+
self._recv_buf = self._recv_buf[nbytes:]
77+
return result
78+
else:
79+
b = self._sock.recv(nbytes)
80+
if self.read_translator:
81+
b = self.read_translator.decrypt(b)
82+
return b
5983

6084
def send(self, b):
85+
if self._compressor:
86+
b = self._compressor.compress(b) + self._compressor.flush(zlib.Z_SYNC_FLUSH)
6187
if self.write_translator:
6288
b = self.write_translator.encrypt(b)
6389
n = 0
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import unittest
2+
import zlib
3+
from unittest.mock import MagicMock, patch
4+
from firebirdsql.stream import SocketStream
5+
from firebirdsql.consts import pflag_compress, ptype_MASK, ptype_lazy_send
6+
7+
8+
class TestCompressionStream(unittest.TestCase):
9+
"""Test zlib compression/decompression in SocketStream."""
10+
11+
def test_roundtrip_compression(self):
12+
"""Test that data compressed by compressor can be decompressed by decompressor."""
13+
compressor = zlib.compressobj()
14+
decompressor = zlib.decompressobj()
15+
16+
original = b'Hello, Firebird wire protocol compression!'
17+
compressed = compressor.compress(original) + compressor.flush(zlib.Z_SYNC_FLUSH)
18+
decompressed = decompressor.decompress(compressed)
19+
self.assertEqual(decompressed, original)
20+
21+
def test_streaming_compression(self):
22+
"""Test that streaming compression works across multiple messages."""
23+
compressor = zlib.compressobj()
24+
decompressor = zlib.decompressobj()
25+
26+
messages = [b'first message', b'second message', b'third message with more data']
27+
recovered = []
28+
for msg in messages:
29+
compressed = compressor.compress(msg) + compressor.flush(zlib.Z_SYNC_FLUSH)
30+
decompressed = decompressor.decompress(compressed)
31+
recovered.append(decompressed)
32+
33+
self.assertEqual(recovered, messages)
34+
35+
def test_compression_with_encryption(self):
36+
"""Test that compression + encryption layering works correctly.
37+
Send order: compress -> encrypt. Receive order: decrypt -> decompress."""
38+
from firebirdsql.arc4 import ARC4
39+
40+
compressor = zlib.compressobj()
41+
decompressor = zlib.decompressobj()
42+
enc = ARC4.new(b'test_key')
43+
dec = ARC4.new(b'test_key')
44+
45+
original = b'Test data for compress+encrypt round-trip'
46+
# Compress then encrypt
47+
compressed = compressor.compress(original) + compressor.flush(zlib.Z_SYNC_FLUSH)
48+
encrypted = enc.translate(compressed)
49+
# Decrypt then decompress
50+
decrypted = dec.translate(encrypted)
51+
decompressed = decompressor.decompress(decrypted)
52+
self.assertEqual(decompressed, original)
53+
54+
def test_pflag_compress_detection(self):
55+
"""Test that pflag_compress flag is correctly detected and stripped."""
56+
accept_type = ptype_lazy_send | pflag_compress # 5 | 0x100 = 0x105
57+
self.assertTrue(accept_type & pflag_compress)
58+
stripped = accept_type & ptype_MASK
59+
self.assertEqual(stripped, ptype_lazy_send)
60+
61+
def test_pflag_compress_not_set(self):
62+
"""Test that missing pflag_compress is correctly detected."""
63+
accept_type = ptype_lazy_send # 5
64+
self.assertFalse(accept_type & pflag_compress)
65+
66+
def test_enable_compression_sets_compressor(self):
67+
"""Test that enable_compression initializes zlib objects."""
68+
with patch('socket.create_connection') as mock_conn:
69+
mock_sock = MagicMock()
70+
mock_conn.return_value = mock_sock
71+
stream = SocketStream('localhost', 3050)
72+
self.assertIsNone(stream._compressor)
73+
self.assertIsNone(stream._decompressor)
74+
stream.enable_compression()
75+
self.assertIsNotNone(stream._compressor)
76+
self.assertIsNotNone(stream._decompressor)
77+
78+
def test_large_data_compression(self):
79+
"""Test compression with larger data payloads."""
80+
compressor = zlib.compressobj()
81+
decompressor = zlib.decompressobj()
82+
83+
# Simulate a large query result
84+
original = b'A' * 100000
85+
compressed = compressor.compress(original) + compressor.flush(zlib.Z_SYNC_FLUSH)
86+
# Compressed size should be significantly smaller for repetitive data
87+
self.assertLess(len(compressed), len(original))
88+
decompressed = decompressor.decompress(compressed)
89+
self.assertEqual(decompressed, original)
90+
91+
92+
if __name__ == '__main__':
93+
unittest.main()

firebirdsql/wireprotocol.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def pack_cnct_param(k, v):
365365
return r
366366

367367
@wire_operation
368-
def _op_connect(self, auth_plugin_name, wire_crypt):
368+
def _op_connect(self, auth_plugin_name, wire_crypt, wire_compress=False):
369369
protocols = [
370370
# PROTOCOL_VERSION, Arch type (Generic=1), min, max, weight
371371
'0000000a00000001000000000000000500000002', # 10, 1, 0, 5, 2
@@ -377,6 +377,19 @@ def _op_connect(self, auth_plugin_name, wire_crypt):
377377
'ffff80100000000100000000000000050000000e', # 16, 1, 0, 5, 14
378378
'ffff801100000001000000000000000500000010', # 17, 1, 0, 5, 16
379379
]
380+
if wire_compress:
381+
protocols_with_compress = [
382+
# PROTOCOL_VERSION, Arch type (Generic=1), min, max|pflag_compress, weight
383+
'0000000a00000001000000000000000500000002', # 10, 1, 0, 5, 2
384+
'ffff800b00000001000000000000000500000004', # 11, 1, 0, 5, 4
385+
'ffff800c00000001000000000000000500000006', # 12, 1, 0, 5, 6
386+
'ffff800d00000001000000000000010500000008', # 13, 1, 0, 0x105, 8
387+
'ffff800e0000000100000000000001050000000a', # 14, 1, 0, 0x105, 10
388+
'ffff800f0000000100000000000001050000000c', # 15, 1, 0, 0x105, 12
389+
'ffff80100000000100000000000001050000000e', # 16, 1, 0, 0x105, 14
390+
'ffff801100000001000000000000010500000010', # 17, 1, 0, 0x105, 16
391+
]
392+
protocols = protocols_with_compress
380393
p = Packer()
381394
p.pack_int(self.op_connect)
382395
p.pack_int(self.op_attach)

0 commit comments

Comments
 (0)