5
5
from sqlalchemy .sql import ClauseElement
6
6
7
7
from . import json_support
8
- from .declarative import Model
8
+ from .declarative import Model , InvertDict
9
9
from .exceptions import NoSuchRowError
10
10
from .loader import AliasLoader , ModelLoader
11
11
@@ -78,7 +78,7 @@ class UpdateRequest:
78
78
specific model instance and its database row.
79
79
80
80
"""
81
- def __init__ (self , instance ):
81
+ def __init__ (self , instance : 'CRUDModel' ):
82
82
self ._instance = instance
83
83
self ._values = {}
84
84
self ._props = {}
@@ -88,7 +88,7 @@ def __init__(self, instance):
88
88
try :
89
89
self ._locator = instance .lookup ()
90
90
except LookupError :
91
- # apply() will fail anyway, but still allow updates ()
91
+ # apply() will fail anyway, but still allow update ()
92
92
pass
93
93
94
94
def _set (self , key , value ):
@@ -124,7 +124,7 @@ async def apply(self, bind=None, timeout=DEFAULT):
124
124
json_updates = {}
125
125
for prop , value in self ._props .items ():
126
126
value = prop .save (self ._instance , value )
127
- updates = json_updates .setdefault (prop .column_name , {})
127
+ updates = json_updates .setdefault (prop .prop_name , {})
128
128
if self ._literal :
129
129
updates [prop .name ] = value
130
130
else :
@@ -133,26 +133,28 @@ async def apply(self, bind=None, timeout=DEFAULT):
133
133
elif not isinstance (value , ClauseElement ):
134
134
value = sa .cast (value , sa .Unicode )
135
135
updates [sa .cast (prop .name , sa .Unicode )] = value
136
- for column_name , updates in json_updates .items ():
137
- column = getattr (cls , column_name )
136
+ for prop_name , updates in json_updates .items ():
137
+ prop = getattr (cls , prop_name )
138
138
from .dialects .asyncpg import JSONB
139
- if isinstance (column .type , JSONB ):
139
+ if isinstance (prop .type , JSONB ):
140
140
if self ._literal :
141
- values [column_name ] = column .concat (updates )
141
+ values [prop_name ] = prop .concat (updates )
142
142
else :
143
- values [column_name ] = column .concat (
143
+ values [prop_name ] = prop .concat (
144
144
sa .func .jsonb_build_object (
145
145
* itertools .chain (* updates .items ())))
146
146
else :
147
- raise TypeError ('{} is not supported.' .format (column .type ))
147
+ raise TypeError ('{} is not supported to update json '
148
+ 'properties in Gino. Please consider using '
149
+ 'JSONB.' .format (prop .type ))
148
150
149
151
opts = dict (return_model = False )
150
152
if timeout is not DEFAULT :
151
153
opts ['timeout' ] = timeout
152
154
clause = type (self ._instance ).update .where (
153
155
self ._locator ,
154
156
).values (
155
- ** values ,
157
+ ** self . _instance . _get_sa_values ( values ) ,
156
158
).returning (
157
159
* [getattr (cls , key ) for key in values ],
158
160
).execution_options (** opts )
@@ -161,7 +163,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
161
163
row = await bind .first (clause )
162
164
if not row :
163
165
raise NoSuchRowError ()
164
- self ._instance .__values__ .update (row )
166
+ for k , v in row .items ():
167
+ self ._instance .__values__ [
168
+ self ._instance ._column_name_map .invert_get (k )] = v
165
169
for prop in self ._props :
166
170
prop .reload (self ._instance )
167
171
return self
@@ -409,6 +413,7 @@ class CRUDModel(Model):
409
413
"""
410
414
411
415
_update_request_cls = UpdateRequest
416
+ _column_name_map = InvertDict ()
412
417
413
418
def __init__ (self , ** values ):
414
419
super ().__init__ ()
@@ -421,10 +426,10 @@ def _init_table(cls, sub_cls):
421
426
for each_cls in sub_cls .__mro__ [::- 1 ]:
422
427
for k , v in each_cls .__dict__ .items ():
423
428
if isinstance (v , json_support .JSONProperty ):
424
- if not hasattr (sub_cls , v .column_name ):
429
+ if not hasattr (sub_cls , v .prop_name ):
425
430
raise AttributeError (
426
431
'Requires "{}" JSON[B] column.' .format (
427
- v .column_name ))
432
+ v .prop_name ))
428
433
v .name = k
429
434
if rv is not None :
430
435
rv .__model__ = weakref .ref (sub_cls )
@@ -440,12 +445,12 @@ async def _create(self, bind=None, timeout=DEFAULT):
440
445
cls = type (self )
441
446
# noinspection PyUnresolvedReferences,PyProtectedMember
442
447
cls ._check_abstract ()
443
- keys = set (self .__profile__ .keys () if self .__profile__ else [])
444
- for key in keys :
448
+ profile_keys = set (self .__profile__ .keys () if self .__profile__ else [])
449
+ for key in profile_keys :
445
450
cls .__dict__ .get (key ).save (self )
446
451
# initialize default values
447
452
for key , prop in cls .__dict__ .items ():
448
- if key in keys :
453
+ if key in profile_keys :
449
454
continue
450
455
if isinstance (prop , json_support .JSONProperty ):
451
456
if prop .default is None or prop .after_get .method is not None :
@@ -458,15 +463,25 @@ async def _create(self, bind=None, timeout=DEFAULT):
458
463
if timeout is not DEFAULT :
459
464
opts ['timeout' ] = timeout
460
465
# noinspection PyArgumentList
461
- q = cls .__table__ .insert ().values (** self .__values__ ).returning (
462
- * cls ).execution_options (** opts )
466
+ q = cls .__table__ .insert ().values (
467
+ ** self ._get_sa_values (self .__values__ )
468
+ ).returning (
469
+ * cls
470
+ ).execution_options (** opts )
463
471
if bind is None :
464
472
bind = cls .__metadata__ .bind
465
473
row = await bind .first (q )
466
- self .__values__ .update (row )
474
+ for k , v in row .items ():
475
+ self .__values__ [self ._column_name_map .invert_get (k )] = v
467
476
self .__profile__ = None
468
477
return self
469
478
479
+ def _get_sa_values (self , instance_values : dict ) -> dict :
480
+ values = {}
481
+ for k , v in instance_values .items ():
482
+ values [self ._column_name_map [k ]] = v
483
+ return values
484
+
470
485
@classmethod
471
486
async def get (cls , ident , bind = None , timeout = DEFAULT ):
472
487
"""
@@ -592,11 +607,12 @@ def to_dict(self):
592
607
593
608
"""
594
609
cls = type (self )
595
- keys = set (c .name for c in cls )
610
+ # noinspection PyTypeChecker
611
+ keys = set (cls ._column_name_map .invert_get (c .name ) for c in cls )
596
612
for key , prop in cls .__dict__ .items ():
597
613
if isinstance (prop , json_support .JSONProperty ):
598
614
keys .add (key )
599
- keys .discard (prop .column_name )
615
+ keys .discard (prop .prop_name )
600
616
return dict ((k , getattr (self , k )) for k in keys )
601
617
602
618
@classmethod
0 commit comments