Skip to content

Commit 9a3e76f

Browse files
feat:Fir 11238 python sdk support new data types (#148)
* add decimal support * add integration tests for decimal * add support for date32 * add support for datetime64 * revert typo * extend date tests * extend unit tests * support parsing milliseconds without ciso8601 * fix datetime test Co-authored-by: Stepan Burlakov <[email protected]>
1 parent 3aceafb commit 9a3e76f

File tree

9 files changed

+192
-37
lines changed

9 files changed

+192
-37
lines changed

src/firebolt/async_db/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
ARRAY,
33
BINARY,
44
DATETIME,
5+
DATETIME64,
6+
DECIMAL,
57
NUMBER,
68
ROWID,
79
STRING,

src/firebolt/async_db/_types.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,24 @@
1313
try:
1414
from ciso8601 import parse_datetime # type: ignore
1515
except ImportError:
16-
parse_datetime = datetime.fromisoformat # type: ignore
16+
# Unfortunately, there seems to be no support for optional bits in strptime
17+
def parse_datetime(date_string: str) -> datetime: # type: ignore
18+
format = "%Y-%m-%d %H:%M:%S.%f"
19+
# fromisoformat doesn't support milliseconds
20+
if "." in date_string:
21+
return datetime.strptime(date_string, format)
22+
return datetime.fromisoformat(date_string)
1723

1824

1925
from firebolt.common.exception import DataError, NotSupportedError
2026
from firebolt.common.util import cached_property
2127

2228
_NoneType = type(None)
23-
_col_types = (int, float, str, datetime, date, bool, list, _NoneType)
29+
_col_types = (int, float, str, datetime, date, bool, list, Decimal, _NoneType)
2430
# duplicating this since 3.7 can't unpack Union
25-
ColType = Union[int, float, str, datetime, date, bool, list, _NoneType]
31+
ColType = Union[int, float, str, datetime, date, bool, list, Decimal, _NoneType]
2632
RawColType = Union[int, float, str, bool, list, _NoneType]
27-
ParameterType = Union[int, float, str, datetime, date, bool, Sequence]
33+
ParameterType = Union[int, float, str, datetime, date, bool, Decimal, Sequence]
2834

2935
# These definitions are required by PEP-249
3036
Date = date
@@ -78,9 +84,9 @@ class ARRAY:
7884

7985
_prefix = "Array("
8086

81-
def __init__(self, subtype: Union[type, ARRAY]):
87+
def __init__(self, subtype: Union[type, ARRAY, DECIMAL, DATETIME64]):
8288
assert (subtype in _col_types and subtype is not list) or isinstance(
83-
subtype, ARRAY
89+
subtype, (ARRAY, DECIMAL, DATETIME64)
8490
), f"Invalid array subtype: {str(subtype)}"
8591
self.subtype = subtype
8692

@@ -93,6 +99,41 @@ def __eq__(self, other: object) -> bool:
9399
return other.subtype == self.subtype
94100

95101

102+
class DECIMAL:
103+
"""Class for holding imformation about decimal value in firebolt db."""
104+
105+
_prefix = "Decimal("
106+
107+
def __init__(self, precision: int, scale: int):
108+
self.precision = precision
109+
self.scale = scale
110+
111+
def __str__(self) -> str:
112+
return f"Decimal({self.precision}, {self.scale})"
113+
114+
def __eq__(self, other: object) -> bool:
115+
if not isinstance(other, DECIMAL):
116+
return NotImplemented
117+
return other.precision == self.precision and other.scale == self.scale
118+
119+
120+
class DATETIME64:
121+
"""Class for holding imformation about datetime64 value in firebolt db."""
122+
123+
_prefix = "DateTime64("
124+
125+
def __init__(self, precision: int):
126+
self.precision = precision
127+
128+
def __str__(self) -> str:
129+
return f"DateTime64({self.precision})"
130+
131+
def __eq__(self, other: object) -> bool:
132+
if not isinstance(other, DATETIME64):
133+
return NotImplemented
134+
return other.precision == self.precision
135+
136+
96137
NULLABLE_PREFIX = "Nullable("
97138

98139

@@ -122,6 +163,7 @@ class _InternalType(Enum):
122163

123164
# DATE
124165
Date = "Date"
166+
Date32 = "Date32"
125167

126168
# DATETIME, TIMESTAMP
127169
DateTime = "DateTime"
@@ -145,20 +187,38 @@ def python_type(self) -> type:
145187
_InternalType.Float64: float,
146188
_InternalType.String: str,
147189
_InternalType.Date: date,
190+
_InternalType.Date32: date,
148191
_InternalType.DateTime: datetime,
149192
# For simplicity, this could happen only during 'select null' query
150193
_InternalType.Nothing: str,
151194
}
152195
return types[self]
153196

154197

155-
def parse_type(raw_type: str) -> Union[type, ARRAY]:
198+
def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL, DATETIME64]:
156199
"""Parse typename, provided by query metadata into python type."""
157200
if not isinstance(raw_type, str):
158201
raise DataError(f"Invalid typename {str(raw_type)}: str expected")
159202
# Handle arrays
160203
if raw_type.startswith(ARRAY._prefix) and raw_type.endswith(")"):
161204
return ARRAY(parse_type(raw_type[len(ARRAY._prefix) : -1]))
205+
# Handle decimal
206+
if raw_type.startswith(DECIMAL._prefix) and raw_type.endswith(")"):
207+
try:
208+
prec_scale = raw_type[len(DECIMAL._prefix) : -1].split(",")
209+
precision, scale = int(prec_scale[0]), int(prec_scale[1])
210+
except (ValueError, IndexError):
211+
pass
212+
else:
213+
return DECIMAL(precision, scale)
214+
# Handle detetime64
215+
if raw_type.startswith(DATETIME64._prefix) and raw_type.endswith(")"):
216+
try:
217+
precision = int(raw_type[len(DATETIME64._prefix) : -1])
218+
except (ValueError, IndexError):
219+
pass
220+
else:
221+
return DATETIME64(precision)
162222
# Handle nullable
163223
if raw_type.startswith(NULLABLE_PREFIX) and raw_type.endswith(")"):
164224
return parse_type(raw_type[len(NULLABLE_PREFIX) : -1])
@@ -173,7 +233,7 @@ def parse_type(raw_type: str) -> Union[type, ARRAY]:
173233

174234
def parse_value(
175235
value: RawColType,
176-
ctype: Union[type, ARRAY],
236+
ctype: Union[type, ARRAY, DECIMAL, DATETIME64],
177237
) -> ColType:
178238
"""Provided raw value and python type, parses first into python value."""
179239
if value is None:
@@ -186,10 +246,13 @@ def parse_value(
186246
raise DataError(f"Invalid date value {value}: str expected")
187247
assert isinstance(value, str)
188248
return parse_datetime(value).date()
189-
if ctype is datetime:
249+
if ctype is datetime or isinstance(ctype, DATETIME64):
190250
if not isinstance(value, str):
191251
raise DataError(f"Invalid datetime value {value}: str expected")
192252
return parse_datetime(value)
253+
if isinstance(ctype, DECIMAL):
254+
assert isinstance(value, (str, int))
255+
return Decimal(value)
193256
if isinstance(ctype, ARRAY):
194257
assert isinstance(value, list)
195258
return [parse_value(it, ctype.subtype) for it in value]

src/firebolt/async_db/cursor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ def _append_query_data(self, response: Response) -> None:
182182
# Empty response is returned for insert query
183183
if response.headers.get("content-length", "") != "0":
184184
try:
185-
query_data = response.json()
185+
# Skip parsing floats to properly parse them later
186+
query_data = response.json(parse_float=str)
186187
rowcount = int(query_data["rows"])
187188
descriptions = [
188189
Column(

src/firebolt/db/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
ARRAY,
33
BINARY,
44
DATETIME,
5+
DATETIME64,
6+
DECIMAL,
57
NUMBER,
68
ROWID,
79
STRING,

tests/integration/dbapi/async/test_queries_async.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import date, datetime
2+
from decimal import Decimal
23
from typing import Any, List
34

45
from pytest import mark, raises
@@ -39,8 +40,11 @@ async def test_select(
3940
all_types_query_response: List[ColType],
4041
) -> None:
4142
"""Select handles all data types properly"""
43+
set_params = {"firebolt_use_decimal": 1}
4244
with connection.cursor() as c:
43-
assert await c.execute(all_types_query) == 1, "Invalid row count returned"
45+
assert (
46+
await c.execute(all_types_query, set_parameters=set_params) == 1
47+
), "Invalid row count returned"
4448
assert c.rowcount == 1, "Invalid rowcount value"
4549
data = await c.fetchall()
4650
assert len(data) == c.rowcount, "Invalid data length"
@@ -50,13 +54,13 @@ async def test_select(
5054
assert len(await c.fetchall()) == 0, "Redundant data returned by fetchall"
5155

5256
# Different fetch types
53-
await c.execute(all_types_query)
57+
await c.execute(all_types_query, set_parameters=set_params)
5458
assert (
5559
await c.fetchone() == all_types_query_response[0]
5660
), "Invalid fetchone data"
5761
assert await c.fetchone() is None, "Redundant data returned by fetchone"
5862

59-
await c.execute(all_types_query)
63+
await c.execute(all_types_query, set_parameters=set_params)
6064
assert len(await c.fetchmany(0)) == 0, "Invalid data size returned by fetchmany"
6165
data = await c.fetchmany()
6266
assert len(data) == 1, "Invalid data size returned by fetchmany"
@@ -206,8 +210,12 @@ async def test_empty_query(c: Cursor, query: str) -> None:
206210
async def test_parameterized_query(connection: Connection) -> None:
207211
"""Query parameters are handled properly"""
208212

213+
set_params = {"firebolt_use_decimal": 1}
214+
209215
async def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
210-
assert await c.execute(query, params) == -1, "Invalid row count returned"
216+
assert (
217+
await c.execute(query, params, set_params) == -1
218+
), "Invalid row count returned"
211219
assert c.rowcount == -1, "Invalid rowcount value"
212220
assert c.description is None, "Invalid description"
213221
with raises(DataError):
@@ -223,8 +231,9 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
223231
await c.execute("DROP TABLE IF EXISTS test_tb_async_parameterized")
224232
await c.execute(
225233
"CREATE FACT TABLE test_tb_async_parameterized(i int, f float, s string, sn"
226-
" string null, d date, dt datetime, b bool, a array(int), ss string)"
227-
" primary index i"
234+
" string null, d date, dt datetime, b bool, a array(int), "
235+
"dec decimal(38, 3), ss string) primary index i",
236+
set_parameters=set_params,
228237
)
229238

230239
params = [
@@ -236,12 +245,13 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
236245
datetime(2022, 1, 1, 1, 1, 1),
237246
True,
238247
[1, 2, 3],
248+
Decimal("123.456"),
239249
]
240250

241251
await test_empty_query(
242252
c,
243253
"INSERT INTO test_tb_async_parameterized VALUES "
244-
"(?, ?, ?, ?, ?, ?, ?, ?, '\\?')",
254+
"(?, ?, ?, ?, ?, ?, ?, ?, ?, '\\?')",
245255
params,
246256
)
247257

@@ -252,7 +262,10 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
252262
params[6] = 1
253263

254264
assert (
255-
await c.execute("SELECT * FROM test_tb_async_parameterized") == 1
265+
await c.execute(
266+
"SELECT * FROM test_tb_async_parameterized", set_parameters=set_params
267+
)
268+
== 1
256269
), "Invalid data length in table after parameterized insert"
257270

258271
assert_deep_eq(

tests/integration/dbapi/conftest.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,39 @@
11
from datetime import date, datetime
2+
from decimal import Decimal
23
from logging import getLogger
34
from typing import List
45

56
from pytest import fixture
67

78
from firebolt.async_db._types import ColType
89
from firebolt.async_db.cursor import Column
9-
from firebolt.db import ARRAY
10+
from firebolt.db import ARRAY, DATETIME64, DECIMAL
1011

1112
LOGGER = getLogger(__name__)
1213

1314

1415
@fixture
1516
def all_types_query() -> str:
1617
return (
17-
"select 1 as uint8, -1 as int8, 257 as uint16, -257 as int16, 80000 as uint32,"
18-
" -80000 as int32, 30000000000 as uint64, -30000000000 as int64, cast(1.23 AS"
19-
" FLOAT) as float32, 1.2345678901234 as float64, 'text' as \"string\","
20-
" CAST('2021-03-28' AS DATE) as \"date\", CAST('2019-07-31 01:01:01' AS"
21-
' DATETIME) as "datetime", true as "bool",[1,2,3,4] as "array", cast(null as'
22-
" int) as nullable"
18+
"select 1 as uint8, "
19+
"-1 as int8, "
20+
"257 as uint16, "
21+
"-257 as int16, "
22+
"80000 as uint32, "
23+
"-80000 as int32, "
24+
"30000000000 as uint64, "
25+
"-30000000000 as int64, "
26+
"cast(1.23 AS FLOAT) as float32, "
27+
"1.2345678901234 as float64, "
28+
"'text' as \"string\", "
29+
"CAST('2021-03-28' AS DATE) as \"date\", "
30+
"CAST('1860-03-04' AS DATE_EXT) as \"date32\","
31+
"CAST('2019-07-31 01:01:01' AS DATETIME) as \"datetime\", "
32+
"CAST('2019-07-31 01:01:01.1234' AS TIMESTAMP_EXT(4)) as \"datetime64\", "
33+
'true as "bool",'
34+
'[1,2,3,4] as "array", cast(1231232.123459999990457054844258706536 as '
35+
'decimal(38,30)) as "decimal", '
36+
"cast(null as int) as nullable"
2337
)
2438

2539

@@ -38,9 +52,12 @@ def all_types_query_description() -> List[Column]:
3852
Column("float64", float, None, None, None, None, None),
3953
Column("string", str, None, None, None, None, None),
4054
Column("date", date, None, None, None, None, None),
55+
Column("date32", date, None, None, None, None, None),
4156
Column("datetime", datetime, None, None, None, None, None),
57+
Column("datetime64", DATETIME64(4), None, None, None, None, None),
4258
Column("bool", int, None, None, None, None, None),
4359
Column("array", ARRAY(int), None, None, None, None, None),
60+
Column("decimal", DECIMAL(38, 30), None, None, None, None, None),
4461
Column("nullable", int, None, None, None, None, None),
4562
]
4663

@@ -61,9 +78,12 @@ def all_types_query_response() -> List[ColType]:
6178
1.23456789012,
6279
"text",
6380
date(2021, 3, 28),
81+
date(1860, 3, 4),
6482
datetime(2019, 7, 31, 1, 1, 1),
83+
datetime(2019, 7, 31, 1, 1, 1, 123400),
6584
1,
6685
[1, 2, 3, 4],
86+
Decimal("1231232.123459999990457054844258706536"),
6787
None,
6888
]
6989
]

0 commit comments

Comments
 (0)