|
1 | 1 | # noinspection PyPackageRequirements
|
2 | 2 | from sanic.exceptions import NotFound
|
| 3 | +from sqlalchemy.engine.url import URL |
| 4 | +try: |
| 5 | + # noinspection PyPackageRequirements |
| 6 | + from aiocontextvars import enable_inherit, disable_inherit |
| 7 | +except ImportError: |
| 8 | + enable_inherit = disable_inherit = lambda: None |
3 | 9 |
|
4 | 10 | from ..api import Gino as _Gino, GinoExecutor as _Executor
|
5 |
| -from ..local import enable_task_local, disable_task_local |
6 |
| -from ..connection import GinoConnection as _Connection |
7 |
| -from ..pool import GinoPool as _Pool |
| 11 | +from ..engine import GinoConnection as _Connection, GinoEngine as _Engine |
8 | 12 |
|
9 | 13 |
|
10 | 14 | class SanicModelMixin:
|
@@ -36,7 +40,9 @@ async def first_or_404(self, *args, **kwargs):
|
36 | 40 |
|
37 | 41 |
|
38 | 42 | # noinspection PyClassHasNoInit
|
39 |
| -class GinoPool(_Pool): |
| 43 | +class GinoEngine(_Engine): |
| 44 | + connection_cls = GinoConnection |
| 45 | + |
40 | 46 | async def first_or_404(self, *args, **kwargs):
|
41 | 47 | rv = await self.first(*args, **kwargs)
|
42 | 48 | if rv is None:
|
@@ -65,51 +71,59 @@ class Gino(_Gino):
|
65 | 71 | """
|
66 | 72 | model_base_classes = _Gino.model_base_classes + (SanicModelMixin,)
|
67 | 73 | query_executor = GinoExecutor
|
68 |
| - connection_cls = GinoConnection |
69 |
| - pool_cls = GinoPool |
| 74 | + |
| 75 | + def __init__(self, app=None, *args, **kwargs): |
| 76 | + super().__init__(*args, **kwargs) |
| 77 | + if app is not None: |
| 78 | + self.init_app(app) |
70 | 79 |
|
71 | 80 | def init_app(self, app):
|
72 |
| - task_local_enabled = [False] |
| 81 | + inherit_enabled = [False] |
73 | 82 |
|
74 | 83 | if app.config.setdefault('DB_USE_CONNECTION_FOR_REQUEST', True):
|
75 | 84 | @app.middleware('request')
|
76 | 85 | async def on_request(request):
|
77 |
| - request['connection_ctx'] = ctx = self.acquire(lazy=True) |
78 |
| - request['connection'] = await ctx.__aenter__() |
| 86 | + request['connection'] = await self.acquire(lazy=True) |
79 | 87 |
|
80 | 88 | @app.middleware('response')
|
81 | 89 | async def on_response(request, _):
|
82 |
| - ctx = request.pop('connection_ctx', None) |
83 |
| - request.pop('connection', None) |
84 |
| - if ctx is not None: |
85 |
| - await ctx.__aexit__(None, None, None) |
| 90 | + conn = request.pop('connection', None) |
| 91 | + if conn is not None: |
| 92 | + await self.release(conn) |
86 | 93 |
|
87 | 94 | @app.listener('before_server_start')
|
88 | 95 | async def before_server_start(_, loop):
|
89 | 96 | if app.config.setdefault('DB_USE_CONNECTION_FOR_REQUEST', True):
|
90 |
| - enable_task_local(loop) |
91 |
| - task_local_enabled[0] = True |
92 |
| - |
93 |
| - await self.create_pool( |
94 |
| - host=app.config.setdefault('DB_HOST', 'localhost'), |
95 |
| - port=app.config.setdefault('DB_PORT', 5432), |
96 |
| - user=app.config.setdefault('DB_USER', 'postgres'), |
97 |
| - password=app.config.setdefault('DB_PASSWORD', ''), |
98 |
| - database=app.config.setdefault('DB_DATABASE', 'postgres'), |
| 97 | + enable_inherit(loop) |
| 98 | + inherit_enabled[0] = True |
| 99 | + |
| 100 | + await self.set_bind( |
| 101 | + URL( |
| 102 | + drivername=app.config.setdefault('DB_DRIVER', 'asyncpg'), |
| 103 | + host=app.config.setdefault('DB_HOST', 'localhost'), |
| 104 | + port=app.config.setdefault('DB_PORT', 5432), |
| 105 | + username=app.config.setdefault('DB_USER', 'postgres'), |
| 106 | + password=app.config.setdefault('DB_PASSWORD', ''), |
| 107 | + database=app.config.setdefault('DB_DATABASE', 'postgres'), |
| 108 | + ), |
99 | 109 | min_size=app.config.setdefault('DB_POOL_MIN_SIZE', 5),
|
100 | 110 | max_size=app.config.setdefault('DB_POOL_MAX_SIZE', 10),
|
101 | 111 | loop=loop,
|
102 | 112 | )
|
103 | 113 |
|
104 | 114 | @app.listener('after_server_stop')
|
105 | 115 | async def after_server_stop(_, loop):
|
106 |
| - await self._bind.close() |
107 |
| - if task_local_enabled[0]: |
108 |
| - disable_task_local(loop) |
109 |
| - task_local_enabled[0] = False |
| 116 | + await self.pop_bind().close() |
| 117 | + if inherit_enabled[0]: |
| 118 | + disable_inherit(loop) |
| 119 | + inherit_enabled[0] = False |
110 | 120 |
|
111 | 121 | async def first_or_404(self, *args, **kwargs):
|
112 | 122 | rv = await self.first(*args, **kwargs)
|
113 | 123 | if rv is None:
|
114 | 124 | raise NotFound('No such data')
|
115 | 125 | return rv
|
| 126 | + |
| 127 | + async def set_bind(self, bind, loop=None, **kwargs): |
| 128 | + kwargs.setdefault('engine_cls', GinoEngine) |
| 129 | + return await super().set_bind(bind, loop=loop, **kwargs) |
0 commit comments