Skip to content

Commit 4588e84

Browse files
authored
Merge pull request #297 from wwwjfy/t261
fixed #261, map model and db column names
2 parents 3b11fd8 + 31a3878 commit 4588e84

14 files changed

+171
-73
lines changed

docs/schema.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,19 @@ more declarative. Instead of ``users.c.name``, you can now access the column by
197197
available at ``User.__table__`` and ``Address.__table__``. You can use anything
198198
that works in GINO core here.
199199

200+
.. note::
201+
202+
Column names can be different as a class property and database column.
203+
For example, name can be declared as
204+
``nickname = db.Column('name', db.Unicode(), default='noname')``. In this
205+
example, ``User.nickname`` is used to access the column, while in database,
206+
the column name is ``name``.
207+
208+
What's worth mentioning is where raw SQL statements are used, or
209+
``TableClause`` is involved, like ``User.insert()``, the original name is
210+
required to be used, because in this case, GINO has no knowledge about the
211+
mappings.
212+
200213
.. tip::
201214

202215
``db.Model`` is a dynamically created parent class for your models. It is

gino/crud.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlalchemy.sql import ClauseElement
66

77
from . import json_support
8-
from .declarative import Model
8+
from .declarative import Model, InvertDict
99
from .exceptions import NoSuchRowError
1010
from .loader import AliasLoader, ModelLoader
1111

@@ -78,7 +78,7 @@ class UpdateRequest:
7878
specific model instance and its database row.
7979
8080
"""
81-
def __init__(self, instance):
81+
def __init__(self, instance: 'CRUDModel'):
8282
self._instance = instance
8383
self._values = {}
8484
self._props = {}
@@ -88,7 +88,7 @@ def __init__(self, instance):
8888
try:
8989
self._locator = instance.lookup()
9090
except LookupError:
91-
# apply() will fail anyway, but still allow updates()
91+
# apply() will fail anyway, but still allow update()
9292
pass
9393

9494
def _set(self, key, value):
@@ -124,7 +124,7 @@ async def apply(self, bind=None, timeout=DEFAULT):
124124
json_updates = {}
125125
for prop, value in self._props.items():
126126
value = prop.save(self._instance, value)
127-
updates = json_updates.setdefault(prop.column_name, {})
127+
updates = json_updates.setdefault(prop.prop_name, {})
128128
if self._literal:
129129
updates[prop.name] = value
130130
else:
@@ -133,26 +133,28 @@ async def apply(self, bind=None, timeout=DEFAULT):
133133
elif not isinstance(value, ClauseElement):
134134
value = sa.cast(value, sa.Unicode)
135135
updates[sa.cast(prop.name, sa.Unicode)] = value
136-
for column_name, updates in json_updates.items():
137-
column = getattr(cls, column_name)
136+
for prop_name, updates in json_updates.items():
137+
prop = getattr(cls, prop_name)
138138
from .dialects.asyncpg import JSONB
139-
if isinstance(column.type, JSONB):
139+
if isinstance(prop.type, JSONB):
140140
if self._literal:
141-
values[column_name] = column.concat(updates)
141+
values[prop_name] = prop.concat(updates)
142142
else:
143-
values[column_name] = column.concat(
143+
values[prop_name] = prop.concat(
144144
sa.func.jsonb_build_object(
145145
*itertools.chain(*updates.items())))
146146
else:
147-
raise TypeError('{} is not supported.'.format(column.type))
147+
raise TypeError('{} is not supported to update json '
148+
'properties in Gino. Please consider using '
149+
'JSONB.'.format(prop.type))
148150

149151
opts = dict(return_model=False)
150152
if timeout is not DEFAULT:
151153
opts['timeout'] = timeout
152154
clause = type(self._instance).update.where(
153155
self._locator,
154156
).values(
155-
**values,
157+
**self._instance._get_sa_values(values),
156158
).returning(
157159
*[getattr(cls, key) for key in values],
158160
).execution_options(**opts)
@@ -161,7 +163,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
161163
row = await bind.first(clause)
162164
if not row:
163165
raise NoSuchRowError()
164-
self._instance.__values__.update(row)
166+
for k, v in row.items():
167+
self._instance.__values__[
168+
self._instance._column_name_map.invert_get(k)] = v
165169
for prop in self._props:
166170
prop.reload(self._instance)
167171
return self
@@ -409,6 +413,7 @@ class CRUDModel(Model):
409413
"""
410414

411415
_update_request_cls = UpdateRequest
416+
_column_name_map = InvertDict()
412417

413418
def __init__(self, **values):
414419
super().__init__()
@@ -421,10 +426,10 @@ def _init_table(cls, sub_cls):
421426
for each_cls in sub_cls.__mro__[::-1]:
422427
for k, v in each_cls.__dict__.items():
423428
if isinstance(v, json_support.JSONProperty):
424-
if not hasattr(sub_cls, v.column_name):
429+
if not hasattr(sub_cls, v.prop_name):
425430
raise AttributeError(
426431
'Requires "{}" JSON[B] column.'.format(
427-
v.column_name))
432+
v.prop_name))
428433
v.name = k
429434
if rv is not None:
430435
rv.__model__ = weakref.ref(sub_cls)
@@ -440,12 +445,12 @@ async def _create(self, bind=None, timeout=DEFAULT):
440445
cls = type(self)
441446
# noinspection PyUnresolvedReferences,PyProtectedMember
442447
cls._check_abstract()
443-
keys = set(self.__profile__.keys() if self.__profile__ else [])
444-
for key in keys:
448+
profile_keys = set(self.__profile__.keys() if self.__profile__ else [])
449+
for key in profile_keys:
445450
cls.__dict__.get(key).save(self)
446451
# initialize default values
447452
for key, prop in cls.__dict__.items():
448-
if key in keys:
453+
if key in profile_keys:
449454
continue
450455
if isinstance(prop, json_support.JSONProperty):
451456
if prop.default is None or prop.after_get.method is not None:
@@ -458,15 +463,25 @@ async def _create(self, bind=None, timeout=DEFAULT):
458463
if timeout is not DEFAULT:
459464
opts['timeout'] = timeout
460465
# noinspection PyArgumentList
461-
q = cls.__table__.insert().values(**self.__values__).returning(
462-
*cls).execution_options(**opts)
466+
q = cls.__table__.insert().values(
467+
**self._get_sa_values(self.__values__)
468+
).returning(
469+
*cls
470+
).execution_options(**opts)
463471
if bind is None:
464472
bind = cls.__metadata__.bind
465473
row = await bind.first(q)
466-
self.__values__.update(row)
474+
for k, v in row.items():
475+
self.__values__[self._column_name_map.invert_get(k)] = v
467476
self.__profile__ = None
468477
return self
469478

479+
def _get_sa_values(self, instance_values: dict) -> dict:
480+
values = {}
481+
for k, v in instance_values.items():
482+
values[self._column_name_map[k]] = v
483+
return values
484+
470485
@classmethod
471486
async def get(cls, ident, bind=None, timeout=DEFAULT):
472487
"""
@@ -592,11 +607,12 @@ def to_dict(self):
592607
593608
"""
594609
cls = type(self)
595-
keys = set(c.name for c in cls)
610+
# noinspection PyTypeChecker
611+
keys = set(cls._column_name_map.invert_get(c.name) for c in cls)
596612
for key, prop in cls.__dict__.items():
597613
if isinstance(prop, json_support.JSONProperty):
598614
keys.add(key)
599-
keys.discard(prop.column_name)
615+
keys.discard(prop.prop_name)
600616
return dict((k, getattr(self, k)) for k in keys)
601617

602618
@classmethod

gino/declarative.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,50 @@
33
import sqlalchemy as sa
44
from sqlalchemy.exc import InvalidRequestError
55

6+
from .exceptions import GinoException
7+
68

79
class ColumnAttribute:
8-
def __init__(self, column):
9-
self.name = column.name
10+
def __init__(self, prop_name, column):
11+
self.prop_name = prop_name
1012
self.column = column
1113

1214
def __get__(self, instance, owner):
1315
if instance is None:
1416
return self.column
1517
else:
16-
return instance.__values__.get(self.name)
18+
return instance.__values__.get(self.prop_name)
1719

1820
def __set__(self, instance, value):
19-
instance.__values__[self.name] = value
21+
instance.__values__[self.prop_name] = value
2022

2123
def __delete__(self, instance):
2224
raise AttributeError('Cannot delete value.')
2325

2426

27+
class InvertDict(dict):
28+
def __init__(self, *args, **kwargs):
29+
super().__init__(*args, **kwargs)
30+
self._inverted_dict = dict()
31+
for k, v in self.items():
32+
if v in self._inverted_dict:
33+
raise GinoException(
34+
'Column name {} already maps to {}'.format(
35+
v, self._inverted_dict[v]))
36+
self._inverted_dict[v] = k
37+
38+
def __setitem__(self, key, value):
39+
if value in self._inverted_dict and self._inverted_dict[value] != key:
40+
raise GinoException(
41+
'Column name {} already maps to {}'.format(
42+
value, self._inverted_dict[value]))
43+
super().__setitem__(key, value)
44+
self._inverted_dict[value] = key
45+
46+
def invert_get(self, key, default=None):
47+
return self._inverted_dict.get(key, default)
48+
49+
2550
class ModelType(type):
2651
def _check_abstract(self):
2752
if self.__table__ is None:
@@ -119,18 +144,22 @@ def _init_table(cls, sub_cls):
119144
columns = []
120145
inspected_args = []
121146
updates = {}
147+
column_name_map = InvertDict()
122148
for each_cls in sub_cls.__mro__[::-1]:
123149
for k, v in getattr(each_cls, '__namespace__',
124150
each_cls.__dict__).items():
125151
if callable(v) and getattr(v, '__declared_attr__', False):
126152
v = updates[k] = v(sub_cls)
127153
if isinstance(v, sa.Column):
128154
v = v.copy()
129-
v.name = k
155+
if not v.name:
156+
v.name = k
157+
column_name_map[k] = v.name
130158
columns.append(v)
131-
updates[k] = sub_cls.__attr_factory__(v)
159+
updates[k] = sub_cls.__attr_factory__(k, v)
132160
elif isinstance(v, (sa.Index, sa.Constraint)):
133161
inspected_args.append(v)
162+
sub_cls._column_name_map = column_name_map
134163

135164
# handle __table_args__
136165
table_args = updates.get('__table_args__',
@@ -173,4 +202,5 @@ def inspect_model_type(target):
173202
return sa.inspection.inspect(target.__table__)
174203

175204

176-
__all__ = ['ColumnAttribute', 'Model', 'declarative_base', 'declared_attr']
205+
__all__ = ['ColumnAttribute', 'Model', 'declarative_base', 'declared_attr',
206+
'InvertDict']

gino/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from contextvars import ContextVar
1616
else:
1717
# noinspection PyPackageRequirements
18-
from aiocontextvars import ContextVar
18+
from aiocontextvars import ContextVar # pragma: no cover
1919

2020

2121
class _BaseDBAPIConnection:

gino/json_support.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ def call(self, instance, val):
2222

2323

2424
class JSONProperty:
25-
def __init__(self, default=None, column_name='profile'):
25+
def __init__(self, default=None, prop_name='profile'):
2626
self.name = None
2727
self.default = default
28-
self.column_name = column_name
28+
self.prop_name = prop_name
2929
self.expression = Hook(self)
3030
self.after_get = Hook(self)
3131
self.before_set = Hook(self)
3232

3333
def __get__(self, instance, owner):
3434
if instance is None:
3535
exp = self.make_expression(
36-
getattr(owner, self.column_name)[self.name])
36+
getattr(owner, self.prop_name)[self.name])
3737
return self.expression.call(owner, exp)
3838
val = self.get_profile(instance).get(self.name, NONE)
3939
if val is NONE:
@@ -54,16 +54,16 @@ def get_profile(self, instance):
5454
if instance.__profile__ is None:
5555
props = type(instance).__dict__
5656
instance.__profile__ = {}
57-
for key, value in (getattr(instance, self.column_name, None)
57+
for key, value in (getattr(instance, self.prop_name, None)
5858
or {}).items():
5959
instance.__profile__[key] = props[key].decode(value)
6060
return instance.__profile__
6161

6262
def save(self, instance, value=NONE):
63-
profile = getattr(instance, self.column_name, None)
63+
profile = getattr(instance, self.prop_name, None)
6464
if profile is None:
6565
profile = {}
66-
setattr(instance, self.column_name, profile)
66+
setattr(instance, self.prop_name, profile)
6767
if value is NONE:
6868
value = instance.__profile__[self.name]
6969
if not isinstance(value, sa.sql.ClauseElement):
@@ -74,7 +74,7 @@ def save(self, instance, value=NONE):
7474
def reload(self, instance):
7575
if instance.__profile__ is None:
7676
return
77-
profile = getattr(instance, self.column_name, None) or {}
77+
profile = getattr(instance, self.prop_name, None) or {}
7878
value = profile.get(self.name, NONE)
7979
if value is NONE:
8080
instance.__profile__.pop(self.name, None)

gino/loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ def _do_load(self, row, *, none_as_none=None):
7575
if none_as_none and all((v is None) for v in values.values()):
7676
return None
7777
rv = self.model()
78-
rv.__values__.update(values)
78+
for c in self.columns:
79+
if c in row:
80+
# noinspection PyProtectedMember
81+
instance_key = self.model._column_name_map.invert_get(c.name)
82+
rv.__values__[instance_key] = row[c]
7983
return rv
8084

8185
def do_load(self, row, context):

tests/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ class User(db.Model):
3434
__tablename__ = 'gino_users'
3535

3636
id = db.Column(db.BigInteger(), primary_key=True)
37-
nickname = db.Column(db.Unicode(), default='noname')
38-
profile = db.Column(JSONB(), nullable=False, server_default='{}')
37+
nickname = db.Column('name', db.Unicode(), default='noname')
38+
profile = db.Column('props', JSONB(), nullable=False, server_default='{}')
3939
type = db.Column(
4040
db.Enum(UserType),
4141
nullable=False,
4242
default=UserType.USER,
4343
)
44-
name = db.StringProperty()
44+
realname = db.StringProperty()
4545
age = db.IntegerProperty(default=18)
4646
balance = db.IntegerProperty(default=0)
4747
birthday = db.DateTimeProperty(

tests/test_aiohttp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class User(db.Model):
2222
__tablename__ = 'gino_users'
2323

2424
id = db.Column(db.BigInteger(), primary_key=True)
25-
nickname = db.Column(db.Unicode(), default='noname')
25+
nickname = db.Column('name', db.Unicode(), default='noname')
2626

2727
routes = web.RouteTableDef()
2828

@@ -107,8 +107,7 @@ async def _test(test_client):
107107
response = await test_client.get('/users/1?method=' + method)
108108
assert response.status == 404
109109

110-
response = await test_client.post('/users',
111-
data=dict(name='fantix'))
110+
response = await test_client.post('/users', data=dict(name='fantix'))
112111
assert response.status == 200
113112
assert await response.json() == dict(id=1, nickname='fantix')
114113

tests/test_bind.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def test_unbind(asyncpg_pool):
4040

4141
async def test_db_api(bind, random_name):
4242
assert await db.scalar(
43-
User.insert().values(nickname=random_name).returning(
43+
User.insert().values(name=random_name).returning(
4444
User.nickname)) == random_name
4545
assert (await db.first(User.query.where(
4646
User.nickname == random_name))).nickname == random_name

0 commit comments

Comments
 (0)