1+ import inspect
2+ import os
3+ from contextlib import _AsyncGeneratorContextManager , asynccontextmanager
4+ from typing import Tuple , Type , Union
15from uuid import uuid4
26
37import pytest
8+ from sqlalchemy import Column , ForeignKey , Integer , String
9+ from sqlalchemy .orm import clear_mappers , relationship
410
511from sqlalchemy_bind_manager import SQLAlchemyAsyncConfig , SQLAlchemyConfig
12+ from sqlalchemy_bind_manager ._bind_manager import (
13+ SQLAlchemyAsyncBind ,
14+ SQLAlchemyBind ,
15+ SQLAlchemyBindManager ,
16+ )
17+ from sqlalchemy_bind_manager ._repository import (
18+ SQLAlchemyAsyncRepository ,
19+ SQLAlchemyRepository ,
20+ )
21+ from sqlalchemy_bind_manager ._session_handler import AsyncSessionHandler , SessionHandler
22+ from sqlalchemy_bind_manager .repository import AsyncUnitOfWork , UnitOfWork
623
724
825@pytest .fixture
@@ -25,3 +42,146 @@ def multiple_config():
2542 engine_options = dict (connect_args = {"check_same_thread" : False }),
2643 ),
2744 }
45+
46+
47+ @pytest .fixture ()
48+ def sync_async_wrapper ():
49+ """
50+ Tiny wrapper to allow calling sync and async methods using await.
51+
52+ :return:
53+ """
54+
55+ async def f (call ):
56+ return await call if inspect .iscoroutine (call ) else call
57+
58+ return f
59+
60+
61+ @pytest .fixture ()
62+ def sync_async_cm_wrapper ():
63+ """
64+ Tiny wrapper to allow calling sync and async methods using await.
65+
66+ :return:
67+ """
68+
69+ @asynccontextmanager
70+ async def f (cm ):
71+ if isinstance (cm , _AsyncGeneratorContextManager ):
72+ async with cm as c :
73+ yield c
74+ else :
75+ with cm as c :
76+ yield c
77+
78+ return f
79+
80+
81+ @pytest .fixture
82+ def sa_manager () -> SQLAlchemyBindManager :
83+ test_sync_db_path = f"./{ uuid4 ()} .db"
84+ test_async_db_path = f"./{ uuid4 ()} .db"
85+ config = {
86+ "sync" : SQLAlchemyConfig (
87+ engine_url = f"sqlite:///{ test_sync_db_path } " ,
88+ engine_options = dict (connect_args = {"check_same_thread" : False }),
89+ ),
90+ "async" : SQLAlchemyAsyncConfig (
91+ engine_url = f"sqlite+aiosqlite:///{ test_sync_db_path } " ,
92+ engine_options = dict (connect_args = {"check_same_thread" : False }),
93+ ),
94+ }
95+
96+ yield SQLAlchemyBindManager (config )
97+ try :
98+ os .unlink (test_sync_db_path )
99+ except FileNotFoundError :
100+ pass
101+
102+ try :
103+ os .unlink (test_async_db_path )
104+ except FileNotFoundError :
105+ pass
106+
107+ clear_mappers ()
108+
109+
110+ @pytest .fixture (params = ["sync" , "async" ])
111+ def sa_bind (request , sa_manager ):
112+ return sa_manager .get_bind (request .param )
113+
114+
115+ @pytest .fixture
116+ async def model_classes (sa_bind ) -> Tuple [Type , Type ]:
117+ class ParentModel (sa_bind .model_declarative_base ):
118+ __tablename__ = "parent_model"
119+ # required in order to access columns with server defaults
120+ # or SQL expression defaults, subsequent to a flush, without
121+ # triggering an expired load
122+ __mapper_args__ = {"eager_defaults" : True }
123+
124+ model_id = Column (Integer , primary_key = True , autoincrement = True )
125+ name = Column (String )
126+
127+ children = relationship (
128+ "ChildModel" ,
129+ back_populates = "parent" ,
130+ cascade = "all, delete-orphan" ,
131+ lazy = "selectin" ,
132+ )
133+
134+ class ChildModel (sa_bind .model_declarative_base ):
135+ __tablename__ = "child_model"
136+ # required in order to access columns with server defaults
137+ # or SQL expression defaults, subsequent to a flush, without
138+ # triggering an expired load
139+ __mapper_args__ = {"eager_defaults" : True }
140+
141+ model_id = Column (Integer , primary_key = True , autoincrement = True )
142+ parent_model_id = Column (
143+ Integer , ForeignKey ("parent_model.model_id" ), nullable = False
144+ )
145+ name = Column (String )
146+
147+ parent = relationship ("ParentModel" , back_populates = "children" , lazy = "selectin" )
148+
149+ if isinstance (sa_bind , SQLAlchemyBind ):
150+ sa_bind .registry_mapper .metadata .create_all (sa_bind .engine )
151+ else :
152+ async with sa_bind .engine .begin () as conn :
153+ await conn .run_sync (sa_bind .registry_mapper .metadata .create_all )
154+
155+ return ParentModel , ChildModel
156+
157+
158+ @pytest .fixture
159+ async def model_class (model_classes : Tuple [Type , Type ]) -> Type :
160+ return model_classes [0 ]
161+
162+
163+ @pytest .fixture
164+ def session_handler_class (sa_bind ):
165+ return (
166+ AsyncSessionHandler
167+ if isinstance (sa_bind , SQLAlchemyAsyncBind )
168+ else SessionHandler
169+ )
170+
171+
172+ @pytest .fixture
173+ def repository_class (
174+ sa_bind : Union [SQLAlchemyBind , SQLAlchemyAsyncBind ]
175+ ) -> Type [Union [SQLAlchemyAsyncRepository , SQLAlchemyRepository ]]:
176+ base_class = (
177+ SQLAlchemyRepository
178+ if isinstance (sa_bind , SQLAlchemyBind )
179+ else SQLAlchemyAsyncRepository
180+ )
181+
182+ return base_class
183+
184+
185+ @pytest .fixture
186+ def uow_class (sa_bind ):
187+ return AsyncUnitOfWork if isinstance (sa_bind , SQLAlchemyAsyncBind ) else UnitOfWork
0 commit comments