Skip to content

Commit 2e70667

Browse files
feat: Fir 16953 bytea support (#235)
1 parent a93a400 commit 2e70667

File tree

8 files changed

+124
-10
lines changed

8 files changed

+124
-10
lines changed

src/firebolt/async_db/_types.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def parse_datetime(datetime_string: str) -> datetime:
4949
from firebolt.utils.util import cached_property
5050

5151
_NoneType = type(None)
52-
_col_types = (int, float, str, datetime, date, bool, list, Decimal, _NoneType)
52+
_col_types = (int, float, str, datetime, date, bool, list, Decimal, _NoneType, bytes)
5353
# duplicating this since 3.7 can't unpack Union
54-
ColType = Union[int, float, str, datetime, date, bool, list, Decimal, _NoneType]
54+
ColType = Union[int, float, str, datetime, date, bool, list, Decimal, _NoneType, bytes]
5555
RawColType = Union[int, float, str, bool, list, _NoneType]
56-
ParameterType = Union[int, float, str, datetime, date, bool, Decimal, Sequence]
56+
ParameterType = Union[int, float, str, datetime, date, bool, Decimal, Sequence, bytes]
5757

5858
# These definitions are required by PEP-249
5959
Date = date
@@ -78,12 +78,13 @@ def TimeFromTicks(t: int) -> None:
7878
TimestampFromTicks = datetime.fromtimestamp
7979

8080

81-
def Binary(value: str) -> str:
82-
"""Convert string to binary for Firebolt DB does nothing."""
83-
return value
81+
def Binary(value: str) -> bytes:
82+
"""Encode a string into UTF-8."""
83+
return value.encode("utf-8")
8484

8585

86-
STRING = BINARY = str
86+
STRING = str
87+
BINARY = bytes
8788
NUMBER = int
8889
DATETIME = datetime
8990
ROWID = int
@@ -169,6 +170,8 @@ class _InternalType(Enum):
169170

170171
Boolean = "boolean"
171172

173+
Bytea = "bytea"
174+
172175
Nothing = "Nothing"
173176

174177
@cached_property
@@ -188,6 +191,7 @@ def python_type(self) -> type:
188191
_InternalType.TimestampNtz: datetime,
189192
_InternalType.TimestampTz: datetime,
190193
_InternalType.Boolean: bool,
194+
_InternalType.Bytea: bytes,
191195
# For simplicity, this could happen only during 'select null' query
192196
_InternalType.Nothing: str,
193197
}
@@ -221,6 +225,18 @@ def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL]: # noqa: C901
221225
return str
222226

223227

228+
BYTEA_PREFIX = "\\x"
229+
230+
231+
def _parse_bytea(str_value: str) -> bytes:
232+
if (
233+
len(str_value) < len(BYTEA_PREFIX)
234+
or str_value[: len(BYTEA_PREFIX)] != BYTEA_PREFIX
235+
):
236+
raise ValueError(f"Invalid bytea value format: {BYTEA_PREFIX} prefix expected")
237+
return bytes.fromhex(str_value[len(BYTEA_PREFIX) :])
238+
239+
224240
def parse_value(
225241
value: RawColType,
226242
ctype: Union[type, ARRAY, DECIMAL],
@@ -244,6 +260,10 @@ def parse_value(
244260
if not isinstance(value, (bool, int)):
245261
raise DataError(f"Invalid boolean value {value}: bool or int expected")
246262
return bool(value)
263+
if ctype is bytes:
264+
if not isinstance(value, str):
265+
raise DataError(f"Invalid bytea value {value}: str expected")
266+
return _parse_bytea(value)
247267
if isinstance(ctype, DECIMAL):
248268
assert isinstance(value, (str, int))
249269
return Decimal(value)
@@ -274,6 +294,9 @@ def format_value(value: ParameterType) -> str:
274294
return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'"
275295
elif isinstance(value, date):
276296
return f"'{value.isoformat()}'"
297+
elif isinstance(value, bytes):
298+
# Encode each byte into hex
299+
return "'" + "".join(f"\\x{b:02x}" for b in value) + "'"
277300
if value is None:
278301
return "NULL"
279302
elif isinstance(value, Sequence):

tests/integration/dbapi/async/test_queries_async.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44

55
from pytest import mark, raises
66

7-
from firebolt.async_db import Connection, Cursor, DataError, OperationalError
7+
from firebolt.async_db import (
8+
Binary,
9+
Connection,
10+
Cursor,
11+
DataError,
12+
OperationalError,
13+
)
814
from firebolt.async_db._types import ColType, Column
915
from firebolt.async_db.cursor import QueryStatus
1016

@@ -487,3 +493,27 @@ async def test_server_side_async_execution_get_status(
487493
# assert (
488494
# type(status) is QueryStatus,
489495
# ), "get_status() did not return a QueryStatus object."
496+
497+
498+
async def test_bytea_roundtrip(
499+
connection: Connection,
500+
) -> None:
501+
"""Inserted and than selected bytea value doesn't get corrupted."""
502+
with connection.cursor() as c:
503+
await c.execute("DROP TABLE IF EXISTS test_bytea_roundtrip")
504+
await c.execute(
505+
"CREATE FACT TABLE test_bytea_roundtrip(id int, b bytea) primary index id"
506+
)
507+
508+
data = "bytea_123\n\tヽ༼ຈل͜ຈ༽ノ"
509+
510+
await c.execute(
511+
"INSERT INTO test_bytea_roundtrip VALUES (1, ?)", (Binary(data),)
512+
)
513+
await c.execute("SELECT b FROM test_bytea_roundtrip")
514+
515+
bytes_data = (await c.fetchone())[0]
516+
517+
assert (
518+
bytes_data.decode("utf-8") == data
519+
), "Invalid bytea data returned after roundtrip"

tests/integration/dbapi/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def all_types_query() -> str:
7575
'true as "boolean", '
7676
"[1,2,3,4] as \"array\", cast('1231232.123459999990457054844258706536' as "
7777
'decimal(38,30)) as "decimal", '
78-
'cast(null as int) as "nullable"'
78+
'cast(null as int) as "nullable", '
79+
"'abc123'::bytea as \"bytea\""
7980
)
8081

8182

@@ -104,6 +105,7 @@ def all_types_query_description() -> List[Column]:
104105
Column("array", ARRAY(int), None, None, None, None, None),
105106
Column("decimal", DECIMAL(38, 30), None, None, None, None, None),
106107
Column("nullable", int, None, None, None, None, None),
108+
Column("bytea", bytes, None, None, None, None, None),
107109
]
108110

109111

@@ -142,6 +144,7 @@ def all_types_query_response(timezone_offset_seconds: int) -> List[ColType]:
142144
[1, 2, 3, 4],
143145
Decimal("1231232.123459999990457054844258706536"),
144146
None,
147+
b"abc123",
145148
]
146149
]
147150

tests/integration/dbapi/sync/test_queries.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from firebolt.async_db.cursor import QueryStatus
1010
from firebolt.client.auth import Auth
1111
from firebolt.db import (
12+
Binary,
1213
Connection,
1314
Cursor,
1415
DataError,
@@ -519,3 +520,25 @@ def run_query():
519520

520521
connection.close()
521522
assert not exceptions
523+
524+
525+
def test_bytea_roundtrip(
526+
connection: Connection,
527+
) -> None:
528+
"""Inserted and than selected bytea value doesn't get corrupted."""
529+
with connection.cursor() as c:
530+
c.execute("DROP TABLE IF EXISTS test_bytea_roundtrip")
531+
c.execute(
532+
"CREATE FACT TABLE test_bytea_roundtrip(id int, b bytea) primary index id"
533+
)
534+
535+
data = "bytea_123\n\tヽ༼ຈل͜ຈ༽ノ"
536+
537+
c.execute("INSERT INTO test_bytea_roundtrip VALUES (1, ?)", (Binary(data),))
538+
c.execute("SELECT b FROM test_bytea_roundtrip")
539+
540+
bytes_data = (c.fetchone())[0]
541+
542+
assert (
543+
bytes_data.decode("utf-8") == data
544+
), "Invalid bytea data returned after roundtrip"

tests/unit/async_db/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def types_map() -> Dict[str, type]:
5151
"Decimal(38)": str,
5252
"boolean": bool,
5353
"SomeRandomNotExistingType": str,
54+
"bytea": bytes,
5455
}
5556
array_types = {f"array({k})": ARRAY(v) for k, v in base_types.items()}
5657
nullable_arrays = {f"{k} null": v for k, v in array_types.items()}

tests/unit/async_db/test_typing_format.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from sqlparse import parse
77
from sqlparse.sql import Statement
88

9-
from firebolt.async_db import DataError, InterfaceError, NotSupportedError
9+
from firebolt.async_db import (
10+
Binary,
11+
DataError,
12+
InterfaceError,
13+
NotSupportedError,
14+
)
1015
from firebolt.async_db._types import (
1116
SetParameter,
1217
format_statement,
@@ -44,6 +49,8 @@
4449
(("a", "b", "c"), "['a', 'b', 'c']"),
4550
# None
4651
(None, "NULL"),
52+
# Bytea
53+
(b"abc", "'\\x61\\x62\\x63'"),
4754
],
4855
)
4956
def test_format_value(value: str, result: str) -> None:
@@ -188,3 +195,7 @@ def test_statement_to_set(statement: Statement, result: Optional[SetParameter])
188195
def test_statement_to_set_errors(statement: Statement, error: Exception) -> None:
189196
with raises(error):
190197
statement_to_set(statement)
198+
199+
200+
def test_binary() -> None:
201+
assert Binary("abc") == b"abc"

tests/unit/async_db/test_typing_parse.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,22 @@ def test_parse_value_bool() -> None:
248248

249249
with raises(DataError):
250250
parse_value("true", bool)
251+
252+
253+
def test_parse_value_bytes() -> None:
254+
"""parse_value parses all int values correctly."""
255+
assert (
256+
parse_value("\\x616263", bytes) == b"abc"
257+
), "Error parsing bytes: provided str"
258+
assert parse_value(None, bytes) is None, "Error parsing bytes: provided None"
259+
260+
with raises(ValueError):
261+
parse_value("\\xabc", bytes)
262+
263+
# Missing prefix
264+
with raises(ValueError):
265+
parse_value("616263", bytes)
266+
267+
for val in (1, True, Exception()):
268+
with raises(DataError):
269+
parse_value(val, bytes)

tests/unit/db_conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def query_description() -> List[Column]:
3232
Column("bool", "boolean", None, None, None, None, None),
3333
Column("array", "array(int)", None, None, None, None, None),
3434
Column("decimal", "Decimal(12, 34)", None, None, None, None, None),
35+
Column("bytea", "bytea", None, None, None, None, None),
3536
]
3637

3738

@@ -54,6 +55,7 @@ def python_query_description() -> List[Column]:
5455
Column("bool", bool, None, None, None, None, None),
5556
Column("array", ARRAY(int), None, None, None, None, None),
5657
Column("decimal", DECIMAL(12, 34), None, None, None, None, None),
58+
Column("bytea", bytes, None, None, None, None, None),
5759
]
5860

5961

@@ -77,6 +79,7 @@ def query_data() -> List[List[ColType]]:
7779
1,
7880
[1, 2, 3, 4],
7981
"123456789.123456789123456789123456789",
82+
"\\x616263",
8083
]
8184
for i in range(QUERY_ROW_COUNT)
8285
]
@@ -102,6 +105,7 @@ def python_query_data() -> List[List[ColType]]:
102105
1,
103106
[1, 2, 3, 4],
104107
Decimal("123456789.123456789123456789123456789"),
108+
b"abc",
105109
]
106110
for i in range(QUERY_ROW_COUNT)
107111
]

0 commit comments

Comments
 (0)