Skip to content

Commit 9aaa8c9

Browse files
fix: Multistatement fix idx (#118)
* fixes for multi-statement queries * extend unit tests * add integration tests
1 parent 94892c5 commit 9aaa8c9

File tree

6 files changed

+87
-10
lines changed

6 files changed

+87
-10
lines changed

src/firebolt/async_db/cursor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def _pop_next_set(self) -> Optional[bool]:
222222
self._rowcount, self._descriptions, self._rows = self._row_sets[
223223
self._next_set_idx
224224
]
225+
self._idx = 0
225226
self._next_set_idx += 1
226227
return True
227228

@@ -464,6 +465,11 @@ async def fetchall(self) -> List[List[ColType]]:
464465
return super().fetchall()
465466
"""Fetch all remaining rows of a query result"""
466467

468+
@wraps(BaseCursor.nextset)
469+
async def nextset(self) -> None:
470+
async with self._async_query_lock.reader:
471+
return super().nextset()
472+
467473
# Iteration support
468474
@check_not_closed
469475
@check_query_executed

src/firebolt/db/cursor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def fetchall(self) -> List[List[ColType]]:
7575
with self._query_lock.gen_rlock():
7676
return super().fetchall()
7777

78+
@wraps(AsyncBaseCursor.nextset)
79+
def nextset(self) -> None:
80+
with self._query_lock.gen_rlock(), self._idx_lock:
81+
return super().nextset()
82+
7883
# Iteration support
7984
@check_not_closed
8085
@check_query_executed

tests/integration/dbapi/async/test_queries_async.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,14 +260,15 @@ async def test_multi_statement_query(connection: Connection) -> None:
260260
assert (
261261
await c.execute(
262262
"INSERT INTO test_tb_multi_statement values (1, 'a'), (2, 'b');"
263-
"SELECT * FROM test_tb_multi_statement"
263+
"SELECT * FROM test_tb_multi_statement;"
264+
"SELECT * FROM test_tb_multi_statement WHERE i <= 1"
264265
)
265266
== -1
266267
), "Invalid row count returned for insert"
267268
assert c.rowcount == -1, "Invalid row count"
268269
assert c.description is None, "Invalid description"
269270

270-
assert c.nextset()
271+
assert await c.nextset()
271272

272273
assert c.rowcount == 2, "Invalid select row count"
273274
assert_deep_eq(
@@ -285,4 +286,22 @@ async def test_multi_statement_query(connection: Connection) -> None:
285286
"Invalid data in table after parameterized insert",
286287
)
287288

288-
assert c.nextset() is None
289+
assert await c.nextset()
290+
291+
assert c.rowcount == 1, "Invalid select row count"
292+
assert_deep_eq(
293+
c.description,
294+
[
295+
Column("i", int, None, None, None, None, None),
296+
Column("s", str, None, None, None, None, None),
297+
],
298+
"Invalid select query description",
299+
)
300+
301+
assert_deep_eq(
302+
await c.fetchall(),
303+
[[1, "a"]],
304+
"Invalid data in table after parameterized insert",
305+
)
306+
307+
assert await c.nextset() is None

tests/integration/dbapi/sync/test_queries.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ def test_multi_statement_query(connection: Connection) -> None:
251251
assert (
252252
c.execute(
253253
"INSERT INTO test_tb_multi_statement values (1, 'a'), (2, 'b');"
254-
"SELECT * FROM test_tb_multi_statement"
254+
"SELECT * FROM test_tb_multi_statement;"
255+
"SELECT * FROM test_tb_multi_statement WHERE i <= 1"
255256
)
256257
== -1
257258
), "Invalid row count returned for insert"
@@ -276,4 +277,22 @@ def test_multi_statement_query(connection: Connection) -> None:
276277
"Invalid data in table after parameterized insert",
277278
)
278279

280+
assert c.nextset()
281+
282+
assert c.rowcount == 1, "Invalid select row count"
283+
assert_deep_eq(
284+
c.description,
285+
[
286+
Column("i", int, None, None, None, None, None),
287+
Column("s", str, None, None, None, None, None),
288+
],
289+
"Invalid select query description",
290+
)
291+
292+
assert_deep_eq(
293+
c.fetchall(),
294+
[[1, "a"]],
295+
"Invalid data in table after parameterized insert",
296+
)
297+
279298
assert c.nextset() is None

tests/unit/async_db/test_cursor.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ async def test_closed_cursor(cursor: Cursor):
6363
("fetchone", ()),
6464
("fetchmany", ()),
6565
("fetchall", ()),
66+
("nextset", ()),
6667
)
67-
methods = ("setinputsizes", "setoutputsize", "nextset")
68+
methods = ("setinputsizes", "setoutputsize")
6869

6970
cursor.close()
7071

@@ -439,8 +440,11 @@ async def test_cursor_multi_statement(
439440
httpx_mock.add_callback(auth_callback, url=auth_url)
440441
httpx_mock.add_callback(query_callback, url=query_url)
441442
httpx_mock.add_callback(insert_query_callback, url=query_url)
443+
httpx_mock.add_callback(query_callback, url=query_url)
442444

443-
rc = await cursor.execute("select * from t; insert into t values (1, 2)")
445+
rc = await cursor.execute(
446+
"select * from t; insert into t values (1, 2); select * from t"
447+
)
444448
assert rc == len(python_query_data), "Invalid row count returned"
445449
assert cursor.rowcount == len(python_query_data), "Invalid cursor row count"
446450
for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)):
@@ -451,12 +455,23 @@ async def test_cursor_multi_statement(
451455
await cursor.fetchone() == python_query_data[i]
452456
), f"Invalid data row at position {i}"
453457

454-
assert cursor.nextset()
458+
assert await cursor.nextset()
455459
assert cursor.rowcount == -1, "Invalid cursor row count"
456460
assert cursor.description is None, "Invalid cursor description"
457461
with raises(DataError) as exc_info:
458462
await cursor.fetchall()
459463

460464
assert str(exc_info.value) == "no rows to fetch", "Invalid error message"
461465

462-
assert cursor.nextset() is None
466+
assert await cursor.nextset()
467+
468+
assert cursor.rowcount == len(python_query_data), "Invalid cursor row count"
469+
for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)):
470+
assert desc == exp, f"Invalid column description at position {i}"
471+
472+
for i in range(cursor.rowcount):
473+
assert (
474+
await cursor.fetchone() == python_query_data[i]
475+
), f"Invalid data row at position {i}"
476+
477+
assert await cursor.nextset() is None

tests/unit/db/test_cursor.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_closed_cursor(cursor: Cursor):
6060
("fetchall", ()),
6161
("setinputsizes", (cursor, [0])),
6262
("setoutputsize", (cursor, 0)),
63-
("nextset", (cursor, [])),
63+
("nextset", ()),
6464
)
6565

6666
cursor.close()
@@ -71,6 +71,7 @@ def test_closed_cursor(cursor: Cursor):
7171

7272
for method, args in methods:
7373
with raises(CursorClosedError):
74+
print(method, args)
7475
getattr(cursor, method)(*args)
7576

7677
with raises(CursorClosedError):
@@ -386,8 +387,9 @@ def test_cursor_multi_statement(
386387
httpx_mock.add_callback(auth_callback, url=auth_url)
387388
httpx_mock.add_callback(query_callback, url=query_url)
388389
httpx_mock.add_callback(insert_query_callback, url=query_url)
390+
httpx_mock.add_callback(query_callback, url=query_url)
389391

390-
rc = cursor.execute("select * from t; insert into t values (1, 2)")
392+
rc = cursor.execute("select * from t; insert into t values (1, 2); select * from t")
391393
assert rc == len(python_query_data), "Invalid row count returned"
392394
assert cursor.rowcount == len(python_query_data), "Invalid cursor row count"
393395
for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)):
@@ -406,4 +408,15 @@ def test_cursor_multi_statement(
406408

407409
assert str(exc_info.value) == "no rows to fetch", "Invalid error message"
408410

411+
assert cursor.nextset()
412+
413+
assert cursor.rowcount == len(python_query_data), "Invalid cursor row count"
414+
for i, (desc, exp) in enumerate(zip(cursor.description, python_query_description)):
415+
assert desc == exp, f"Invalid column description at position {i}"
416+
417+
for i in range(cursor.rowcount):
418+
assert (
419+
cursor.fetchone() == python_query_data[i]
420+
), f"Invalid data row at position {i}"
421+
409422
assert cursor.nextset() is None

0 commit comments

Comments
 (0)