Skip to content

Commit b5e850b

Browse files
committed
update async extension and add tests
1 parent 088164e commit b5e850b

File tree

5 files changed

+69
-2
lines changed

5 files changed

+69
-2
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ fastapi = "^0.68.1"
5050
requests = "^2.26.0"
5151
autoflake = "^1.4"
5252
isort = "^5.9.3"
53+
testcontainers = "^3.7.1"
54+
psycopg2-binary = "^2.9.7"
55+
asyncpg = "^0.28.0"
5356

5457
[build-system]
5558
requires = ["poetry-core"]

sqlmodel/ext/asyncio/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .engine import create_async_engine as create_async_engine
2+
from .session import AsyncSession as AsyncSession

sqlmodel/ext/asyncio/engine.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import Any
2+
3+
from sqlalchemy.ext.asyncio import AsyncEngine
4+
from sqlalchemy.ext.asyncio import create_async_engine as _create_async_engine
5+
6+
7+
# create_async_engine by default already has future set to be true.
8+
# Porting this over to sqlmodel to make it easier to use.
9+
def create_async_engine(*args: Any, **kwargs: Any) -> AsyncEngine:
10+
return _create_async_engine(*args, **kwargs)

sqlmodel/ext/asyncio/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ...engine.result import ScalarResult
1111
from ...orm.session import Session
12-
from ...sql.expression import Select
12+
from ...sql.expression import Select, SelectOfScalar
1313

1414
_T = TypeVar("_T")
1515

@@ -42,7 +42,7 @@ def __init__(
4242

4343
async def exec(
4444
self,
45-
statement: Union[Select[_T], Executable[_T]],
45+
statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]],
4646
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
4747
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
4848
bind_arguments: Optional[Mapping[str, Any]] = None,

tests/test_async.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import asyncio
2+
from typing import Generator, Optional
3+
4+
import pytest
5+
from sqlmodel import Field, SQLModel, select
6+
from sqlmodel.ext.asyncio import AsyncSession, create_async_engine
7+
from testcontainers.postgres import PostgresContainer
8+
9+
10+
# The first time this test is run, it will download the postgres image which can take
11+
# a while. Subsequent runs will be much faster.
12+
@pytest.fixture(scope="module")
13+
def postgres_container_url() -> Generator[str, None, None]:
14+
with PostgresContainer("postgres:13") as postgres:
15+
postgres.driver = "asyncpg"
16+
yield postgres.get_connection_url()
17+
18+
19+
async def _test_async_create(postgres_container_url: str) -> None:
20+
class Hero(SQLModel, table=True):
21+
# SQLModel.metadata is a singleton and the Hero Class has already been defined.
22+
# If I flush the metadata during this test, it will cause test_enum to fail
23+
# because in that file, the model isn't defined within a function. For now, the
24+
# workaround is to set extend_existing to True. In the future, test setup and
25+
# teardown should be refactored to avoid this issue.
26+
__table_args__ = {"extend_existing": True}
27+
28+
id: Optional[int] = Field(default=None, primary_key=True)
29+
name: str
30+
secret_name: str
31+
age: Optional[int] = None
32+
33+
hero_create = Hero(name="Deadpond", secret_name="Dive Wilson")
34+
35+
engine = create_async_engine(postgres_container_url)
36+
async with engine.begin() as conn:
37+
await conn.run_sync(SQLModel.metadata.create_all)
38+
39+
async with AsyncSession(engine) as session:
40+
session.add(hero_create)
41+
await session.commit()
42+
await session.refresh(hero_create)
43+
44+
async with AsyncSession(engine) as session:
45+
statement = select(Hero).where(Hero.name == "Deadpond")
46+
results = await session.exec(statement)
47+
hero_query = results.one()
48+
assert hero_create == hero_query
49+
50+
51+
def test_async_create(postgres_container_url: str) -> None:
52+
asyncio.run(_test_async_create(postgres_container_url))

0 commit comments

Comments
 (0)