Skip to content

Commit 25cdf24

Browse files
authored
test: Compatibility with alchemy 2.0.1 (#70)
1 parent 00f47a7 commit 25cdf24

File tree

4 files changed

+57
-37
lines changed

4 files changed

+57
-37
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ sqlalchemy.dialects =
4242
[options.extras_require]
4343
dev =
4444
devtools==0.7.0
45+
greenlet==2.0.2
4546
mock==4.0.3
4647
mypy==0.910
4748
pre-commit==2.15.0

src/firebolt_db/firebolt_dialect.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def has_table(
148148
connection: AlchemyConnection,
149149
table_name: str,
150150
schema: Optional[str] = None,
151+
**kw: Any
151152
) -> bool:
152153
query = """
153154
select count(*) > 0 as exists_

tests/integration/conftest.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from os import environ
44

55
from pytest import fixture
6-
from sqlalchemy import create_engine
6+
from sqlalchemy import create_engine, text
77
from sqlalchemy.engine.base import Connection, Engine
88
from sqlalchemy.ext.asyncio import create_async_engine
99

@@ -157,28 +157,32 @@ def setup_test_tables(
157157
dimension_table_name: str,
158158
):
159159
connection.execute(
160-
f"""
160+
text(
161+
f"""
161162
CREATE FACT TABLE IF NOT EXISTS {fact_table_name}
162163
(
163164
idx INT,
164165
dummy TEXT
165166
) PRIMARY INDEX idx;
166167
"""
168+
)
167169
)
168170
connection.execute(
169-
f"""
171+
text(
172+
f"""
170173
CREATE DIMENSION TABLE IF NOT EXISTS {dimension_table_name}
171174
(
172175
idx INT,
173176
dummy TEXT
174177
);
175178
"""
179+
)
176180
)
177-
assert engine.dialect.has_table(engine, fact_table_name)
178-
assert engine.dialect.has_table(engine, dimension_table_name)
181+
assert engine.dialect.has_table(connection, fact_table_name)
182+
assert engine.dialect.has_table(connection, dimension_table_name)
179183
yield
180184
# Teardown
181-
connection.execute(f"DROP TABLE IF EXISTS {fact_table_name} CASCADE;")
182-
connection.execute(f"DROP TABLE IF EXISTS {dimension_table_name} CASCADE;")
183-
assert not engine.dialect.has_table(engine, fact_table_name)
184-
assert not engine.dialect.has_table(engine, dimension_table_name)
185+
connection.execute(text(f"DROP TABLE IF EXISTS {fact_table_name} CASCADE;"))
186+
connection.execute(text(f"DROP TABLE IF EXISTS {dimension_table_name} CASCADE;"))
187+
assert not engine.dialect.has_table(connection, fact_table_name)
188+
assert not engine.dialect.has_table(connection, dimension_table_name)

tests/integration/test_sqlalchemy_integration.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from decimal import Decimal
33

44
import pytest
5-
from sqlalchemy import create_engine
5+
from sqlalchemy import create_engine, text
66
from sqlalchemy.engine.base import Connection, Engine
77
from sqlalchemy.exc import OperationalError
88

@@ -15,11 +15,11 @@ def test_create_ex_table(
1515
ex_table_query: str,
1616
ex_table_name: str,
1717
):
18-
connection.execute(ex_table_query)
19-
assert engine.dialect.has_table(engine, ex_table_name)
18+
connection.execute(text(ex_table_query))
19+
assert engine.dialect.has_table(connection, ex_table_name)
2020
# Cleanup
21-
connection.execute(f"DROP TABLE {ex_table_name}")
22-
assert not engine.dialect.has_table(engine, ex_table_name)
21+
connection.execute(text(f"DROP TABLE {ex_table_name}"))
22+
assert not engine.dialect.has_table(connection, ex_table_name)
2323

2424
def test_set_params(
2525
self, username: str, password: str, database_name: str, engine_name: str
@@ -28,81 +28,95 @@ def test_set_params(
2828
f"firebolt://{username}:{password}@{database_name}/{engine_name}"
2929
)
3030
with engine.connect() as connection:
31-
connection.execute("SET advanced_mode=1")
32-
connection.execute("SET use_standard_sql=0")
33-
result = connection.execute("SELECT sleepEachRow(1) from numbers(1)")
31+
connection.execute(text("SET advanced_mode=1"))
32+
connection.execute(text("SET use_standard_sql=0"))
33+
result = connection.execute(text("SELECT sleepEachRow(1) from numbers(1)"))
3434
assert len(result.fetchall()) == 1
3535
engine.dispose()
3636

3737
def test_data_write(self, connection: Connection, fact_table_name: str):
3838
connection.execute(
39-
f"INSERT INTO {fact_table_name}(idx, dummy) VALUES (1, 'some_text')"
39+
text(f"INSERT INTO {fact_table_name}(idx, dummy) VALUES (1, 'some_text')")
40+
)
41+
result = connection.execute(
42+
text(f"SELECT * FROM {fact_table_name} WHERE idx=1")
4043
)
41-
result = connection.execute(f"SELECT * FROM {fact_table_name} WHERE idx=?", 1)
4244
assert result.fetchall() == [(1, "some_text")]
43-
result = connection.execute(f"SELECT * FROM {fact_table_name}")
45+
result = connection.execute(text(f"SELECT * FROM {fact_table_name}"))
4446
assert len(result.fetchall()) == 1
4547
# Update not supported
4648
with pytest.raises(OperationalError):
4749
connection.execute(
48-
f"UPDATE {fact_table_name} SET dummy='some_other_text' WHERE idx=1"
50+
text(
51+
f"UPDATE {fact_table_name} SET dummy='some_other_text' WHERE idx=1"
52+
)
4953
)
5054
# Delete works but is not officially supported yet
5155
# with pytest.raises(OperationalError):
5256
# connection.execute(f"DELETE FROM {fact_table_name} WHERE idx=1")
5357

5458
def test_firebolt_types(self, connection: Connection):
55-
result = connection.execute("SELECT '1896-01-01' :: DATE_EXT")
59+
result = connection.execute(text("SELECT '1896-01-01' :: DATE_EXT"))
5660
assert result.fetchall() == [(date(1896, 1, 1),)]
57-
result = connection.execute("SELECT '1896-01-01 00:01:00' :: TIMESTAMP_EXT")
61+
result = connection.execute(
62+
text("SELECT '1896-01-01 00:01:00' :: TIMESTAMP_EXT")
63+
)
5864
assert result.fetchall() == [(datetime(1896, 1, 1, 0, 1, 0, 0),)]
59-
result = connection.execute("SELECT 100.76 :: DECIMAL(5, 2)")
65+
result = connection.execute(text("SELECT 100.76 :: DECIMAL(5, 2)"))
6066
assert result.fetchall() == [(Decimal("100.76"),)]
6167

6268
def test_agg_index(self, connection: Connection, fact_table_name: str):
6369
# Test if sql parsing allows it
6470
agg_index = "idx_agg_max"
6571
connection.execute(
66-
f"""
72+
text(
73+
f"""
6774
CREATE AGGREGATING INDEX {agg_index} ON {fact_table_name} (
6875
dummy,
6976
max(idx)
7077
);
7178
"""
79+
)
7280
)
73-
connection.execute(f"DROP AGGREGATING INDEX {agg_index}")
81+
connection.execute(text(f"DROP AGGREGATING INDEX {agg_index}"))
7482

7583
def test_join_index(self, connection: Connection, dimension_table_name: str):
7684
# Test if sql parsing allows it
7785
join_index = "idx_join"
7886
connection.execute(
79-
f"""
87+
text(
88+
f"""
8089
CREATE JOIN INDEX {join_index} ON {dimension_table_name} (
8190
idx,
8291
dummy
8392
);
8493
"""
94+
)
8595
)
86-
connection.execute(f"DROP JOIN INDEX {join_index}")
96+
connection.execute(text(f"DROP JOIN INDEX {join_index}"))
8797

8898
def test_get_schema_names(self, engine: Engine, database_name: str):
8999
results = engine.dialect.get_schema_names(engine)
90100
assert "public" in results
91101

92-
def test_has_table(self, engine: Engine, fact_table_name: str):
93-
results = engine.dialect.has_table(engine, fact_table_name)
102+
def test_has_table(
103+
self, engine: Engine, connection: Connection, fact_table_name: str
104+
):
105+
results = engine.dialect.has_table(connection, fact_table_name)
94106
assert results == 1
95107

96-
def test_get_table_names(self, engine: Engine):
97-
results = engine.dialect.get_table_names(engine)
108+
def test_get_table_names(self, engine: Engine, connection: Connection):
109+
results = engine.dialect.get_table_names(connection)
98110
assert len(results) > 0
99-
results = engine.dialect.get_table_names(engine, "public")
111+
results = engine.dialect.get_table_names(connection, "public")
100112
assert len(results) > 0
101-
results = engine.dialect.get_table_names(engine, "non_existing_schema")
113+
results = engine.dialect.get_table_names(connection, "non_existing_schema")
102114
assert len(results) == 0
103115

104-
def test_get_columns(self, engine: Engine, fact_table_name: str):
105-
results = engine.dialect.get_columns(engine, fact_table_name)
116+
def test_get_columns(
117+
self, engine: Engine, connection: Connection, fact_table_name: str
118+
):
119+
results = engine.dialect.get_columns(connection, fact_table_name)
106120
assert len(results) > 0
107121
row = results[0]
108122
assert isinstance(row, dict)
@@ -113,5 +127,5 @@ def test_get_columns(self, engine: Engine, fact_table_name: str):
113127
assert row_keys[3] == "default"
114128

115129
def test_service_account_connect(self, connection_service_account: Connection):
116-
result = connection_service_account.execute("SELECT 1")
130+
result = connection_service_account.execute(text("SELECT 1"))
117131
assert result.fetchall() == [(1,)]

0 commit comments

Comments
 (0)