Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[pytest]
trio_mode = true
asyncio_mode = auto
markers =
core: mark test to run only on Firebolt Core.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dev =
mock==4.0.3
mypy==0.910
pre-commit==3.5.0
pytest==7.2.0
pytest==9.0.1
pytest-cov==3.0.0
pytest-trio==0.8.0
sqlalchemy-stubs==0.4
Expand Down
182 changes: 83 additions & 99 deletions src/firebolt_db/firebolt_async_dialect.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,49 @@
from __future__ import annotations

from asyncio import Lock
import asyncio
from functools import partial
from types import ModuleType
from typing import Any, Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict

import firebolt.async_db as async_dbapi
from firebolt.async_db import Connection
from sqlalchemy.engine import AdaptedConnection # type: ignore[attr-defined]
from sqlalchemy.connectors.asyncio import (
AsyncAdapt_dbapi_connection,
AsyncAdapt_dbapi_cursor,
)

# Ignoring type since sqlalchemy-stubs doesn't cover AdaptedConnection
# and util.concurrency
from sqlalchemy.pool import AsyncAdaptedQueuePool # type: ignore[attr-defined]
from sqlalchemy.util.concurrency import await_only # type: ignore[import]
from sqlalchemy.util.concurrency import await_fallback
from trio import run

from firebolt_db.firebolt_dialect import FireboltDialect


class AsyncCursorWrapper:
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
def _is_trio_context() -> bool:
"""Check if we're currently in a Trio async context."""
try:
import trio

server_side = False

def __init__(self, adapt_connection: AsyncConnectionWrapper):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
self._rows: List[List] = []
self._cursor = self._connection.cursor()

def close(self) -> None:
self._rows[:] = []
self._cursor.close()
trio.lowlevel.current_task()
return True
except (ImportError, RuntimeError):
return False

@property
def description(self) -> str:
return self._cursor.description

@property
def arraysize(self) -> int:
return self._cursor.arraysize

@arraysize.setter
def arraysize(self, value: int) -> None:
self._cursor.arraysize = value

@property
def rowcount(self) -> int:
return self._cursor.rowcount

def execute(
self,
operation: str,
parameters: Optional[Tuple] = None,
) -> None:
self.await_(self._execute(operation, parameters))

async def _execute(
self,
operation: str,
parameters: Optional[Tuple] = None,
) -> None:
async with self._adapt_connection._execute_mutex:
await self._cursor.execute(operation, parameters)
if self._cursor.description:
self._rows = await self._cursor.fetchall()
else:
self._rows = []

def executemany(self, operation: str, seq_of_parameters: List[Tuple]) -> None:
raise NotImplementedError("executemany is not supported yet")

def __iter__(self) -> Iterator[List]:
while self._rows:
yield self._rows.pop(0)

def fetchone(self) -> Optional[List]:
if self._rows:
return self._rows.pop(0)
else:
return None
def _is_asyncio_context() -> bool:
"""Check if we're currently in an asyncio context."""
try:
loop = asyncio.get_running_loop()
return loop.is_running()
except RuntimeError:
return False

def fetchmany(self, size: int = None) -> List[List]:
if size is None:
size = self._cursor.arraysize

retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
class AsyncCursorWrapper(AsyncAdapt_dbapi_cursor):
__slots__ = ()

def fetchall(self) -> List[List]:
retval = self._rows[:]
self._rows[:] = []
return retval
server_side = False

@property
def _set_parameters(self) -> Dict[str, Any]:
Expand All @@ -109,18 +53,46 @@ def _set_parameters(self) -> Dict[str, Any]:
def _set_parameters(self, value: Dict[str, Any]) -> None:
self._cursor._set_parameters = value

@property
def rowcount(self) -> int:
"""Return the rowcount, using memoized value if cursor is closed."""
# Use hasattr to check if attribute exists and has the value
memoized = getattr(self, "_soft_closed_memoized", {})
if "rowcount" in memoized:
return memoized["rowcount"] # type: ignore[return-value]
return self._cursor.rowcount

class AsyncConnectionWrapper(AdaptedConnection):
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_connection", "_execute_mutex")
async def _async_soft_close(self) -> None:
"""close the cursor but keep the results pending, and memoize the
description and rowcount.

Copied from SQLAlchemy's AsyncAdapted_dbapi_connection to use aclose()
instead of close().

"""
# Check if cursor can be closed asynchronously
awaitable_close = getattr(self, "_awaitable_cursor_close", False)
if not awaitable_close or self.server_side:
return

# Get or initialize memoized data
memoized = getattr(self, "_soft_closed_memoized", set())
if not isinstance(memoized, dict):
memoized = {}

memoized.update(
{
"description": self._cursor.description,
"rowcount": self._cursor.rowcount, # Memoize rowcount before closing
}
)
setattr(self, "_soft_closed_memoized", memoized)
await self._cursor.aclose()

def __init__(self, dbapi: AsyncAPIWrapper, connection: Connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = Lock()

def cursor(self) -> AsyncCursorWrapper:
return AsyncCursorWrapper(self)
class AsyncConnectionWrapper(AsyncAdapt_dbapi_connection):
_cursor_cls = AsyncCursorWrapper
__slots__ = ("dbapi", "_connection", "_execute_mutex")

def rollback(self) -> None:
pass
Expand All @@ -129,17 +101,22 @@ def commit(self) -> None:
self._connection.commit()

def close(self) -> None:
self.await_(self._connection.aclose())
if _is_trio_context() or _is_asyncio_context():
await_fallback(self._connection.aclose())
else:
# Fall back to sync close
self._connection.close()


class AsyncAPIWrapper(ModuleType):
class AsyncAPIWrapper:
"""Wrapper around Firebolt async dbapi that returns a similar wrapper for
Cursor on connect()"""

def __init__(self, dbapi: ModuleType):
self.dbapi = dbapi
self.paramstyle = dbapi.paramstyle # type: ignore[attr-defined] # noqa: F821
self._init_dbapi_attributes()
self.Cursor = AsyncCursorWrapper

def _init_dbapi_attributes(self) -> None:
for name in (
Expand All @@ -153,14 +130,21 @@ def _init_dbapi_attributes(self) -> None:
setattr(self, name, getattr(self.dbapi, name))

def connect(self, *arg: Any, **kw: Any) -> AsyncConnectionWrapper:
# Synchronously establish a connection that can execute
# asynchronous queries later
"""Create a connection, handling both sync and async contexts."""

# Helper function to create async connection
async def _create_async_connection() -> Any:
return await self.dbapi.connect(*arg, **kw) # type: ignore[attr-defined] # noqa: F821,E501

# Check if we're in an async context
if _is_trio_context() or _is_asyncio_context():
connection = await_fallback(_create_async_connection())
return AsyncConnectionWrapper(self, connection)

# No async context detected, use trio.run for synchronous connection creation
conn_func = partial(self.dbapi.connect, *arg, **kw) # type: ignore[attr-defined] # noqa: F821,E501
connection = run(conn_func)
return AsyncConnectionWrapper(
self,
connection,
)
return AsyncConnectionWrapper(self, connection)


class AsyncFireboltDialect(FireboltDialect):
Expand All @@ -171,7 +155,7 @@ class AsyncFireboltDialect(FireboltDialect):
poolclass = AsyncAdaptedQueuePool

@classmethod
def dbapi(cls) -> AsyncAPIWrapper:
def dbapi(cls) -> AsyncAPIWrapper: # type: ignore[override]
return AsyncAPIWrapper(async_dbapi)


Expand Down
14 changes: 13 additions & 1 deletion tests/integration/test_core_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from firebolt.client.auth import FireboltCore
from sqlalchemy import create_engine, text
from sqlalchemy import Column, Integer, MetaData, Table, create_engine, text
from sqlalchemy.engine.base import Connection, Engine
from sqlalchemy.exc import InterfaceError

Expand All @@ -18,6 +18,18 @@ def test_core_engine_auth(self, core_engine: Engine):
auth = connect_args[1]["auth"]
assert isinstance(auth, FireboltCore)

def test_core_table_with_column_int_special_case(self, core_connection: Connection):
"""Test handling of table with column named 'int' in Core."""
table = Table("test_int", MetaData(), Column("int", Integer))
table.create(core_connection)
stmt = table.insert().values(int=42)
core_connection.execute(stmt)
result = core_connection.execute(table.select())
rows = result.fetchall()
assert len(rows) == 1
assert rows[0][0] == 42
table.drop(core_connection)

def test_core_simple_query(self, core_connection: Connection):
"""Test executing a simple query against Core."""
result = core_connection.execute(text("SELECT 'Hello Core' as message"))
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_sqlalchemy_async_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from sqlalchemy.engine.base import Connection, Engine


@pytest.mark.skip("FIR-47589")
@pytest.mark.usefixtures("setup_test_tables")
class TestAsyncFireboltDialect:
async def test_create_ex_table(
Expand Down
59 changes: 53 additions & 6 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,20 @@ async def connect():


class MockAsyncConnection:
def cursor():
def cursor(self):
# Mock implementation for cursor creation
pass

def commit(self):
# Mock implementation for commit
pass

def rollback(self):
# Mock implementation for rollback
pass

async def aclose(self):
# Mock implementation for async close
pass


Expand All @@ -86,16 +99,31 @@ class MockAsyncCursor:
rowcount = -1
arraysize = 1

async def execute():
async def execute(self):
# Mock implementation for async execute
pass

async def executemany():
async def executemany(self, **kwargs):
# Mock implementation for async executemany
pass

async def fetchall():
async def fetchall(self):
# Mock implementation for async fetchall
pass

def close():
def close(self):
# Mock implementation for close
pass

async def aclose(self):
# Mock implementation for async close
pass

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
# Mock implementation for async context manager exit
pass


Expand Down Expand Up @@ -131,4 +159,23 @@ def async_connection() -> AsyncMock(spec=MockAsyncConnection):

@fixture
def async_cursor() -> AsyncMock(spec=MockAsyncCursor):
return AsyncMock(spec=MockAsyncCursor)
mock = AsyncMock(spec=MockAsyncCursor)
# Make sure the async context manager methods return the mock itself

async def aenter():
# Return the mock cursor for async context manager entry
return mock

async def aexit(*args):
# Mock implementation for async context manager exit
pass

# Make sure close() returns a coroutine that can be awaited
async def close_coro():
# Mock implementation for async close
pass

mock.__aenter__ = AsyncMock(side_effect=aenter)
mock.__aexit__ = AsyncMock(side_effect=aexit)
mock.close.return_value = close_coro()
return mock
Loading