Skip to content

Commit e15dfda

Browse files
committed
Fix Sanic support
1 parent 09d7f84 commit e15dfda

File tree

2 files changed

+113
-26
lines changed

2 files changed

+113
-26
lines changed

gino/ext/sanic.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# noinspection PyPackageRequirements
22
from sanic.exceptions import NotFound
3+
from sqlalchemy.engine.url import URL
4+
try:
5+
# noinspection PyPackageRequirements
6+
from aiocontextvars import enable_inherit, disable_inherit
7+
except ImportError:
8+
enable_inherit = disable_inherit = lambda: None
39

410
from ..api import Gino as _Gino, GinoExecutor as _Executor
5-
from ..local import enable_task_local, disable_task_local
6-
from ..connection import GinoConnection as _Connection
7-
from ..pool import GinoPool as _Pool
11+
from ..engine import GinoConnection as _Connection, GinoEngine as _Engine
812

913

1014
class SanicModelMixin:
@@ -36,7 +40,9 @@ async def first_or_404(self, *args, **kwargs):
3640

3741

3842
# noinspection PyClassHasNoInit
39-
class GinoPool(_Pool):
43+
class GinoEngine(_Engine):
44+
connection_cls = GinoConnection
45+
4046
async def first_or_404(self, *args, **kwargs):
4147
rv = await self.first(*args, **kwargs)
4248
if rv is None:
@@ -65,51 +71,59 @@ class Gino(_Gino):
6571
"""
6672
model_base_classes = _Gino.model_base_classes + (SanicModelMixin,)
6773
query_executor = GinoExecutor
68-
connection_cls = GinoConnection
69-
pool_cls = GinoPool
74+
75+
def __init__(self, app=None, *args, **kwargs):
76+
super().__init__(*args, **kwargs)
77+
if app is not None:
78+
self.init_app(app)
7079

7180
def init_app(self, app):
72-
task_local_enabled = [False]
81+
inherit_enabled = [False]
7382

7483
if app.config.setdefault('DB_USE_CONNECTION_FOR_REQUEST', True):
7584
@app.middleware('request')
7685
async def on_request(request):
77-
request['connection_ctx'] = ctx = self.acquire(lazy=True)
78-
request['connection'] = await ctx.__aenter__()
86+
request['connection'] = await self.acquire(lazy=True)
7987

8088
@app.middleware('response')
8189
async def on_response(request, _):
82-
ctx = request.pop('connection_ctx', None)
83-
request.pop('connection', None)
84-
if ctx is not None:
85-
await ctx.__aexit__(None, None, None)
90+
conn = request.pop('connection', None)
91+
if conn is not None:
92+
await self.release(conn)
8693

8794
@app.listener('before_server_start')
8895
async def before_server_start(_, loop):
8996
if app.config.setdefault('DB_USE_CONNECTION_FOR_REQUEST', True):
90-
enable_task_local(loop)
91-
task_local_enabled[0] = True
92-
93-
await self.create_pool(
94-
host=app.config.setdefault('DB_HOST', 'localhost'),
95-
port=app.config.setdefault('DB_PORT', 5432),
96-
user=app.config.setdefault('DB_USER', 'postgres'),
97-
password=app.config.setdefault('DB_PASSWORD', ''),
98-
database=app.config.setdefault('DB_DATABASE', 'postgres'),
97+
enable_inherit(loop)
98+
inherit_enabled[0] = True
99+
100+
await self.set_bind(
101+
URL(
102+
drivername=app.config.setdefault('DB_DRIVER', 'asyncpg'),
103+
host=app.config.setdefault('DB_HOST', 'localhost'),
104+
port=app.config.setdefault('DB_PORT', 5432),
105+
username=app.config.setdefault('DB_USER', 'postgres'),
106+
password=app.config.setdefault('DB_PASSWORD', ''),
107+
database=app.config.setdefault('DB_DATABASE', 'postgres'),
108+
),
99109
min_size=app.config.setdefault('DB_POOL_MIN_SIZE', 5),
100110
max_size=app.config.setdefault('DB_POOL_MAX_SIZE', 10),
101111
loop=loop,
102112
)
103113

104114
@app.listener('after_server_stop')
105115
async def after_server_stop(_, loop):
106-
await self._bind.close()
107-
if task_local_enabled[0]:
108-
disable_task_local(loop)
109-
task_local_enabled[0] = False
116+
await self.pop_bind().close()
117+
if inherit_enabled[0]:
118+
disable_inherit(loop)
119+
inherit_enabled[0] = False
110120

111121
async def first_or_404(self, *args, **kwargs):
112122
rv = await self.first(*args, **kwargs)
113123
if rv is None:
114124
raise NotFound('No such data')
115125
return rv
126+
127+
async def set_bind(self, bind, loop=None, **kwargs):
128+
kwargs.setdefault('engine_cls', GinoEngine)
129+
return await super().set_bind(bind, loop=loop, **kwargs)

tests/test_sanic.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import gino
2+
import sanic
3+
import pytest
4+
from sanic.response import text, json
5+
from gino.ext.sanic import Gino
6+
7+
from .models import DB_ARGS, PG_URL
8+
9+
10+
# noinspection PyShadowingNames
11+
@pytest.fixture
12+
async def app():
13+
app = sanic.Sanic()
14+
app.config['DB_HOST'] = DB_ARGS['host']
15+
app.config['DB_PORT'] = DB_ARGS['port']
16+
app.config['DB_USER'] = DB_ARGS['user']
17+
app.config['DB_PASSWORD'] = DB_ARGS['password']
18+
app.config['DB_DATABASE'] = DB_ARGS['database']
19+
20+
db = Gino(app)
21+
22+
class User(db.Model):
23+
__tablename__ = 'gino_users'
24+
25+
id = db.Column(db.BigInteger(), primary_key=True)
26+
nickname = db.Column(db.Unicode(), default='noname')
27+
28+
@app.route('/')
29+
async def root(request):
30+
return text('Hello, world!')
31+
32+
@app.route('/users/<uid:int>')
33+
async def get_user(request, uid):
34+
return json((await User.get_or_404(uid)).to_dict())
35+
36+
@app.route('/users', methods=['POST'])
37+
async def add_user(request):
38+
u = await User.create(nickname=request.form.get('name'))
39+
await u.query.gino.first_or_404()
40+
await db.first_or_404(u.query)
41+
await db.bind.first_or_404(u.query)
42+
await request['connection'].first_or_404(u.query)
43+
return json(u.to_dict())
44+
45+
e = await gino.create_engine(PG_URL)
46+
try:
47+
try:
48+
await db.gino.create_all(e)
49+
yield app
50+
finally:
51+
await db.gino.drop_all(e)
52+
finally:
53+
await e.close()
54+
55+
56+
def test_index_returns_200(app):
57+
request, response = app.test_client.get('/')
58+
assert response.status == 200
59+
assert response.text == 'Hello, world!'
60+
61+
62+
def test(app):
63+
request, response = app.test_client.get('/users/1')
64+
assert response.status == 404
65+
66+
request, response = app.test_client.post('/users',
67+
data=dict(name='fantix'))
68+
assert response.status == 200
69+
assert response.json == dict(id=1, nickname='fantix')
70+
71+
request, response = app.test_client.get('/users/1')
72+
assert response.status == 200
73+
assert response.json == dict(id=1, nickname='fantix')

0 commit comments

Comments
 (0)