Skip to content

Commit 10b0c50

Browse files
authored
feat: Adding set parameter support (#46)
1 parent 38ba50f commit 10b0c50

File tree

5 files changed

+69
-2
lines changed

5 files changed

+69
-2
lines changed

src/firebolt_db/firebolt_async_dialect.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from firebolt.async_db import Connection
99

1010
# Ignoring type since sqlalchemy-stubs doesn't cover AdaptedConnection
11+
# and util.concurrency
1112
from sqlalchemy.engine import AdaptedConnection # type: ignore[attr-defined]
12-
from sqlalchemy.util.concurrency import await_only
13+
from sqlalchemy.util.concurrency import await_only # type: ignore[import]
1314

1415
from firebolt_db.firebolt_dialect import FireboltDialect
1516

src/firebolt_db/firebolt_dialect.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import firebolt.db as dbapi
66
import sqlalchemy.types as sqltypes
7+
from firebolt.db import Cursor
78
from sqlalchemy.engine import Connection as AlchemyConnection
89
from sqlalchemy.engine import ExecutionContext, default
910
from sqlalchemy.engine.url import URL
@@ -87,6 +88,7 @@ class FireboltDialect(default.DefaultDialect):
8788
returns_unicode_strings = True
8889
description_encoding = None
8990
supports_native_boolean = True
91+
_set_parameters: Optional[Dict[str, Any]] = None
9092

9193
def __init__(
9294
self, context: Optional[ExecutionContext] = None, *args: Any, **kwargs: Any
@@ -107,6 +109,10 @@ def create_connect_args(self, url: URL) -> Tuple[List, Dict]:
107109
"password": url.password or None,
108110
"engine_name": url.database,
109111
}
112+
parameters = dict(url.query)
113+
if "account_name" in parameters:
114+
kwargs["account_name"] = parameters.pop("account_name")
115+
self._set_parameters = parameters
110116
# If URL override is not provided leave it to the sdk to determine the endpoint
111117
if "FIREBOLT_BASE_URL" in os.environ:
112118
kwargs["api_endpoint"] = os.environ["FIREBOLT_BASE_URL"]
@@ -257,6 +263,15 @@ def get_view_definition(
257263
) -> str:
258264
pass
259265

266+
def do_execute(
267+
self,
268+
cursor: Cursor,
269+
statement: str,
270+
parameters: Tuple[str, Any],
271+
context: Optional[ExecutionContext] = None,
272+
) -> None:
273+
cursor.execute(statement, set_parameters=self._set_parameters)
274+
260275
def do_rollback(self, dbapi_connection: AlchemyConnection) -> None:
261276
pass
262277

tests/integration/test_sqlalchemy_integration.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from sqlalchemy import create_engine
23
from sqlalchemy.engine.base import Connection, Engine
34
from sqlalchemy.exc import OperationalError
45

@@ -17,6 +18,18 @@ def test_create_ex_table(
1718
connection.execute(f"DROP TABLE {ex_table_name}")
1819
assert not engine.dialect.has_table(engine, ex_table_name)
1920

21+
def test_set_params(
22+
self, username: str, password: str, database_name: str, engine_name: str
23+
):
24+
engine = create_engine(
25+
f"firebolt://{username}:{password}@{database_name}/{engine_name}?"
26+
"advanced_mode=1&use_standard_sql=0"
27+
)
28+
with engine.connect() as connection:
29+
result = connection.execute("SELECT sleepEachRow(1) from numbers(1)")
30+
assert len(result.fetchall()) == 1
31+
engine.dispose()
32+
2033
def test_data_write(self, connection: Connection, fact_table_name: str):
2134
connection.execute(
2235
f"INSERT INTO {fact_table_name}(idx, dummy) VALUES (1, 'some_text')"

tests/unit/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,20 @@ def connect():
3737
pass
3838

3939

40+
class MockCursor:
41+
def execute():
42+
pass
43+
44+
def executemany():
45+
pass
46+
47+
def fetchall():
48+
pass
49+
50+
def close():
51+
pass
52+
53+
4054
class MockAsyncDBApi:
4155
class DatabaseError:
4256
pass
@@ -100,6 +114,11 @@ def connection() -> mock.Mock(spec=MockDBApi):
100114
return mock.Mock(spec=MockDBApi)
101115

102116

117+
@fixture
118+
def cursor() -> mock.Mock(spec=MockCursor):
119+
return mock.Mock(spec=MockCursor)
120+
121+
103122
@fixture
104123
def async_api() -> AsyncMock(spec=MockAsyncDBApi):
105124
return AsyncMock(spec=MockAsyncDBApi)

tests/unit/test_firebolt_dialect.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from unittest import mock
33

44
import sqlalchemy
5-
from conftest import MockDBApi
5+
from conftest import MockCursor, MockDBApi
66
from sqlalchemy.engine import url
77
from sqlalchemy.sql import text
88

@@ -48,6 +48,25 @@ def test_create_connect_args(self, dialect: FireboltDialect):
4848
assert "api_endpoint" not in result_dict
4949
assert result_list == []
5050

51+
def test_create_connect_args_set_params(self, dialect: FireboltDialect):
52+
connection_url = (
53+
"test_engine://test_user@email:test_password@test_db_name/test_engine_name"
54+
"?account_name=FB&param1=1&param2=2"
55+
)
56+
u = url.make_url(connection_url)
57+
result_list, result_dict = dialect.create_connect_args(u)
58+
assert (
59+
"account_name" in result_dict
60+
), "account_name was not parsed correctly from connection string"
61+
assert dialect._set_parameters == {"param1": "1", "param2": "2"}
62+
63+
def test_do_execute(
64+
self, dialect: FireboltDialect, cursor: mock.Mock(spec=MockCursor)
65+
):
66+
dialect._set_parameters = {"a": "b"}
67+
dialect.do_execute(cursor, "SELECT *", None, None)
68+
cursor.execute.assert_called_once_with("SELECT *", set_parameters={"a": "b"})
69+
5170
def test_schema_names(
5271
self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi)
5372
):

0 commit comments

Comments
 (0)