From fd66a2a8041a1b31898bbfd10ecc8de70833253f Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 17 Nov 2025 10:58:14 +0000 Subject: [PATCH 1/2] fix(FIR-50880): Async close failure --- src/firebolt_db/firebolt_async_dialect.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/firebolt_db/firebolt_async_dialect.py b/src/firebolt_db/firebolt_async_dialect.py index 23946e3..e1fa5cd 100644 --- a/src/firebolt_db/firebolt_async_dialect.py +++ b/src/firebolt_db/firebolt_async_dialect.py @@ -109,6 +109,15 @@ def _set_parameters(self) -> Dict[str, Any]: def _set_parameters(self, value: Dict[str, Any]) -> None: self._cursor._set_parameters = value + async def _async_soft_close(self) -> None: + """close the cursor but keep the results pending, and memoize the + description. + + We don't have ability to memorize results with async driver yet so + keeping this a no-op. + + """ + class AsyncConnectionWrapper(AdaptedConnection): await_ = staticmethod(await_only) From 44d9d329d21388a9c35e3027593a53ab65189bd7 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 17 Nov 2025 14:50:40 +0000 Subject: [PATCH 2/2] refactor(FIR-50880): Migrate to SQLAlchemy 2.0 async --- pytest.ini | 2 +- setup.cfg | 2 +- src/firebolt_db/firebolt_async_dialect.py | 181 ++++++++---------- tests/integration/test_core_integration.py | 14 +- .../test_sqlalchemy_async_integration.py | 1 - tests/unit/conftest.py | 59 +++++- tests/unit/test_firebolt_async_dialect.py | 52 +++-- 7 files changed, 185 insertions(+), 126 deletions(-) diff --git a/pytest.ini b/pytest.ini index c2c378f..b5c2498 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,4 @@ [pytest] -trio_mode = true +asyncio_mode = auto markers = core: mark test to run only on Firebolt Core. \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 4f51f65..12c5bc5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ dev = mock==4.0.3 mypy==0.910 pre-commit==3.5.0 - pytest==7.2.0 + pytest==9.0.1 pytest-cov==3.0.0 pytest-trio==0.8.0 sqlalchemy-stubs==0.4 diff --git a/src/firebolt_db/firebolt_async_dialect.py b/src/firebolt_db/firebolt_async_dialect.py index e1fa5cd..67d5219 100644 --- a/src/firebolt_db/firebolt_async_dialect.py +++ b/src/firebolt_db/firebolt_async_dialect.py @@ -1,105 +1,49 @@ from __future__ import annotations -from asyncio import Lock +import asyncio from functools import partial from types import ModuleType -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict import firebolt.async_db as async_dbapi -from firebolt.async_db import Connection -from sqlalchemy.engine import AdaptedConnection # type: ignore[attr-defined] +from sqlalchemy.connectors.asyncio import ( + AsyncAdapt_dbapi_connection, + AsyncAdapt_dbapi_cursor, +) # Ignoring type since sqlalchemy-stubs doesn't cover AdaptedConnection # and util.concurrency from sqlalchemy.pool import AsyncAdaptedQueuePool # type: ignore[attr-defined] -from sqlalchemy.util.concurrency import await_only # type: ignore[import] +from sqlalchemy.util.concurrency import await_fallback from trio import run from firebolt_db.firebolt_dialect import FireboltDialect -class AsyncCursorWrapper: - __slots__ = ( - "_adapt_connection", - "_connection", - "await_", - "_cursor", - "_rows", - ) +def _is_trio_context() -> bool: + """Check if we're currently in a Trio async context.""" + try: + import trio - server_side = False - - def __init__(self, adapt_connection: AsyncConnectionWrapper): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - self._rows: List[List] = [] - self._cursor = self._connection.cursor() - - def close(self) -> None: - self._rows[:] = [] - self._cursor.close() - - @property - def description(self) -> str: - return self._cursor.description - - @property - def arraysize(self) -> int: - return self._cursor.arraysize + trio.lowlevel.current_task() + return True + except (ImportError, RuntimeError): + return False - @arraysize.setter - def arraysize(self, value: int) -> None: - self._cursor.arraysize = value - @property - def rowcount(self) -> int: - return self._cursor.rowcount +def _is_asyncio_context() -> bool: + """Check if we're currently in an asyncio context.""" + try: + loop = asyncio.get_running_loop() + return loop.is_running() + except RuntimeError: + return False - def execute( - self, - operation: str, - parameters: Optional[Tuple] = None, - ) -> None: - self.await_(self._execute(operation, parameters)) - - async def _execute( - self, - operation: str, - parameters: Optional[Tuple] = None, - ) -> None: - async with self._adapt_connection._execute_mutex: - await self._cursor.execute(operation, parameters) - if self._cursor.description: - self._rows = await self._cursor.fetchall() - else: - self._rows = [] - - def executemany(self, operation: str, seq_of_parameters: List[Tuple]) -> None: - raise NotImplementedError("executemany is not supported yet") - - def __iter__(self) -> Iterator[List]: - while self._rows: - yield self._rows.pop(0) - - def fetchone(self) -> Optional[List]: - if self._rows: - return self._rows.pop(0) - else: - return None - - def fetchmany(self, size: int = None) -> List[List]: - if size is None: - size = self._cursor.arraysize - retval = self._rows[0:size] - self._rows[:] = self._rows[size:] - return retval +class AsyncCursorWrapper(AsyncAdapt_dbapi_cursor): + __slots__ = () - def fetchall(self) -> List[List]: - retval = self._rows[:] - self._rows[:] = [] - return retval + server_side = False @property def _set_parameters(self) -> Dict[str, Any]: @@ -109,28 +53,47 @@ def _set_parameters(self) -> Dict[str, Any]: def _set_parameters(self, value: Dict[str, Any]) -> None: self._cursor._set_parameters = value + @property + def rowcount(self) -> int: + """Return the rowcount, using memoized value if cursor is closed.""" + # Use hasattr to check if attribute exists and has the value + memoized = getattr(self, "_soft_closed_memoized", {}) + if "rowcount" in memoized: + return memoized["rowcount"] # type: ignore[return-value] + return self._cursor.rowcount + async def _async_soft_close(self) -> None: """close the cursor but keep the results pending, and memoize the - description. + description and rowcount. - We don't have ability to memorize results with async driver yet so - keeping this a no-op. + Copied from SQLAlchemy's AsyncAdapted_dbapi_connection to use aclose() + instead of close(). """ + # Check if cursor can be closed asynchronously + awaitable_close = getattr(self, "_awaitable_cursor_close", False) + if not awaitable_close or self.server_side: + return + + # Get or initialize memoized data + memoized = getattr(self, "_soft_closed_memoized", set()) + if not isinstance(memoized, dict): + memoized = {} + + memoized.update( + { + "description": self._cursor.description, + "rowcount": self._cursor.rowcount, # Memoize rowcount before closing + } + ) + setattr(self, "_soft_closed_memoized", memoized) + await self._cursor.aclose() -class AsyncConnectionWrapper(AdaptedConnection): - await_ = staticmethod(await_only) +class AsyncConnectionWrapper(AsyncAdapt_dbapi_connection): + _cursor_cls = AsyncCursorWrapper __slots__ = ("dbapi", "_connection", "_execute_mutex") - def __init__(self, dbapi: AsyncAPIWrapper, connection: Connection): - self.dbapi = dbapi - self._connection = connection - self._execute_mutex = Lock() - - def cursor(self) -> AsyncCursorWrapper: - return AsyncCursorWrapper(self) - def rollback(self) -> None: pass @@ -138,10 +101,14 @@ def commit(self) -> None: self._connection.commit() def close(self) -> None: - self.await_(self._connection.aclose()) + if _is_trio_context() or _is_asyncio_context(): + await_fallback(self._connection.aclose()) + else: + # Fall back to sync close + self._connection.close() -class AsyncAPIWrapper(ModuleType): +class AsyncAPIWrapper: """Wrapper around Firebolt async dbapi that returns a similar wrapper for Cursor on connect()""" @@ -149,6 +116,7 @@ def __init__(self, dbapi: ModuleType): self.dbapi = dbapi self.paramstyle = dbapi.paramstyle # type: ignore[attr-defined] # noqa: F821 self._init_dbapi_attributes() + self.Cursor = AsyncCursorWrapper def _init_dbapi_attributes(self) -> None: for name in ( @@ -162,14 +130,21 @@ def _init_dbapi_attributes(self) -> None: setattr(self, name, getattr(self.dbapi, name)) def connect(self, *arg: Any, **kw: Any) -> AsyncConnectionWrapper: - # Synchronously establish a connection that can execute - # asynchronous queries later + """Create a connection, handling both sync and async contexts.""" + + # Helper function to create async connection + async def _create_async_connection() -> Any: + return await self.dbapi.connect(*arg, **kw) # type: ignore[attr-defined] # noqa: F821,E501 + + # Check if we're in an async context + if _is_trio_context() or _is_asyncio_context(): + connection = await_fallback(_create_async_connection()) + return AsyncConnectionWrapper(self, connection) + + # No async context detected, use trio.run for synchronous connection creation conn_func = partial(self.dbapi.connect, *arg, **kw) # type: ignore[attr-defined] # noqa: F821,E501 connection = run(conn_func) - return AsyncConnectionWrapper( - self, - connection, - ) + return AsyncConnectionWrapper(self, connection) class AsyncFireboltDialect(FireboltDialect): @@ -180,7 +155,7 @@ class AsyncFireboltDialect(FireboltDialect): poolclass = AsyncAdaptedQueuePool @classmethod - def dbapi(cls) -> AsyncAPIWrapper: + def dbapi(cls) -> AsyncAPIWrapper: # type: ignore[override] return AsyncAPIWrapper(async_dbapi) diff --git a/tests/integration/test_core_integration.py b/tests/integration/test_core_integration.py index 64018c5..b512108 100644 --- a/tests/integration/test_core_integration.py +++ b/tests/integration/test_core_integration.py @@ -1,6 +1,6 @@ import pytest from firebolt.client.auth import FireboltCore -from sqlalchemy import create_engine, text +from sqlalchemy import Column, Integer, MetaData, Table, create_engine, text from sqlalchemy.engine.base import Connection, Engine from sqlalchemy.exc import InterfaceError @@ -18,6 +18,18 @@ def test_core_engine_auth(self, core_engine: Engine): auth = connect_args[1]["auth"] assert isinstance(auth, FireboltCore) + def test_core_table_with_column_int_special_case(self, core_connection: Connection): + """Test handling of table with column named 'int' in Core.""" + table = Table("test_int", MetaData(), Column("int", Integer)) + table.create(core_connection) + stmt = table.insert().values(int=42) + core_connection.execute(stmt) + result = core_connection.execute(table.select()) + rows = result.fetchall() + assert len(rows) == 1 + assert rows[0][0] == 42 + table.drop(core_connection) + def test_core_simple_query(self, core_connection: Connection): """Test executing a simple query against Core.""" result = core_connection.execute(text("SELECT 'Hello Core' as message")) diff --git a/tests/integration/test_sqlalchemy_async_integration.py b/tests/integration/test_sqlalchemy_async_integration.py index 8573146..b93c964 100644 --- a/tests/integration/test_sqlalchemy_async_integration.py +++ b/tests/integration/test_sqlalchemy_async_integration.py @@ -5,7 +5,6 @@ from sqlalchemy.engine.base import Connection, Engine -@pytest.mark.skip("FIR-47589") @pytest.mark.usefixtures("setup_test_tables") class TestAsyncFireboltDialect: async def test_create_ex_table( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 497abb4..c82df22 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -77,7 +77,20 @@ async def connect(): class MockAsyncConnection: - def cursor(): + def cursor(self): + # Mock implementation for cursor creation + pass + + def commit(self): + # Mock implementation for commit + pass + + def rollback(self): + # Mock implementation for rollback + pass + + async def aclose(self): + # Mock implementation for async close pass @@ -86,16 +99,31 @@ class MockAsyncCursor: rowcount = -1 arraysize = 1 - async def execute(): + async def execute(self): + # Mock implementation for async execute pass - async def executemany(): + async def executemany(self, **kwargs): + # Mock implementation for async executemany pass - async def fetchall(): + async def fetchall(self): + # Mock implementation for async fetchall pass - def close(): + def close(self): + # Mock implementation for close + pass + + async def aclose(self): + # Mock implementation for async close + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Mock implementation for async context manager exit pass @@ -131,4 +159,23 @@ def async_connection() -> AsyncMock(spec=MockAsyncConnection): @fixture def async_cursor() -> AsyncMock(spec=MockAsyncCursor): - return AsyncMock(spec=MockAsyncCursor) + mock = AsyncMock(spec=MockAsyncCursor) + # Make sure the async context manager methods return the mock itself + + async def aenter(): + # Return the mock cursor for async context manager entry + return mock + + async def aexit(*args): + # Mock implementation for async context manager exit + pass + + # Make sure close() returns a coroutine that can be awaited + async def close_coro(): + # Mock implementation for async close + pass + + mock.__aenter__ = AsyncMock(side_effect=aenter) + mock.__aexit__ = AsyncMock(side_effect=aexit) + mock.close.return_value = close_coro() + return mock diff --git a/tests/unit/test_firebolt_async_dialect.py b/tests/unit/test_firebolt_async_dialect.py index 3753024..079150d 100644 --- a/tests/unit/test_firebolt_async_dialect.py +++ b/tests/unit/test_firebolt_async_dialect.py @@ -1,7 +1,9 @@ +from collections import deque + import pytest from conftest import MockAsyncConnection, MockAsyncCursor, MockAsyncDBApi from mock import AsyncMock -from sqlalchemy.util import await_only, greenlet_spawn +from sqlalchemy.util import greenlet_spawn from firebolt_db.firebolt_async_dialect import ( AsyncAPIWrapper, @@ -45,9 +47,26 @@ def test_connect() -> AsyncAPIWrapper: assert wrapper.paramstyle == "quoted" async_api.connect.assert_called_once_with("test arg") - async def test_connection_wrapper(self, async_api: AsyncMock(spec=MockAsyncDBApi)): + async def test_connection_wrapper( + self, + async_api, + async_connection, + async_cursor, + ): def test_connection() -> AsyncConnectionWrapper: - wrapper = AsyncConnectionWrapper(async_api, await_only(async_api.connect())) + # Set up the mock to return the async_connection when connect() is called + async_api.connect.return_value = async_connection + # Set up the connection to return the async_cursor when cursor() is called + async_connection.cursor.return_value = async_cursor + + wrapper = AsyncConnectionWrapper(async_api, async_connection) + + # Monkey patch the _aenter_cursor method to avoid the __aenter__ issue + def mock_aenter_cursor(self, cursor): + return cursor # Return cursor directly without calling __aenter__ + + wrapper._cursor_cls._aenter_cursor = mock_aenter_cursor + # Check call propagation wrapper.commit() wrapper.rollback() @@ -55,9 +74,12 @@ def test_connection() -> AsyncConnectionWrapper: return wrapper wrapper = await greenlet_spawn(test_connection) - assert isinstance(wrapper.cursor(), AsyncCursorWrapper) - async_api.connect.return_value.commit.assert_called_once() - async_api.connect.return_value.aclose.assert_awaited_once() + + cursor_wrapper = wrapper.cursor() + assert isinstance(cursor_wrapper, AsyncCursorWrapper) + + async_connection.commit.assert_called_once() + async_connection.aclose.assert_awaited_once() async def test_cursor_execute( self, @@ -116,9 +138,9 @@ def test_cursor(): async_connection.cursor.return_value = async_cursor conn_wrapper = AsyncConnectionWrapper(async_api, async_connection) wrapper = AsyncCursorWrapper(conn_wrapper) - wrapper._rows = [1, 2, 3] + wrapper._rows = deque([1, 2, 3]) wrapper.close() - assert wrapper._rows == [] + assert len(wrapper._rows) == 0 async_cursor.close.assert_called_once() await greenlet_spawn(test_cursor) @@ -133,12 +155,14 @@ def test_cursor(): async_connection.cursor.return_value = async_cursor conn_wrapper = AsyncConnectionWrapper(async_api, async_connection) wrapper = AsyncCursorWrapper(conn_wrapper) - with pytest.raises(NotImplementedError): - wrapper.executemany( - "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a"), (2, "b")] - ) + wrapper.executemany( + "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a"), (2, "b")] + ) await greenlet_spawn(test_cursor) + async_cursor.executemany.assert_awaited_once_with( + "INSERT INTO test(a, b) VALUES (?, ?)", [(1, "a"), (2, "b")] + ) async def test_cursor_fetch( self, @@ -148,9 +172,11 @@ async def test_cursor_fetch( ): def test_cursor(): async_connection.cursor.return_value = async_cursor + # Set arraysize to 1 initially + async_cursor.arraysize = 1 conn_wrapper = AsyncConnectionWrapper(async_api, async_connection) wrapper = AsyncCursorWrapper(conn_wrapper) - wrapper._rows = [1, 2, 3, 4, 5, 6, 7, 8] + wrapper._rows = deque([1, 2, 3, 4, 5, 6, 7, 8]) assert wrapper.fetchone() == 1 assert wrapper.fetchmany() == [2] async_cursor.arraysize = 2