Skip to content

Commit 4741210

Browse files
committed
connect loaders and fix tests:
tests/test_bind.py tests/test_core.py tests/test_engine.py
1 parent afd9dd3 commit 4741210

File tree

7 files changed

+163
-27
lines changed

7 files changed

+163
-27
lines changed

src/gino/api.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import sqlalchemy as sa
44
from sqlalchemy.engine.url import make_url, URL
5-
from sqlalchemy.sql.base import Executable
5+
from sqlalchemy.sql.base import Executable, _bind_or_error
6+
from sqlalchemy.sql.schema import SchemaItem
67

78
from . import json_support
8-
from .engine import create_engine
99
from .crud import CRUDModel
1010
from .declarative import declarative_base, declared_attr
11+
from .engine import create_engine
1112
from .exceptions import UninitializedError
1213

1314

@@ -209,6 +210,23 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
209210
await self._args[0].pop_bind().close()
210211

211212

213+
class GinoSchemaVisitor:
214+
__slots__ = ("_item",)
215+
216+
def __init__(self, item):
217+
self._item = item
218+
219+
def __getattr__(self, item):
220+
sync_func = getattr(self._item, item)
221+
222+
async def _wrapper(bind=None, *args, **kwargs):
223+
if bind is None:
224+
bind = _bind_or_error(self._item)
225+
return await bind.run_sync(sync_func, *args, **kwargs)
226+
227+
return _wrapper
228+
229+
212230
class Gino(sa.MetaData):
213231
"""
214232
All-in-one API class of GINO, providing several shortcuts.
@@ -301,7 +319,7 @@ class Gino(sa.MetaData):
301319
302320
"""
303321

304-
# schema_visitor = GinoSchemaVisitor
322+
schema_visitor = GinoSchemaVisitor
305323
"""
306324
The overridable ``gino`` extension class on
307325
:class:`~sqlalchemy.schema.SchemaItem`.
@@ -375,9 +393,8 @@ def __init__(
375393
if ext:
376394
if query_ext:
377395
Executable.gino = property(self.query_executor)
378-
# if schema_ext:
379-
# SchemaItem.gino = property(self.schema_visitor)
380-
# patch_schema(self)
396+
if schema_ext:
397+
SchemaItem.gino = property(self.schema_visitor)
381398

382399
# noinspection PyPep8Naming
383400
@property

src/gino/crud.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,10 @@ async def apply(self, bind=None, timeout=DEFAULT):
170170
row = await bind.first(clause)
171171
if not row:
172172
raise NoSuchRowError()
173-
for k, v in row.items():
174-
self._instance.__values__[self._instance._column_name_map.invert_get(k)] = v
173+
for k in row.keys():
174+
self._instance.__values__[
175+
self._instance._column_name_map.invert_get(k)
176+
] = row[k]
175177
for prop in self._props:
176178
prop.reload(self._instance)
177179
return self
@@ -475,8 +477,8 @@ async def _create(self, bind=None, timeout=DEFAULT):
475477
if bind is None:
476478
bind = cls.__metadata__.bind
477479
row = await bind.first(q)
478-
for k, v in row.items():
479-
self.__values__[self._column_name_map.invert_get(k)] = v
480+
for k in row.keys():
481+
self.__values__[self._column_name_map.invert_get(k)] = row[k]
480482
self.__profile__ = None
481483
return self
482484

src/gino/engine.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import asyncio
44
import inspect
55
import warnings
6+
import weakref
67
from contextvars import ContextVar
7-
from typing import Optional
8+
from typing import Optional, Callable, Any
89

910
from sqlalchemy import text
1011
from sqlalchemy.engine import Engine
@@ -23,6 +24,8 @@
2324
from sqlalchemy.sql import ClauseElement, WARN_LINTING
2425
from sqlalchemy.util.concurrency import greenlet_spawn
2526

27+
from .loader import Loader, LoaderResult
28+
2629

2730
async def create_engine(
2831
url, *arg, isolation_level=None, min_size=1, max_size=None, **kw
@@ -212,6 +215,18 @@ async def _execute_sa10(self, object_, multiparams, params):
212215
return await asyncio.wait_for(coro, timeout)
213216
return await coro
214217

218+
def _load_result(self, result):
219+
options = result.context.execution_options
220+
loader = options.get("loader")
221+
model = options.get("model")
222+
if loader is None and model is not None:
223+
if isinstance(model, weakref.ref):
224+
model = model()
225+
loader = Loader.get(model)
226+
if loader is not None and options.get("return_model", True):
227+
result = LoaderResult(result, loader)
228+
return result
229+
215230
async def _execute(self, object_, params_20style):
216231
if isinstance(object_, str):
217232
return await self.exec_driver_sql(object_, params_20style)
@@ -281,6 +296,7 @@ async def all(self, clause, *multiparams, **params):
281296
282297
"""
283298
result = await self._execute_sa10(clause, multiparams, params)
299+
result = self._load_result(result)
284300
return result.all()
285301

286302
async def first(self, clause, *multiparams, **params):
@@ -293,6 +309,7 @@ async def first(self, clause, *multiparams, **params):
293309
294310
"""
295311
result = await self._execute_sa10(clause, multiparams, params)
312+
result = self._load_result(result)
296313
try:
297314
return result.first()
298315
except ResourceClosedError as e:
@@ -313,6 +330,7 @@ async def one_or_none(self, clause, *multiparams, **params):
313330
314331
"""
315332
result = await self._execute_sa10(clause, multiparams, params)
333+
result = self._load_result(result)
316334
return result.one_or_none()
317335

318336
async def one(self, clause, *multiparams, **params):
@@ -328,6 +346,7 @@ async def one(self, clause, *multiparams, **params):
328346
329347
"""
330348
result = await self._execute_sa10(clause, multiparams, params)
349+
result = self._load_result(result)
331350
return result.one()
332351

333352
async def scalar(self, clause, *multiparams, **params):
@@ -352,7 +371,7 @@ async def status(self, clause, *multiparams, **params):
352371
353372
"""
354373
result = await self._execute_sa10(clause, multiparams, params)
355-
return f"SELECT {result.rowcount}", result.all()
374+
return result.context
356375

357376
class _IterateResult(StartableContext):
358377
def __init__(self, conn, *args):
@@ -633,6 +652,10 @@ async def status(self, clause, *multiparams, **params):
633652
async with self.acquire(reuse=True) as conn:
634653
return await conn.status(clause, *multiparams, **params)
635654

655+
async def run_sync(self, fn: Callable, *arg, **kw) -> Any:
656+
async with self.acquire(reuse=True) as conn:
657+
return await conn.run_sync(fn, *arg, **kw)
658+
636659
class _CompileConnection:
637660
def __init__(self, dialect):
638661
self.dialect = dialect
@@ -737,7 +760,7 @@ def repr(self, color=False):
737760
return repr(self)
738761

739762
def __repr__(self):
740-
return f'{self.__class__.__name__}<{self.sync_engine.pool.status()}>'
763+
return f"{self.__class__.__name__}<{self.sync_engine.pool.status()}>"
741764

742765

743766
class AsyncOptionEngine(OptionEngineMixin, GinoEngine):

src/gino/loader.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33

44
from sqlalchemy import select
5+
from sqlalchemy.engine.result import FilterResult
56
from sqlalchemy.schema import Column
67
from sqlalchemy.sql.elements import Label
78

@@ -213,12 +214,13 @@ def __init__(self, model, *columns, **extras):
213214
self.on_clause = None
214215

215216
def _do_load(self, row):
216-
values = dict((c.name, row[c]) for c in self.columns if c in row)
217+
columns = row.keys()
218+
values = dict((c.name, row[c]) for c in self.columns if c in columns)
217219
if all((v is None) for v in values.values()):
218220
return None
219221
rv = self.model()
220222
for c in self.columns:
221-
if c in row:
223+
if c in columns:
222224
# noinspection PyProtectedMember
223225
instance_key = self.model._column_name_map.invert_get(c.name)
224226
rv.__values__[instance_key] = row[c]
@@ -424,3 +426,97 @@ def do_load(self, row, context):
424426
"""
425427

426428
return self.value, True
429+
430+
431+
class LoaderResult(FilterResult):
432+
def __init__(self, result, loader):
433+
self._real_result = result
434+
self._loader = loader
435+
self._ctx = {}
436+
self._metadata = result._metadata
437+
if result._source_supports_scalars:
438+
self._metadata = self._metadata._reduce([0])
439+
440+
def _post_creational_filter(self, row):
441+
obj, distinct = self._loader.do_load(row, self._ctx)
442+
return obj
443+
444+
def fetchall(self):
445+
# type: () -> List[Mapping]
446+
"""A synonym for the :meth:`_engine.MappingResult.all` method."""
447+
448+
return self._allrows()
449+
450+
def fetchone(self):
451+
# type: () -> Mapping
452+
"""Fetch one object.
453+
454+
Equivalent to :meth:`_result.Result.fetchone` except that
455+
mapping values, rather than :class:`_result.Row` objects,
456+
are returned.
457+
458+
"""
459+
460+
row = self._onerow_getter(self)
461+
if row is _NO_ROW:
462+
return None
463+
else:
464+
return row
465+
466+
def fetchmany(self, size=None):
467+
# type: (Optional[Int]) -> List[Mapping]
468+
"""Fetch many objects.
469+
470+
Equivalent to :meth:`_result.Result.fetchmany` except that
471+
mapping values, rather than :class:`_result.Row` objects,
472+
are returned.
473+
474+
"""
475+
476+
return self._manyrow_getter(self, size)
477+
478+
def all(self):
479+
# type: () -> List[Mapping]
480+
"""Return all scalar values in a list.
481+
482+
Equivalent to :meth:`_result.Result.all` except that
483+
mapping values, rather than :class:`_result.Row` objects,
484+
are returned.
485+
486+
"""
487+
488+
return self._allrows()
489+
490+
def first(self):
491+
# type: () -> Optional[Mapping]
492+
"""Fetch the first object or None if no object is present.
493+
494+
Equivalent to :meth:`_result.Result.first` except that
495+
mapping values, rather than :class:`_result.Row` objects,
496+
are returned.
497+
498+
499+
"""
500+
return self._only_one_row(False, False, False)
501+
502+
def one_or_none(self):
503+
# type: () -> Optional[Mapping]
504+
"""Return at most one object or raise an exception.
505+
506+
Equivalent to :meth:`_result.Result.one_or_none` except that
507+
mapping values, rather than :class:`_result.Row` objects,
508+
are returned.
509+
510+
"""
511+
return self._only_one_row(True, False, False)
512+
513+
def one(self):
514+
# type: () -> Mapping
515+
"""Return exactly one object or raise an exception.
516+
517+
Equivalent to :meth:`_result.Result.one` except that
518+
mapping values, rather than :class:`_result.Row` objects,
519+
are returned.
520+
521+
"""
522+
return self._only_one_row(True, True, False)

tests/test_bind.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ async def test_db_api(bind, random_name):
4848
await db.first(User.query.where(User.nickname == random_name))
4949
).nickname == random_name
5050
assert len(await db.all(User.query.where(User.nickname == random_name))) == 1
51-
assert (await db.status(User.delete.where(User.nickname == random_name)))[
52-
0
53-
] == "DELETE 1"
51+
ctx = await db.status(User.delete.where(User.nickname == random_name))
52+
assert ctx.rowcount == 1
5453
stmt, params = db.compile(User.query.where(User.id == 3))
5554
assert params[0] == 3
5655

tests/test_core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88

99
async def test_engine_only():
1010
import gino
11-
from gino.schema import GinoSchemaVisitor
12-
from sqlalchemy.engine.result import RowProxy
11+
from sqlalchemy.engine import Row
1312

1413
metadata = MetaData()
1514

@@ -30,18 +29,19 @@ async def test_engine_only():
3029
)
3130

3231
engine = await gino.create_engine(PG_URL)
33-
await GinoSchemaVisitor(metadata).create_all(engine)
32+
await engine.run_sync(metadata.create_all)
3433
try:
3534
ins = users.insert().values(name="jack", fullname="Jack Jones")
3635
await engine.status(ins)
3736
res = await engine.all(users.select())
38-
assert isinstance(res[0], RowProxy)
37+
assert isinstance(res[0], Row)
3938
finally:
40-
await GinoSchemaVisitor(metadata).drop_all(engine)
39+
await engine.run_sync(metadata.drop_all)
4140

4241

4342
async def test_core():
4443
from gino import Gino
44+
from sqlalchemy.engine import Row
4545

4646
db = Gino()
4747

@@ -68,7 +68,7 @@ async def test_core():
6868
name="jack", fullname="Jack Jones",
6969
).gino.status()
7070
res = await users.select().gino.all()
71-
assert isinstance(res[0], RowProxy)
71+
assert isinstance(res[0], Row)
7272
finally:
7373
await db.gino.drop_all()
7474

tests/test_engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ async def test_basic(engine):
4848
assert isinstance((await engine.all("select now()"))[0][0], datetime)
4949
assert isinstance((await engine.one("select now()"))[0], datetime)
5050
assert isinstance((await engine.one_or_none("select now()"))[0], datetime)
51-
status, result = await engine.status("select now()")
52-
assert status == "SELECT 1"
53-
assert isinstance(result[0][0], datetime)
51+
ctx = await engine.status("select now()")
52+
assert ctx.rowcount == 1
5453
with pytest.raises(ObjectNotExecutableError):
5554
await engine.all(object())
5655

0 commit comments

Comments
 (0)