Skip to content

Commit 47cba14

Browse files
feat: FIR-38125 ecosystem support for struct type for python sdk (#406)
1 parent 28435b9 commit 47cba14

File tree

8 files changed

+292
-25
lines changed

8 files changed

+292
-25
lines changed

src/firebolt/async_db/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
NUMBER,
99
ROWID,
1010
STRING,
11+
STRUCT,
1112
Binary,
1213
Date,
1314
DateFromTicks,

src/firebolt/common/_types.py

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from datetime import date, datetime, timezone
66
from decimal import Decimal
77
from enum import Enum
8-
from typing import List, Optional, Sequence, Union
8+
from io import StringIO
9+
from typing import Any, Dict, List, Optional, Sequence, Union
910

1011
from sqlparse import parse as parse_sql # type: ignore
1112
from sqlparse.sql import ( # type: ignore
@@ -62,8 +63,6 @@ def parse_datetime(datetime_string: str) -> datetime:
6263
# These definitions are required by PEP-249
6364
Date = date
6465

65-
_AccountInfo = namedtuple("_AccountInfo", ["id", "version"])
66-
6766

6867
def DateFromTicks(t: int) -> date: # NOSONAR
6968
"""Convert `ticks` to `date` for Firebolt DB."""
@@ -109,16 +108,28 @@ def Binary(value: str) -> bytes: # NOSONAR
109108
)
110109

111110

112-
class ARRAY:
111+
class ExtendedType:
112+
"""Base type for all extended types in Firebolt (array, decimal, struct, etc.)."""
113+
114+
__name__ = "ExtendedType"
115+
116+
@staticmethod
117+
def is_valid_type(type_: Any) -> bool:
118+
return type_ in _col_types or isinstance(type_, ExtendedType)
119+
120+
def __hash__(self) -> int:
121+
return hash(str(self))
122+
123+
124+
class ARRAY(ExtendedType):
113125
"""Class for holding `array` column type information in Firebolt DB."""
114126

115127
__name__ = "Array"
116128
_prefix = "array("
117129

118-
def __init__(self, subtype: Union[type, ARRAY, DECIMAL]):
119-
assert (subtype in _col_types and subtype is not list) or isinstance(
120-
subtype, (ARRAY, DECIMAL)
121-
), f"Invalid array subtype: {str(subtype)}"
130+
def __init__(self, subtype: Union[type, ExtendedType]):
131+
if not self.is_valid_type(subtype):
132+
raise ValueError(f"Invalid array subtype: {str(subtype)}")
122133
self.subtype = subtype
123134

124135
def __str__(self) -> str:
@@ -130,7 +141,7 @@ def __eq__(self, other: object) -> bool:
130141
return other.subtype == self.subtype
131142

132143

133-
class DECIMAL:
144+
class DECIMAL(ExtendedType):
134145
"""Class for holding `decimal` value information in Firebolt DB."""
135146

136147
__name__ = "Decimal"
@@ -143,15 +154,29 @@ def __init__(self, precision: int, scale: int):
143154
def __str__(self) -> str:
144155
return f"Decimal({self.precision}, {self.scale})"
145156

146-
def __hash__(self) -> int:
147-
return hash(str(self))
148-
149157
def __eq__(self, other: object) -> bool:
150158
if not isinstance(other, DECIMAL):
151159
return NotImplemented
152160
return other.precision == self.precision and other.scale == self.scale
153161

154162

163+
class STRUCT(ExtendedType):
164+
__name__ = "Struct"
165+
_prefix = "struct("
166+
167+
def __init__(self, fields: Dict[str, Union[type, ExtendedType]]):
168+
for name, type_ in fields.items():
169+
if not self.is_valid_type(type_):
170+
raise ValueError(f"Invalid struct field type: {str(type_)}")
171+
self.fields = fields
172+
173+
def __str__(self) -> str:
174+
return f"Struct({', '.join(f'{k}: {v}' for k, v in self.fields.items())})"
175+
176+
def __eq__(self, other: Any) -> bool:
177+
return isinstance(other, STRUCT) and other.fields == self.fields
178+
179+
155180
NULLABLE_SUFFIX = "null"
156181

157182

@@ -206,7 +231,31 @@ def python_type(self) -> type:
206231
return types[self]
207232

208233

209-
def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL]: # noqa: C901
234+
def split_struct_fields(raw_struct: str) -> List[str]:
235+
"""Split raw struct inner fields string into a list of field definitions.
236+
>>> split_struct_fields("field1 int, field2 struct(field1 int, field2 text)")
237+
["field1 int", "field2 struct(field1 int, field2 text)"]
238+
"""
239+
balance = 0 # keep track of the level of nesting, and only split on level 0
240+
separator = ","
241+
res = []
242+
current = StringIO()
243+
for i, ch in enumerate(raw_struct):
244+
if ch == "(":
245+
balance += 1
246+
elif ch == ")":
247+
balance -= 1
248+
elif ch == separator and balance == 0:
249+
res.append(current.getvalue())
250+
current = StringIO()
251+
continue
252+
current.write(ch)
253+
254+
res.append(current.getvalue())
255+
return res
256+
257+
258+
def parse_type(raw_type: str) -> Union[type, ExtendedType]: # noqa: C901
210259
"""Parse typename provided by query metadata into Python type."""
211260
if not isinstance(raw_type, str):
212261
raise DataError(f"Invalid typename {str(raw_type)}: str expected")
@@ -218,10 +267,20 @@ def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL]: # noqa: C901
218267
try:
219268
prec_scale = raw_type[len(DECIMAL._prefix) : -1].split(",")
220269
precision, scale = int(prec_scale[0]), int(prec_scale[1])
270+
return DECIMAL(precision, scale)
221271
except (ValueError, IndexError):
222272
pass
223-
else:
224-
return DECIMAL(precision, scale)
273+
# Handle structs
274+
if raw_type.startswith(STRUCT._prefix) and raw_type.endswith(")"):
275+
try:
276+
fields_raw = split_struct_fields(raw_type[len(STRUCT._prefix) : -1])
277+
fields = {}
278+
for f in fields_raw:
279+
name, type_ = f.strip().split(" ", 1)
280+
fields[name.strip()] = parse_type(type_.strip())
281+
return STRUCT(fields)
282+
except ValueError:
283+
pass
225284
# Handle nullable
226285
if raw_type.endswith(NULLABLE_SUFFIX):
227286
return parse_type(raw_type[: -len(NULLABLE_SUFFIX)].strip(" "))
@@ -247,13 +306,13 @@ def _parse_bytea(str_value: str) -> bytes:
247306

248307
def parse_value(
249308
value: RawColType,
250-
ctype: Union[type, ARRAY, DECIMAL],
309+
ctype: Union[type, ExtendedType],
251310
) -> ColType:
252311
"""Provided raw value, and Python type; parses first into Python value."""
253312
if value is None:
254313
return None
255314
if ctype in (int, str, float):
256-
assert isinstance(ctype, type)
315+
assert isinstance(ctype, type) # assertion for mypy
257316
return ctype(value)
258317
if ctype is date:
259318
if not isinstance(value, str):
@@ -273,11 +332,20 @@ def parse_value(
273332
raise DataError(f"Invalid bytea value {value}: str expected")
274333
return _parse_bytea(value)
275334
if isinstance(ctype, DECIMAL):
276-
assert isinstance(value, (str, int))
335+
if not isinstance(value, (str, int)):
336+
raise DataError(f"Invalid decimal value {value}: str or int expected")
277337
return Decimal(value)
278338
if isinstance(ctype, ARRAY):
279-
assert isinstance(value, list)
339+
if not isinstance(value, list):
340+
raise DataError(f"Invalid array value {value}: list expected")
280341
return [parse_value(it, ctype.subtype) for it in value]
342+
if isinstance(ctype, STRUCT):
343+
if not isinstance(value, dict):
344+
raise DataError(f"Invalid struct value {value}: dict expected")
345+
return {
346+
name: parse_value(value.get(name), type_)
347+
for name, type_ in ctype.fields.items()
348+
}
281349
raise DataError(f"Unsupported data type returned: {ctype.__name__}")
282350

283351

src/firebolt/db/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
NUMBER,
77
ROWID,
88
STRING,
9+
STRUCT,
910
Binary,
1011
Date,
1112
DateFromTicks,

tests/integration/dbapi/async/V2/test_queries_async.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,3 +429,29 @@ async def test_select_geography(
429429
select_geography_response,
430430
"Invalid data returned by fetchall",
431431
)
432+
433+
434+
async def test_select_struct(
435+
connection: Connection,
436+
setup_struct_query: str,
437+
cleanup_struct_query: str,
438+
select_struct_query: str,
439+
select_struct_description: List[Column],
440+
select_struct_response: List[ColType],
441+
):
442+
with connection.cursor() as c:
443+
try:
444+
await c.execute(setup_struct_query)
445+
await c.execute(select_struct_query)
446+
assert (
447+
c.description == select_struct_description
448+
), "Invalid description value"
449+
res = await c.fetchall()
450+
assert len(res) == 1, "Invalid data length"
451+
assert_deep_eq(
452+
res,
453+
select_struct_response,
454+
"Invalid data returned by fetchall",
455+
)
456+
finally:
457+
await c.execute(cleanup_struct_query)

tests/integration/dbapi/conftest.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytest import fixture
77

88
from firebolt.async_db.cursor import Column
9-
from firebolt.common._types import ColType
9+
from firebolt.common._types import STRUCT, ColType
1010
from firebolt.db import ARRAY, DECIMAL, Connection
1111

1212
LOGGER = getLogger(__name__)
@@ -209,3 +209,54 @@ def select_geography_description() -> List[Column]:
209209
@fixture
210210
def select_geography_response() -> List[ColType]:
211211
return [["0101000020E6100000FEFFFFFFFFFFEF3F000000000000F03F"]]
212+
213+
214+
@fixture
215+
def setup_struct_query() -> str:
216+
return """
217+
SET advanced_mode=1;
218+
SET enable_struct=1;
219+
SET enable_create_table_v2=true;
220+
SET enable_row_selection=true;
221+
SET prevent_create_on_information_schema=true;
222+
SET enable_create_table_with_struct_type=true;
223+
DROP TABLE IF EXISTS test_struct;
224+
DROP TABLE IF EXISTS test_struct_helper;
225+
CREATE TABLE IF NOT EXISTS test_struct(id int not null, s struct(a array(int) not null, b datetime null) not null);
226+
CREATE TABLE IF NOT EXISTS test_struct_helper(a array(int) not null, b datetime null);
227+
INSERT INTO test_struct_helper(a, b) VALUES ([1, 2], '2019-07-31 01:01:01');
228+
INSERT INTO test_struct(id, s) SELECT 1, test_struct_helper FROM test_struct_helper;
229+
"""
230+
231+
232+
@fixture
233+
def cleanup_struct_query() -> str:
234+
return """
235+
DROP TABLE IF EXISTS test_struct;
236+
DROP TABLE IF EXISTS test_struct_helper;
237+
"""
238+
239+
240+
@fixture
241+
def select_struct_query() -> str:
242+
return "SELECT test_struct FROM test_struct"
243+
244+
245+
@fixture
246+
def select_struct_description() -> List[Column]:
247+
return [
248+
Column(
249+
"test_struct",
250+
STRUCT({"id": int, "s": STRUCT({"a": ARRAY(int), "b": datetime})}),
251+
None,
252+
None,
253+
None,
254+
None,
255+
None,
256+
)
257+
]
258+
259+
260+
@fixture
261+
def select_struct_response() -> List[ColType]:
262+
return [[{"id": 1, "s": {"a": [1, 2], "b": datetime(2019, 7, 31, 1, 1, 1)}}]]

tests/integration/dbapi/sync/V2/test_queries.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,3 +512,29 @@ def test_select_geography(
512512
select_geography_response,
513513
"Invalid data returned by fetchall",
514514
)
515+
516+
517+
def test_select_struct(
518+
connection: Connection,
519+
setup_struct_query: str,
520+
cleanup_struct_query: str,
521+
select_struct_query: str,
522+
select_struct_description: List[Column],
523+
select_struct_response: List[ColType],
524+
):
525+
with connection.cursor() as c:
526+
try:
527+
c.execute(setup_struct_query)
528+
c.execute(select_struct_query)
529+
assert (
530+
c.description == select_struct_description
531+
), "Invalid description value"
532+
res = c.fetchall()
533+
assert len(res) == 1, "Invalid data length"
534+
assert_deep_eq(
535+
res,
536+
select_struct_response,
537+
"Invalid data returned by fetchall",
538+
)
539+
finally:
540+
c.execute(cleanup_struct_query)

0 commit comments

Comments
 (0)