Skip to content

Commit 7a8081f

Browse files
committed
Prepare fixtures for combinated sync/async tests
1 parent 000e20b commit 7a8081f

File tree

2 files changed

+121
-1
lines changed

2 files changed

+121
-1
lines changed

tests/conftest.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import inspect
2+
import os
3+
from typing import Tuple, Type
14
from uuid import uuid4
25

36
import pytest
7+
from sqlalchemy import Column, ForeignKey, Integer, String
8+
from sqlalchemy.orm import clear_mappers, relationship
49

510
from sqlalchemy_bind_manager import SQLAlchemyAsyncConfig, SQLAlchemyConfig
11+
from sqlalchemy_bind_manager._bind_manager import SQLAlchemyBind, SQLAlchemyBindManager
612

713

814
@pytest.fixture
@@ -25,3 +31,99 @@ def multiple_config():
2531
engine_options=dict(connect_args={"check_same_thread": False}),
2632
),
2733
}
34+
35+
36+
@pytest.fixture()
37+
def sync_async_wrapper():
38+
"""
39+
Tiny wrapper to allow calling sync and async methods using await.
40+
41+
:return:
42+
"""
43+
44+
async def f(call):
45+
return await call if inspect.iscoroutine(call) else call
46+
47+
return f
48+
49+
50+
@pytest.fixture
51+
def sa_manager() -> SQLAlchemyBindManager:
52+
test_sync_db_path = f"./{uuid4()}.db"
53+
test_async_db_path = f"./{uuid4()}.db"
54+
config = {
55+
"sync": SQLAlchemyConfig(
56+
engine_url=f"sqlite:///{test_sync_db_path}",
57+
engine_options=dict(connect_args={"check_same_thread": False}),
58+
),
59+
"async": SQLAlchemyAsyncConfig(
60+
engine_url=f"sqlite+aiosqlite:///{test_sync_db_path}",
61+
engine_options=dict(connect_args={"check_same_thread": False}),
62+
),
63+
}
64+
65+
yield SQLAlchemyBindManager(config)
66+
try:
67+
os.unlink(test_sync_db_path)
68+
except FileNotFoundError:
69+
pass
70+
71+
try:
72+
os.unlink(test_async_db_path)
73+
except FileNotFoundError:
74+
pass
75+
76+
clear_mappers()
77+
78+
79+
@pytest.fixture(params=["sync", "async"])
80+
def sa_bind(request, sa_manager):
81+
return sa_manager.get_bind(request.param)
82+
83+
84+
@pytest.fixture
85+
async def model_classes(sa_bind) -> Tuple[Type, Type]:
86+
class ParentModel(sa_bind.model_declarative_base):
87+
__tablename__ = "parent_model"
88+
# required in order to access columns with server defaults
89+
# or SQL expression defaults, subsequent to a flush, without
90+
# triggering an expired load
91+
__mapper_args__ = {"eager_defaults": True}
92+
93+
model_id = Column(Integer, primary_key=True, autoincrement=True)
94+
name = Column(String)
95+
96+
children = relationship(
97+
"ChildModel",
98+
back_populates="parent",
99+
cascade="all, delete-orphan",
100+
lazy="selectin",
101+
)
102+
103+
class ChildModel(sa_bind.model_declarative_base):
104+
__tablename__ = "child_model"
105+
# required in order to access columns with server defaults
106+
# or SQL expression defaults, subsequent to a flush, without
107+
# triggering an expired load
108+
__mapper_args__ = {"eager_defaults": True}
109+
110+
model_id = Column(Integer, primary_key=True, autoincrement=True)
111+
parent_model_id = Column(
112+
Integer, ForeignKey("parent_model.model_id"), nullable=False
113+
)
114+
name = Column(String)
115+
116+
parent = relationship("ParentModel", back_populates="children", lazy="selectin")
117+
118+
if isinstance(sa_bind, SQLAlchemyBind):
119+
sa_bind.registry_mapper.metadata.create_all(sa_bind.engine)
120+
else:
121+
async with sa_bind.engine.begin() as conn:
122+
await conn.run_sync(sa_bind.registry_mapper.metadata.create_all)
123+
124+
return ParentModel, ChildModel
125+
126+
127+
@pytest.fixture
128+
async def model_class(model_classes: Tuple[Type, Type]) -> Type:
129+
return model_classes[0]

tests/repository/conftest.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Iterator
2+
from typing import Iterator, Union
33
from uuid import uuid4
44

55
import pytest
@@ -10,6 +10,11 @@
1010
SQLAlchemyBindManager,
1111
SQLAlchemyConfig,
1212
)
13+
from sqlalchemy_bind_manager._bind_manager import SQLAlchemyAsyncBind, SQLAlchemyBind
14+
from sqlalchemy_bind_manager._repository import (
15+
SQLAlchemyAsyncRepository,
16+
SQLAlchemyRepository,
17+
)
1318

1419

1520
@pytest.fixture
@@ -39,3 +44,16 @@ def sync_async_sa_manager(multiple_config) -> Iterator[SQLAlchemyBindManager]:
3944
pass
4045

4146
clear_mappers()
47+
48+
49+
@pytest.fixture
50+
def repository_class(
51+
sa_bind: Union[SQLAlchemyBind, SQLAlchemyAsyncBind]
52+
) -> Union[SQLAlchemyAsyncRepository, SQLAlchemyRepository]:
53+
base_class = (
54+
SQLAlchemyRepository
55+
if isinstance(sa_bind, SQLAlchemyBind)
56+
else SQLAlchemyAsyncRepository
57+
)
58+
59+
return base_class

0 commit comments

Comments
 (0)