Skip to content

Commit c78a340

Browse files
authored
Merge pull request #323 from jekel/columns-loader
ModelLoader Columns objects support in column loader
2 parents ca283e0 + 03c9f90 commit c78a340

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

gino/loader.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, model, *column_names, **extras):
5656
self.model = model
5757
self._distinct = None
5858
if column_names:
59-
self.columns = [getattr(model, name) for name in column_names]
59+
self.columns = self._column_loader(model, column_names)
6060
else:
6161
self.columns = model
6262
self.extras = dict((key, self.get(value))
@@ -123,11 +123,28 @@ def get_from(self):
123123

124124
def load(self, *column_names, **extras):
125125
if column_names:
126-
self.columns = [getattr(self.model, name) for name in column_names]
126+
self.columns = self._column_loader(self.model, column_names)
127+
127128
self.extras.update((key, self.get(value))
128129
for key, value in extras.items())
129130
return self
130131

132+
@classmethod
133+
def _column_loader(cls, model, column_names):
134+
def column_formatter(column_name):
135+
if isinstance(column_name, str):
136+
return getattr(model, column_name)
137+
elif isinstance(column_name, Column):
138+
if column_name not in model:
139+
raise AttributeError('Column {} does not belong '
140+
'to this model'.format(column_name))
141+
return column_name
142+
else:
143+
raise TypeError('Unknown column name {} type {}'.
144+
format(column_name, type(column_name)))
145+
146+
return [column_formatter(column_name) for column_name in column_names]
147+
131148
def on(self, on_clause):
132149
self.on_clause = on_clause
133150
return self

tests/test_loader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,17 @@ async def test_scalar(user):
4343

4444

4545
async def test_model_load(user):
46-
u = await User.query.gino.load(User.load('nickname')).first()
46+
u = await User.query.gino.load(User.load('nickname', User.team_id)).first()
4747
assert isinstance(u, User)
4848
assert u.id is None
4949
assert u.nickname == user.nickname
50+
assert u.team_id == user.team.id
51+
52+
with pytest.raises(TypeError):
53+
await User.query.gino.load(User.load(123)).first()
54+
55+
with pytest.raises(AttributeError):
56+
await User.query.gino.load(User.load(Team.id)).first()
5057

5158

5259
async def test_216_model_load_passive_partial(user):

0 commit comments

Comments
 (0)