Skip to content

Commit 57621e1

Browse files
authored
feat: Allow set parameters as queries (#61)
1 parent 2bbc0c3 commit 57621e1

File tree

5 files changed

+32
-12
lines changed

5 files changed

+32
-12
lines changed

src/firebolt_db/firebolt_async_dialect.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ def fetchall(self) -> List[List]:
102102
self._rows[:] = []
103103
return retval
104104

105+
@property
106+
def _set_parameters(self) -> Dict[str, Any]:
107+
return self._cursor._set_parameters
108+
109+
@_set_parameters.setter
110+
def _set_parameters(self, value: Dict[str, Any]) -> None:
111+
self._cursor._set_parameters = value
112+
105113

106114
class AsyncConnectionWrapper(AdaptedConnection):
107115
await_ = staticmethod(await_only)

src/firebolt_db/firebolt_dialect.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class FireboltDialect(default.DefaultDialect):
9090
returns_unicode_strings = True
9191
description_encoding = None
9292
supports_native_boolean = True
93-
_set_parameters: Optional[Dict[str, Any]] = None
93+
_set_parameters: Dict[str, Any] = dict()
9494

9595
def __init__(
9696
self, context: Optional[ExecutionContext] = None, *args: Any, **kwargs: Any
@@ -283,9 +283,10 @@ def do_execute(
283283
parameters: Tuple[str, Any],
284284
context: Optional[ExecutionContext] = None,
285285
) -> None:
286-
cursor.execute(
287-
statement, parameters=parameters, set_parameters=self._set_parameters
288-
)
286+
cursor._set_parameters = self._set_parameters
287+
cursor.execute(statement, parameters=parameters)
288+
# Persist set parameters across calls
289+
self._set_parameters = cursor._set_parameters
289290

290291
def do_rollback(self, dbapi_connection: AlchemyConnection) -> None:
291292
pass

tests/integration/test_sqlalchemy_async_integration.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,17 @@ async def test_data_write(self, async_connection: Connection, fact_table_name: s
5050
text(f"DELETE FROM {fact_table_name} WHERE idx=1")
5151
)
5252

53+
@pytest.mark.asyncio
54+
async def test_set_params(self, async_connection: Engine):
55+
await async_connection.execute(text(f"SET advanced_mode=1"))
56+
await async_connection.execute(text(f"SET use_standard_sql=0"))
57+
result = await async_connection.execute(
58+
text(f"SELECT sleepEachRow(1) from numbers(1)")
59+
)
60+
assert len(result.fetchall()) == 1
61+
await async_connection.execute(text(f"SET use_standard_sql=1"))
62+
await async_connection.execute(text(f"SET advanced_mode=0"))
63+
5364
@pytest.mark.asyncio
5465
async def test_get_table_names(self, async_connection: Connection):
5566
def get_table_names(conn: Connection) -> bool:

tests/integration/test_sqlalchemy_integration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ def test_set_params(
2222
self, username: str, password: str, database_name: str, engine_name: str
2323
):
2424
engine = create_engine(
25-
f"firebolt://{username}:{password}@{database_name}/{engine_name}?"
26-
"advanced_mode=1&use_standard_sql=0"
25+
f"firebolt://{username}:{password}@{database_name}/{engine_name}"
2726
)
2827
with engine.connect() as connection:
28+
connection.execute("SET advanced_mode=1")
29+
connection.execute("SET use_standard_sql=0")
2930
result = connection.execute("SELECT sleepEachRow(1) from numbers(1)")
3031
assert len(result.fetchall()) == 1
3132
engine.dispose()

tests/unit/test_firebolt_dialect.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,13 @@ def test_do_execute(
104104
):
105105
dialect._set_parameters = {"a": "b"}
106106
dialect.do_execute(cursor, "SELECT *", None)
107-
cursor.execute.assert_called_once_with(
108-
"SELECT *", parameters=None, set_parameters={"a": "b"}
109-
)
107+
cursor.execute.assert_called_once_with("SELECT *", parameters=None)
108+
assert cursor._set_parameters == {"a": "b"}, "Set parameters were not set"
109+
cursor._set_parameters = {}
110110
cursor.execute.reset_mock()
111111
dialect.do_execute(cursor, "SELECT *", (1, 22), None)
112-
cursor.execute.assert_called_once_with(
113-
"SELECT *", parameters=(1, 22), set_parameters={"a": "b"}
114-
)
112+
cursor.execute.assert_called_once_with("SELECT *", parameters=(1, 22))
113+
assert cursor._set_parameters == {"a": "b"}, "Set parameters were not set"
115114

116115
def test_schema_names(
117116
self, dialect: FireboltDialect, connection: mock.Mock(spec=MockDBApi)

0 commit comments

Comments
 (0)