Skip to content

Commit f2dccd1

Browse files
committed
fixed #261, map model and db column names
1 parent a05c82a commit f2dccd1

12 files changed

+123
-57
lines changed

gino/crud.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlalchemy.sql import ClauseElement
66

77
from . import json_support
8-
from .declarative import Model
8+
from .declarative import Model, InvertDict
99
from .exceptions import NoSuchRowError
1010
from .loader import AliasLoader, ModelLoader
1111

@@ -78,7 +78,7 @@ class UpdateRequest:
7878
specific model instance and its database row.
7979
8080
"""
81-
def __init__(self, instance):
81+
def __init__(self, instance: 'CRUDModel'):
8282
self._instance = instance
8383
self._values = {}
8484
self._props = {}
@@ -124,7 +124,7 @@ async def apply(self, bind=None, timeout=DEFAULT):
124124
json_updates = {}
125125
for prop, value in self._props.items():
126126
value = prop.save(self._instance, value)
127-
updates = json_updates.setdefault(prop.column_name, {})
127+
updates = json_updates.setdefault(prop.prop_name, {})
128128
if self._literal:
129129
updates[prop.name] = value
130130
else:
@@ -133,26 +133,26 @@ async def apply(self, bind=None, timeout=DEFAULT):
133133
elif not isinstance(value, ClauseElement):
134134
value = sa.cast(value, sa.Unicode)
135135
updates[sa.cast(prop.name, sa.Unicode)] = value
136-
for column_name, updates in json_updates.items():
137-
column = getattr(cls, column_name)
136+
for prop_name, updates in json_updates.items():
137+
prop = getattr(cls, prop_name)
138138
from .dialects.asyncpg import JSONB
139-
if isinstance(column.type, JSONB):
139+
if isinstance(prop.type, JSONB):
140140
if self._literal:
141-
values[column_name] = column.concat(updates)
141+
values[prop_name] = prop.concat(updates)
142142
else:
143-
values[column_name] = column.concat(
143+
values[prop_name] = prop.concat(
144144
sa.func.jsonb_build_object(
145145
*itertools.chain(*updates.items())))
146146
else:
147-
raise TypeError('{} is not supported.'.format(column.type))
147+
raise TypeError('{} is not supported.'.format(prop.type))
148148

149149
opts = dict(return_model=False)
150150
if timeout is not DEFAULT:
151151
opts['timeout'] = timeout
152152
clause = type(self._instance).update.where(
153153
self._locator,
154154
).values(
155-
**values,
155+
**self._instance._get_sa_values(values),
156156
).returning(
157157
*[getattr(cls, key) for key in values],
158158
).execution_options(**opts)
@@ -161,7 +161,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
161161
row = await bind.first(clause)
162162
if not row:
163163
raise NoSuchRowError()
164-
self._instance.__values__.update(row)
164+
for k, v in row.items():
165+
self._instance.__values__[
166+
self._instance._column_name_map.invert_get(k)] = v
165167
for prop in self._props:
166168
prop.reload(self._instance)
167169
return self
@@ -409,6 +411,7 @@ class CRUDModel(Model):
409411
"""
410412

411413
_update_request_cls = UpdateRequest
414+
_column_name_map = InvertDict()
412415

413416
def __init__(self, **values):
414417
super().__init__()
@@ -421,10 +424,10 @@ def _init_table(cls, sub_cls):
421424
for each_cls in sub_cls.__mro__[::-1]:
422425
for k, v in each_cls.__dict__.items():
423426
if isinstance(v, json_support.JSONProperty):
424-
if not hasattr(sub_cls, v.column_name):
427+
if not hasattr(sub_cls, v.prop_name):
425428
raise AttributeError(
426429
'Requires "{}" JSON[B] column.'.format(
427-
v.column_name))
430+
v.prop_name))
428431
v.name = k
429432
if rv is not None:
430433
rv.__model__ = weakref.ref(sub_cls)
@@ -440,12 +443,12 @@ async def _create(self, bind=None, timeout=DEFAULT):
440443
cls = type(self)
441444
# noinspection PyUnresolvedReferences,PyProtectedMember
442445
cls._check_abstract()
443-
keys = set(self.__profile__.keys() if self.__profile__ else [])
444-
for key in keys:
446+
profile_keys = set(self.__profile__.keys() if self.__profile__ else [])
447+
for key in profile_keys:
445448
cls.__dict__.get(key).save(self)
446449
# initialize default values
447450
for key, prop in cls.__dict__.items():
448-
if key in keys:
451+
if key in profile_keys:
449452
continue
450453
if isinstance(prop, json_support.JSONProperty):
451454
if prop.default is None or prop.after_get.method is not None:
@@ -458,15 +461,25 @@ async def _create(self, bind=None, timeout=DEFAULT):
458461
if timeout is not DEFAULT:
459462
opts['timeout'] = timeout
460463
# noinspection PyArgumentList
461-
q = cls.__table__.insert().values(**self.__values__).returning(
462-
*cls).execution_options(**opts)
464+
q = cls.__table__.insert().values(
465+
**self._get_sa_values(self.__values__)
466+
).returning(
467+
*cls
468+
).execution_options(**opts)
463469
if bind is None:
464470
bind = cls.__metadata__.bind
465471
row = await bind.first(q)
466-
self.__values__.update(row)
472+
for k, v in row.items():
473+
self.__values__[self._column_name_map.invert_get(k)] = v
467474
self.__profile__ = None
468475
return self
469476

477+
def _get_sa_values(self, instance_values: dict) -> dict:
478+
values = {}
479+
for k, v in instance_values.items():
480+
values[self._column_name_map[k]] = v
481+
return values
482+
470483
@classmethod
471484
async def get(cls, ident, bind=None, timeout=DEFAULT):
472485
"""
@@ -592,11 +605,12 @@ def to_dict(self):
592605
593606
"""
594607
cls = type(self)
595-
keys = set(c.name for c in cls)
608+
# noinspection PyTypeChecker
609+
keys = set(cls._column_name_map.invert_get(c.name) for c in cls)
596610
for key, prop in cls.__dict__.items():
597611
if isinstance(prop, json_support.JSONProperty):
598612
keys.add(key)
599-
keys.discard(prop.column_name)
613+
keys.discard(prop.prop_name)
600614
return dict((k, getattr(self, k)) for k in keys)
601615

602616
@classmethod

gino/declarative.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import sqlalchemy as sa
44
from sqlalchemy.exc import InvalidRequestError
55

6+
from .exceptions import GinoException
7+
68

79
class ColumnAttribute:
8-
def __init__(self, column):
9-
self.name = column.name
10+
def __init__(self, name, column):
11+
self.name = name
1012
self.column = column
1113

1214
def __get__(self, instance, owner):
@@ -22,6 +24,29 @@ def __delete__(self, instance):
2224
raise AttributeError('Cannot delete value.')
2325

2426

27+
class InvertDict(dict):
28+
def __init__(self, *args, **kwargs):
29+
super().__init__(*args, **kwargs)
30+
self._inverted_dict = dict()
31+
for k, v in self.items():
32+
if v in self._inverted_dict:
33+
raise GinoException(
34+
'Column name {} already maps to {}'.format(
35+
v, self._inverted_dict[v]))
36+
self._inverted_dict[v] = k
37+
38+
def __setitem__(self, key, value):
39+
if value in self._inverted_dict:
40+
raise GinoException(
41+
'Column name {} already maps to {}'.format(
42+
value, self._inverted_dict[value]))
43+
super().__setitem__(key, value)
44+
self._inverted_dict[value] = key
45+
46+
def invert_get(self, key, default=None):
47+
return self._inverted_dict.get(key, default)
48+
49+
2550
class ModelType(type):
2651
def _check_abstract(self):
2752
if self.__table__ is None:
@@ -119,18 +144,22 @@ def _init_table(cls, sub_cls):
119144
columns = []
120145
inspected_args = []
121146
updates = {}
147+
column_name_map = InvertDict()
122148
for each_cls in sub_cls.__mro__[::-1]:
123149
for k, v in getattr(each_cls, '__namespace__',
124150
each_cls.__dict__).items():
125151
if callable(v) and getattr(v, '__declared_attr__', False):
126152
v = updates[k] = v(sub_cls)
127153
if isinstance(v, sa.Column):
128154
v = v.copy()
129-
v.name = k
155+
if not v.name:
156+
v.name = k
157+
column_name_map[k] = v.name
130158
columns.append(v)
131-
updates[k] = sub_cls.__attr_factory__(v)
159+
updates[k] = sub_cls.__attr_factory__(k, v)
132160
elif isinstance(v, (sa.Index, sa.Constraint)):
133161
inspected_args.append(v)
162+
sub_cls._column_name_map = column_name_map
134163

135164
# handle __table_args__
136165
table_args = updates.get('__table_args__',
@@ -173,4 +202,5 @@ def inspect_model_type(target):
173202
return sa.inspection.inspect(target.__table__)
174203

175204

176-
__all__ = ['ColumnAttribute', 'Model', 'declarative_base', 'declared_attr']
205+
__all__ = ['ColumnAttribute', 'Model', 'declarative_base', 'declared_attr',
206+
'InvertDict']

gino/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from contextvars import ContextVar
1616
else:
1717
# noinspection PyPackageRequirements
18-
from aiocontextvars import ContextVar
18+
from aiocontextvars import ContextVar # pragma: no cover
1919

2020

2121
class _BaseDBAPIConnection:

gino/json_support.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ def call(self, instance, val):
2222

2323

2424
class JSONProperty:
25-
def __init__(self, default=None, column_name='profile'):
25+
def __init__(self, default=None, prop_name='profile'):
2626
self.name = None
2727
self.default = default
28-
self.column_name = column_name
28+
self.prop_name = prop_name
2929
self.expression = Hook(self)
3030
self.after_get = Hook(self)
3131
self.before_set = Hook(self)
3232

3333
def __get__(self, instance, owner):
3434
if instance is None:
3535
exp = self.make_expression(
36-
getattr(owner, self.column_name)[self.name])
36+
getattr(owner, self.prop_name)[self.name])
3737
return self.expression.call(owner, exp)
3838
val = self.get_profile(instance).get(self.name, NONE)
3939
if val is NONE:
@@ -54,16 +54,16 @@ def get_profile(self, instance):
5454
if instance.__profile__ is None:
5555
props = type(instance).__dict__
5656
instance.__profile__ = {}
57-
for key, value in (getattr(instance, self.column_name, None)
57+
for key, value in (getattr(instance, self.prop_name, None)
5858
or {}).items():
5959
instance.__profile__[key] = props[key].decode(value)
6060
return instance.__profile__
6161

6262
def save(self, instance, value=NONE):
63-
profile = getattr(instance, self.column_name, None)
63+
profile = getattr(instance, self.prop_name, None)
6464
if profile is None:
6565
profile = {}
66-
setattr(instance, self.column_name, profile)
66+
setattr(instance, self.prop_name, profile)
6767
if value is NONE:
6868
value = instance.__profile__[self.name]
6969
if not isinstance(value, sa.sql.ClauseElement):
@@ -74,7 +74,7 @@ def save(self, instance, value=NONE):
7474
def reload(self, instance):
7575
if instance.__profile__ is None:
7676
return
77-
profile = getattr(instance, self.column_name, None) or {}
77+
profile = getattr(instance, self.prop_name, None) or {}
7878
value = profile.get(self.name, NONE)
7979
if value is NONE:
8080
instance.__profile__.pop(self.name, None)

gino/loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ def _do_load(self, row, *, none_as_none=None):
7575
if none_as_none and all((v is None) for v in values.values()):
7676
return None
7777
rv = self.model()
78-
rv.__values__.update(values)
78+
for c in self.columns:
79+
if c in row:
80+
# noinspection PyProtectedMember
81+
instance_key = self.model._column_name_map.invert_get(c.name)
82+
rv.__values__[instance_key] = row[c]
7983
return rv
8084

8185
def do_load(self, row, context):

tests/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ class User(db.Model):
3131
__tablename__ = 'gino_users'
3232

3333
id = db.Column(db.BigInteger(), primary_key=True)
34-
nickname = db.Column(db.Unicode(), default='noname')
35-
profile = db.Column(JSONB(), nullable=False, server_default='{}')
34+
nickname = db.Column('name', db.Unicode(), default='noname')
35+
profile = db.Column('props', JSONB(), nullable=False, server_default='{}')
3636
type = db.Column(
3737
db.Enum(UserType),
3838
nullable=False,
3939
default=UserType.USER,
4040
)
41-
name = db.StringProperty()
41+
realname = db.StringProperty()
4242
age = db.IntegerProperty(default=18)
4343
balance = db.IntegerProperty(default=0)
4444
birthday = db.DateTimeProperty(

tests/test_bind.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def test_unbind(asyncpg_pool):
3636

3737
async def test_db_api(bind, random_name):
3838
assert await db.scalar(
39-
User.insert().values(nickname=random_name).returning(
39+
User.insert().values(name=random_name).returning(
4040
User.nickname)) == random_name
4141
assert (await db.first(User.query.where(
4242
User.nickname == random_name))).nickname == random_name

tests/test_declarative.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import gino
3+
from gino.declarative import InvertDict
34
from asyncpg.exceptions import (
45
UniqueViolationError, ForeignKeyViolationError, CheckViolationError)
56

@@ -54,7 +55,7 @@ class Model(db.Model):
5455

5556

5657
async def test_inline_constraints_and_indexes(bind, engine):
57-
u = await User.create(name='test')
58+
u = await User.create(nickname='test')
5859
us1 = await UserSetting.create(user_id=u.id, setting='skin', value='blue')
5960

6061
# PrimaryKeyConstraint
@@ -222,3 +223,20 @@ class AbstractModel(db.Model):
222223

223224
with pytest.raises(TypeError, match='AbstractModel is abstract'):
224225
await AbstractModel.get(1)
226+
227+
228+
async def test_invert_dict():
229+
with pytest.raises(gino.GinoException,
230+
match=r'Column name c1 already maps to \w+'):
231+
InvertDict({'col1': 'c1', 'col2': 'c1'})
232+
with pytest.raises(gino.GinoException,
233+
match=r'Column name c1 already maps to \w+'):
234+
d = InvertDict()
235+
d['col1'] = 'c1'
236+
d['col2'] = 'c1'
237+
238+
d = InvertDict()
239+
d['col1'] = 'c1'
240+
d['col2'] = 'c2'
241+
assert d.invert_get('c1') == 'col1'
242+
assert d.invert_get('c2') == 'col2'

tests/test_executemany.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
# noinspection PyUnusedLocal
99
async def test_status(bind):
1010
statement, params = db.compile(User.insert(),
11-
[dict(nickname='1'), dict(nickname='2')])
12-
assert statement == ('INSERT INTO gino_users (nickname, type) '
11+
[dict(name='1'), dict(name='2')])
12+
assert statement == ('INSERT INTO gino_users (name, type) '
1313
'VALUES ($1, $2)')
1414
assert params == (('1', 'USER'), ('2', 'USER'))
15-
result = await User.insert().gino.status(dict(nickname='1'),
16-
dict(nickname='2'))
15+
result = await User.insert().gino.status(dict(name='1'), dict(name='2'))
1716
assert result is None
1817
assert len(await User.query.gino.all()) == 2
1918

tests/test_iterate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
def names(sa_engine):
1010
rv = {'11', '22', '33'}
1111
sa_engine.execute(User.__table__.insert(),
12-
[dict(nickname=name) for name in rv])
12+
[dict(name=name) for name in rv])
1313
yield rv
1414
sa_engine.execute('DELETE FROM gino_users')
1515

0 commit comments

Comments
 (0)