Skip to content

Commit 189be8e

Browse files
committed
Add implementation of AsyncSession.run_sync()
Currently, `sqlmodel.ext.asyncio.session.AsyncSession` doesn't implement `run_sync()`, which means that any call to `run_sync()` on a sqlmodel `AsyncSession` will be dispatched to the parent `sqlalchemy.ext.asyncio.AsyncSession`. The first argument to sqlalchemy's `AsyncSession.run_sync()` is a callable whose first argument is a `sqlalchemy.orm.Session` object. If we're using this in a repo that uses sqlmodel, we'll actually be passing a callable whose first argument is a `sqlmodel.orm.session.Session`. In practice this works fine - because `sqlmodel.orm.session.Session` is derived from `sqlalchemy.orm.Session`, the implementation of `sqlalchemy.ext.asyncio.AsyncSession.run_sync()` can use the sqlmodel `Session` object in place of the sqlalchemy `Session` object. However, static analysers will complain that the argument to `run_sync()` is of the wrong type. For example, here's a warning from pyright: ``` Pyright: Error: Argument of type "(session: Session, id: UUID) -> int" cannot be assigned to parameter "fn" of type "(Session, **_P@run_sync) -> _T@run_sync" in function "run_sync"   Type "(session: Session, id: UUID) -> int" is not assignable to type "(Session, id: UUID) -> int"     Parameter 1: type "Session" is incompatible with type "Session"       "sqlalchemy.orm.session.Session" is not assignable to "sqlmodel.orm.session.Session" [reportArgumentType] ``` This commit implements a `run_sync()` method on `sqlmodel.ext.asyncio.session.AsyncSession`, which casts the callable to the correct type before dispatching it to the base class. This satisfies the static type checks.
1 parent 6c0410e commit 189be8e

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

sqlmodel/ext/asyncio/session.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import (
22
Any,
3+
Callable,
34
Dict,
45
Mapping,
56
Optional,
@@ -15,12 +16,14 @@
1516
from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams
1617
from sqlalchemy.engine.result import Result, ScalarResult, TupleResult
1718
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
19+
from sqlalchemy.orm import Session as _Session
1820
from sqlalchemy.ext.asyncio.result import _ensure_sync_result
1921
from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS
2022
from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
2123
from sqlalchemy.sql.base import Executable as _Executable
2224
from sqlalchemy.util.concurrency import greenlet_spawn
2325
from typing_extensions import deprecated
26+
from typing import Concatenate
2427

2528
from ...orm.session import Session
2629
from ...sql.base import Executable
@@ -29,6 +32,7 @@
2932
_TSelectParam = TypeVar("_TSelectParam", bound=Any)
3033

3134

35+
3236
class AsyncSession(_AsyncSession):
3337
sync_session_class: Type[Session] = Session
3438
sync_session: Session
@@ -148,3 +152,18 @@ async def execute( # type: ignore
148152
_parent_execute_state=_parent_execute_state,
149153
_add_event=_add_event,
150154
)
155+
156+
async def run_sync[**P, T](
157+
self,
158+
fn: Callable[Concatenate[Session, P], T],
159+
*arg: P.args,
160+
**kw: P.kwargs,
161+
) -> T:
162+
163+
base_fn = cast(Callable[Concatenate[_Session, P], T], fn)
164+
165+
return await super().run_sync(
166+
base_fn,
167+
*arg,
168+
**kw,
169+
)

0 commit comments

Comments
 (0)