Skip to content

Commit 39670ab

Browse files
authored
feat: New firebolt types (#74)
1 parent b53d779 commit 39670ab

File tree

3 files changed

+92
-13
lines changed

3 files changed

+92
-13
lines changed

src/firebolt_db/firebolt_dialect.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sqlalchemy.engine.url import URL
1313
from sqlalchemy.sql import compiler, text
1414
from sqlalchemy.types import (
15+
ARRAY,
1516
BIGINT,
1617
BOOLEAN,
1718
CHAR,
@@ -24,8 +25,8 @@
2425
)
2526

2627

27-
class ARRAY(sqltypes.TypeEngine):
28-
__visit_name__ = "ARRAY"
28+
class BYTEA(sqltypes.LargeBinary):
29+
__visit_name__ = "BYTEA"
2930

3031

3132
# Firebolt data types compatibility with sqlalchemy.sql.types
@@ -37,18 +38,47 @@ class ARRAY(sqltypes.TypeEngine):
3738
"float": FLOAT,
3839
"double": FLOAT,
3940
"double precision": FLOAT,
41+
"real": FLOAT,
4042
"boolean": BOOLEAN,
4143
"int": INTEGER,
4244
"integer": INTEGER,
4345
"bigint": BIGINT,
4446
"long": BIGINT,
4547
"timestamp": TIMESTAMP,
48+
"timestamptz": TIMESTAMP,
49+
"timestampntz": TIMESTAMP,
4650
"datetime": DATETIME,
4751
"date": DATE,
48-
"array": ARRAY,
52+
"bytea": BYTEA,
4953
}
5054

5155

56+
def resolve_type(fb_type: str) -> sqltypes.TypeEngine:
57+
def removesuffix(s: str, suffix: str) -> str:
58+
"""Python < 3.9 compatibility"""
59+
if s.endswith(suffix):
60+
s = s[: -len(suffix)]
61+
return s
62+
63+
result: sqltypes.TypeEngine
64+
if fb_type.startswith("array"):
65+
# Nested arrays not supported
66+
dimensions = 0
67+
while fb_type.startswith("array"):
68+
dimensions += 1
69+
fb_type = fb_type[6:-1] # Strip ARRAY()
70+
fb_type = removesuffix(removesuffix(fb_type, " not null"), " null")
71+
result = ARRAY(resolve_type(fb_type), dimensions=dimensions)
72+
else:
73+
# Strip complex type info e.g. DECIMAL(8,23) -> DECIMAL
74+
fb_type = fb_type[: fb_type.find("(")] if "(" in fb_type else fb_type
75+
result = type_map.get(fb_type, DEFAULT_TYPE) # type: ignore
76+
return result
77+
78+
79+
DEFAULT_TYPE = VARCHAR
80+
81+
5282
class UniversalSet(set):
5383
def __contains__(self, item: Any) -> bool:
5484
return True
@@ -193,6 +223,7 @@ def get_columns(
193223
schema: Optional[str] = None,
194224
**kwargs: Any
195225
) -> List[Dict]:
226+
196227
query = """
197228
select column_name,
198229
data_type,
@@ -212,7 +243,7 @@ def get_columns(
212243
return [
213244
{
214245
"name": row[0],
215-
"type": type_map[row[1].lower()],
246+
"type": resolve_type(row[1].lower()),
216247
"nullable": get_is_nullable(row[2]),
217248
"default": None,
218249
}

tests/integration/conftest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from logging import getLogger
33
from os import environ
4+
from typing import List
45

56
from pytest import fixture
67
from sqlalchemy import create_engine, text
@@ -139,6 +140,44 @@ def ex_table_query(ex_table_name: str) -> str:
139140
"""
140141

141142

143+
@fixture(scope="class")
144+
def type_table_name() -> str:
145+
return "types_alchemy"
146+
147+
148+
@fixture(scope="class")
149+
def firebolt_columns() -> List[str]:
150+
return [
151+
"INTEGER",
152+
"NUMERIC",
153+
"BIGINT",
154+
"REAL",
155+
"DOUBLE PRECISION",
156+
"TEXT",
157+
"TIMESTAMPNTZ",
158+
"TIMESTAMPTZ",
159+
"DATE",
160+
"TIMESTAMP",
161+
"BOOLEAN",
162+
"BYTEA",
163+
]
164+
165+
166+
@fixture(scope="class")
167+
def type_table_query(firebolt_columns: List[str], type_table_name: str) -> str:
168+
col_names = [c.replace(" ", "_").lower() for c in firebolt_columns]
169+
cols = ",\n".join(
170+
[f"c_{name} {c_type}" for name, c_type in zip(col_names, firebolt_columns)]
171+
)
172+
return f"""
173+
CREATE DIMENSION TABLE {type_table_name}
174+
(
175+
{cols},
176+
c_array ARRAY(ARRAY(INTEGER))
177+
);
178+
"""
179+
180+
142181
@fixture(scope="class")
143182
def fact_table_name() -> str:
144183
return "test_alchemy"
@@ -155,6 +194,8 @@ def setup_test_tables(
155194
engine: Engine,
156195
fact_table_name: str,
157196
dimension_table_name: str,
197+
type_table_query: str,
198+
type_table_name: str,
158199
):
159200
connection.execute(
160201
text(
@@ -178,11 +219,15 @@ def setup_test_tables(
178219
"""
179220
)
180221
)
222+
connection.execute(text(type_table_query))
181223
assert engine.dialect.has_table(connection, fact_table_name)
182224
assert engine.dialect.has_table(connection, dimension_table_name)
225+
assert engine.dialect.has_table(connection, type_table_name)
183226
yield
184227
# Teardown
185228
connection.execute(text(f"DROP TABLE IF EXISTS {fact_table_name} CASCADE;"))
186229
connection.execute(text(f"DROP TABLE IF EXISTS {dimension_table_name} CASCADE;"))
230+
connection.execute(text(f"DROP TABLE IF EXISTS {type_table_name} CASCADE;"))
187231
assert not engine.dialect.has_table(connection, fact_table_name)
188232
assert not engine.dialect.has_table(connection, dimension_table_name)
233+
assert not engine.dialect.has_table(connection, type_table_name)

tests/integration/test_sqlalchemy_integration.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import create_engine, text
66
from sqlalchemy.engine.base import Connection, Engine
77
from sqlalchemy.exc import OperationalError
8+
from sqlalchemy.types import ARRAY, INTEGER, TypeEngine
89

910

1011
class TestFireboltDialect:
@@ -114,17 +115,19 @@ def test_get_table_names(self, engine: Engine, connection: Connection):
114115
assert len(results) == 0
115116

116117
def test_get_columns(
117-
self, engine: Engine, connection: Connection, fact_table_name: str
118+
self, engine: Engine, connection: Connection, type_table_name: str
118119
):
119-
results = engine.dialect.get_columns(connection, fact_table_name)
120+
results = engine.dialect.get_columns(connection, type_table_name)
120121
assert len(results) > 0
121-
row = results[0]
122-
assert isinstance(row, dict)
123-
row_keys = list(row.keys())
124-
assert row_keys[0] == "name"
125-
assert row_keys[1] == "type"
126-
assert row_keys[2] == "nullable"
127-
assert row_keys[3] == "default"
122+
for column in results:
123+
assert isinstance(column, dict)
124+
# Check only works for basic types
125+
if type(column["type"]) == ARRAY:
126+
# ARRAY[[INT]]
127+
assert column["type"].dimensions == 2
128+
assert type(column["type"].item_type) == INTEGER
129+
else:
130+
assert issubclass(column["type"], TypeEngine)
128131

129132
def test_service_account_connect(self, connection_service_account: Connection):
130133
result = connection_service_account.execute(text("SELECT 1"))

0 commit comments

Comments
 (0)