Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gino/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from sqlalchemy import select
from sqlalchemy.schema import Column
from sqlalchemy.sql.elements import Label

from .declarative import Model

Expand All @@ -19,6 +20,8 @@ def get(cls, value):
rv = AliasLoader(value)
elif isinstance(value, Column):
rv = ColumnLoader(value)
elif isinstance(value, Label):
rv = ColumnLoader(value.name)
elif isinstance(value, tuple):
rv = TupleLoader(value)
elif callable(value):
Expand Down
51 changes: 47 additions & 4 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import random
from datetime import datetime

import pytest
from async_generator import yield_, async_generator
import pytest
from sqlalchemy import select
from sqlalchemy.sql.functions import count

from gino.loader import AliasLoader
from gino.loader import AliasLoader, ColumnLoader
from .models import db, User, Team, Company

pytestmark = pytest.mark.asyncio


@pytest.fixture
@async_generator
async def user(bind, random_name):
async def user(bind):
c = await Company.create()
t1 = await Team.create(company_id=c.id)
t2 = await Team.create(company_id=c.id, parent_id=t1.id)
u = await User.create(nickname=random_name, team_id=t2.id)
u = await User.create(team_id=t2.id)
u.team = t2
t2.parent = t1
t2.company = c
Expand Down Expand Up @@ -161,6 +163,47 @@ async def test_alias_loader_columns(user):
assert u.id is not None


async def test_multiple_models_in_one_query(bind):
for _ in range(3):
await User.create()

ua1 = User.alias()
ua2 = User.alias()
join_query = select([ua1, ua2]).where(ua1.id < ua2.id)
result = await join_query.gino.load((ua1.load('id'), ua2.load('id'))).all()
assert len(result) == 3
for u1, u2 in result:
assert u1.id is not None
assert u2.id is not None
assert u1.id < u2.id


async def test_loader_with_aggregation(user):
count_col = count().label('count')
user_count = select(
[User.team_id, count_col]
).group_by(
User.team_id
).alias()
query = Team.outerjoin(user_count).select()
result = await query.gino.load(
(Team.id, Team.name, user_count.columns.team_id, count_col)
).all()
assert len(result) == 2
# team 1 doesn't have users, team 2 has 1 user
# third and forth columns are None for team 1
for team_id, team_name, user_team_id, user_count in result:
if team_id == user.team_id:
assert team_name == user.team.name
assert user_team_id == user.team_id
assert user_count == 1
else:
assert team_id is not None
assert team_name is not None
assert user_team_id is None
assert user_count is None


async def test_adjacency_list_query_builder(user):
group = Team.alias()
u = await User.load(team=Team.load(parent=group.on(
Expand Down