1+ import inspect
2+ import os
3+ from typing import Tuple , Type
14from uuid import uuid4
25
36import pytest
7+ from sqlalchemy import Column , ForeignKey , Integer , String
8+ from sqlalchemy .orm import clear_mappers , relationship
49
510from sqlalchemy_bind_manager import SQLAlchemyAsyncConfig , SQLAlchemyConfig
11+ from sqlalchemy_bind_manager ._bind_manager import SQLAlchemyBind , SQLAlchemyBindManager
612
713
814@pytest .fixture
@@ -25,3 +31,99 @@ def multiple_config():
2531 engine_options = dict (connect_args = {"check_same_thread" : False }),
2632 ),
2733 }
34+
35+
36+ @pytest .fixture ()
37+ def sync_async_wrapper ():
38+ """
39+ Tiny wrapper to allow calling sync and async methods using await.
40+
41+ :return:
42+ """
43+
44+ async def f (call ):
45+ return await call if inspect .iscoroutine (call ) else call
46+
47+ return f
48+
49+
50+ @pytest .fixture
51+ def sa_manager () -> SQLAlchemyBindManager :
52+ test_sync_db_path = f"./{ uuid4 ()} .db"
53+ test_async_db_path = f"./{ uuid4 ()} .db"
54+ config = {
55+ "sync" : SQLAlchemyConfig (
56+ engine_url = f"sqlite:///{ test_sync_db_path } " ,
57+ engine_options = dict (connect_args = {"check_same_thread" : False }),
58+ ),
59+ "async" : SQLAlchemyAsyncConfig (
60+ engine_url = f"sqlite+aiosqlite:///{ test_sync_db_path } " ,
61+ engine_options = dict (connect_args = {"check_same_thread" : False }),
62+ ),
63+ }
64+
65+ yield SQLAlchemyBindManager (config )
66+ try :
67+ os .unlink (test_sync_db_path )
68+ except FileNotFoundError :
69+ pass
70+
71+ try :
72+ os .unlink (test_async_db_path )
73+ except FileNotFoundError :
74+ pass
75+
76+ clear_mappers ()
77+
78+
79+ @pytest .fixture (params = ["sync" , "async" ])
80+ def sa_bind (request , sa_manager ):
81+ return sa_manager .get_bind (request .param )
82+
83+
84+ @pytest .fixture
85+ async def model_classes (sa_bind ) -> Tuple [Type , Type ]:
86+ class ParentModel (sa_bind .model_declarative_base ):
87+ __tablename__ = "parent_model"
88+ # required in order to access columns with server defaults
89+ # or SQL expression defaults, subsequent to a flush, without
90+ # triggering an expired load
91+ __mapper_args__ = {"eager_defaults" : True }
92+
93+ model_id = Column (Integer , primary_key = True , autoincrement = True )
94+ name = Column (String )
95+
96+ children = relationship (
97+ "ChildModel" ,
98+ back_populates = "parent" ,
99+ cascade = "all, delete-orphan" ,
100+ lazy = "selectin" ,
101+ )
102+
103+ class ChildModel (sa_bind .model_declarative_base ):
104+ __tablename__ = "child_model"
105+ # required in order to access columns with server defaults
106+ # or SQL expression defaults, subsequent to a flush, without
107+ # triggering an expired load
108+ __mapper_args__ = {"eager_defaults" : True }
109+
110+ model_id = Column (Integer , primary_key = True , autoincrement = True )
111+ parent_model_id = Column (
112+ Integer , ForeignKey ("parent_model.model_id" ), nullable = False
113+ )
114+ name = Column (String )
115+
116+ parent = relationship ("ParentModel" , back_populates = "children" , lazy = "selectin" )
117+
118+ if isinstance (sa_bind , SQLAlchemyBind ):
119+ sa_bind .registry_mapper .metadata .create_all (sa_bind .engine )
120+ else :
121+ async with sa_bind .engine .begin () as conn :
122+ await conn .run_sync (sa_bind .registry_mapper .metadata .create_all )
123+
124+ return ParentModel , ChildModel
125+
126+
127+ @pytest .fixture
128+ async def model_class (model_classes : Tuple [Type , Type ]) -> Type :
129+ return model_classes [0 ]
0 commit comments