|
1 | 1 | import inspect |
2 | 2 | import os |
3 | 3 | from contextlib import _AsyncGeneratorContextManager, asynccontextmanager |
4 | | -from typing import Tuple, Type |
| 4 | +from typing import Tuple, Type, Union |
5 | 5 | from uuid import uuid4 |
6 | 6 |
|
7 | 7 | import pytest |
8 | 8 | from sqlalchemy import Column, ForeignKey, Integer, String |
9 | 9 | from sqlalchemy.orm import clear_mappers, relationship |
10 | 10 |
|
11 | 11 | from sqlalchemy_bind_manager import SQLAlchemyAsyncConfig, SQLAlchemyConfig |
12 | | -from sqlalchemy_bind_manager._bind_manager import SQLAlchemyBind, SQLAlchemyBindManager |
| 12 | +from sqlalchemy_bind_manager._bind_manager import ( |
| 13 | + SQLAlchemyAsyncBind, |
| 14 | + SQLAlchemyBind, |
| 15 | + SQLAlchemyBindManager, |
| 16 | +) |
| 17 | +from sqlalchemy_bind_manager._repository import ( |
| 18 | + SQLAlchemyAsyncRepository, |
| 19 | + SQLAlchemyRepository, |
| 20 | +) |
| 21 | +from sqlalchemy_bind_manager._session_handler import AsyncSessionHandler, SessionHandler |
| 22 | +from sqlalchemy_bind_manager.repository import AsyncUnitOfWork, UnitOfWork |
13 | 23 |
|
14 | 24 |
|
15 | 25 | @pytest.fixture |
@@ -148,3 +158,30 @@ class ChildModel(sa_bind.model_declarative_base): |
148 | 158 | @pytest.fixture |
149 | 159 | async def model_class(model_classes: Tuple[Type, Type]) -> Type: |
150 | 160 | return model_classes[0] |
| 161 | + |
| 162 | + |
| 163 | +@pytest.fixture |
| 164 | +def session_handler_class(sa_bind): |
| 165 | + return ( |
| 166 | + AsyncSessionHandler |
| 167 | + if isinstance(sa_bind, SQLAlchemyAsyncBind) |
| 168 | + else SessionHandler |
| 169 | + ) |
| 170 | + |
| 171 | + |
| 172 | +@pytest.fixture |
| 173 | +def repository_class( |
| 174 | + sa_bind: Union[SQLAlchemyBind, SQLAlchemyAsyncBind] |
| 175 | +) -> Type[Union[SQLAlchemyAsyncRepository, SQLAlchemyRepository]]: |
| 176 | + base_class = ( |
| 177 | + SQLAlchemyRepository |
| 178 | + if isinstance(sa_bind, SQLAlchemyBind) |
| 179 | + else SQLAlchemyAsyncRepository |
| 180 | + ) |
| 181 | + |
| 182 | + return base_class |
| 183 | + |
| 184 | + |
| 185 | +@pytest.fixture |
| 186 | +def uow_class(sa_bind): |
| 187 | + return AsyncUnitOfWork if isinstance(sa_bind, SQLAlchemyAsyncBind) else UnitOfWork |
0 commit comments