5
5
from sqlalchemy .sql import ClauseElement
6
6
7
7
from . import json_support
8
- from .declarative import Model
8
+ from .declarative import Model , InvertDict
9
9
from .exceptions import NoSuchRowError
10
10
from .loader import AliasLoader , ModelLoader
11
11
@@ -78,7 +78,7 @@ class UpdateRequest:
78
78
specific model instance and its database row.
79
79
80
80
"""
81
- def __init__ (self , instance ):
81
+ def __init__ (self , instance : 'CRUDModel' ):
82
82
self ._instance = instance
83
83
self ._values = {}
84
84
self ._props = {}
@@ -124,7 +124,7 @@ async def apply(self, bind=None, timeout=DEFAULT):
124
124
json_updates = {}
125
125
for prop , value in self ._props .items ():
126
126
value = prop .save (self ._instance , value )
127
- updates = json_updates .setdefault (prop .column_name , {})
127
+ updates = json_updates .setdefault (prop .prop_name , {})
128
128
if self ._literal :
129
129
updates [prop .name ] = value
130
130
else :
@@ -133,26 +133,26 @@ async def apply(self, bind=None, timeout=DEFAULT):
133
133
elif not isinstance (value , ClauseElement ):
134
134
value = sa .cast (value , sa .Unicode )
135
135
updates [sa .cast (prop .name , sa .Unicode )] = value
136
- for column_name , updates in json_updates .items ():
137
- column = getattr (cls , column_name )
136
+ for prop_name , updates in json_updates .items ():
137
+ prop = getattr (cls , prop_name )
138
138
from .dialects .asyncpg import JSONB
139
- if isinstance (column .type , JSONB ):
139
+ if isinstance (prop .type , JSONB ):
140
140
if self ._literal :
141
- values [column_name ] = column .concat (updates )
141
+ values [prop_name ] = prop .concat (updates )
142
142
else :
143
- values [column_name ] = column .concat (
143
+ values [prop_name ] = prop .concat (
144
144
sa .func .jsonb_build_object (
145
145
* itertools .chain (* updates .items ())))
146
146
else :
147
- raise TypeError ('{} is not supported.' .format (column .type ))
147
+ raise TypeError ('{} is not supported.' .format (prop .type ))
148
148
149
149
opts = dict (return_model = False )
150
150
if timeout is not DEFAULT :
151
151
opts ['timeout' ] = timeout
152
152
clause = type (self ._instance ).update .where (
153
153
self ._locator ,
154
154
).values (
155
- ** values ,
155
+ ** self . _instance . _get_sa_values ( values ) ,
156
156
).returning (
157
157
* [getattr (cls , key ) for key in values ],
158
158
).execution_options (** opts )
@@ -161,7 +161,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
161
161
row = await bind .first (clause )
162
162
if not row :
163
163
raise NoSuchRowError ()
164
- self ._instance .__values__ .update (row )
164
+ for k , v in row .items ():
165
+ self ._instance .__values__ [
166
+ self ._instance ._column_name_map .invert_get (k )] = v
165
167
for prop in self ._props :
166
168
prop .reload (self ._instance )
167
169
return self
@@ -409,6 +411,7 @@ class CRUDModel(Model):
409
411
"""
410
412
411
413
_update_request_cls = UpdateRequest
414
+ _column_name_map = InvertDict ()
412
415
413
416
def __init__ (self , ** values ):
414
417
super ().__init__ ()
@@ -421,10 +424,10 @@ def _init_table(cls, sub_cls):
421
424
for each_cls in sub_cls .__mro__ [::- 1 ]:
422
425
for k , v in each_cls .__dict__ .items ():
423
426
if isinstance (v , json_support .JSONProperty ):
424
- if not hasattr (sub_cls , v .column_name ):
427
+ if not hasattr (sub_cls , v .prop_name ):
425
428
raise AttributeError (
426
429
'Requires "{}" JSON[B] column.' .format (
427
- v .column_name ))
430
+ v .prop_name ))
428
431
v .name = k
429
432
if rv is not None :
430
433
rv .__model__ = weakref .ref (sub_cls )
@@ -440,12 +443,12 @@ async def _create(self, bind=None, timeout=DEFAULT):
440
443
cls = type (self )
441
444
# noinspection PyUnresolvedReferences,PyProtectedMember
442
445
cls ._check_abstract ()
443
- keys = set (self .__profile__ .keys () if self .__profile__ else [])
444
- for key in keys :
446
+ profile_keys = set (self .__profile__ .keys () if self .__profile__ else [])
447
+ for key in profile_keys :
445
448
cls .__dict__ .get (key ).save (self )
446
449
# initialize default values
447
450
for key , prop in cls .__dict__ .items ():
448
- if key in keys :
451
+ if key in profile_keys :
449
452
continue
450
453
if isinstance (prop , json_support .JSONProperty ):
451
454
if prop .default is None or prop .after_get .method is not None :
@@ -458,15 +461,25 @@ async def _create(self, bind=None, timeout=DEFAULT):
458
461
if timeout is not DEFAULT :
459
462
opts ['timeout' ] = timeout
460
463
# noinspection PyArgumentList
461
- q = cls .__table__ .insert ().values (** self .__values__ ).returning (
462
- * cls ).execution_options (** opts )
464
+ q = cls .__table__ .insert ().values (
465
+ ** self ._get_sa_values (self .__values__ )
466
+ ).returning (
467
+ * cls
468
+ ).execution_options (** opts )
463
469
if bind is None :
464
470
bind = cls .__metadata__ .bind
465
471
row = await bind .first (q )
466
- self .__values__ .update (row )
472
+ for k , v in row .items ():
473
+ self .__values__ [self ._column_name_map .invert_get (k )] = v
467
474
self .__profile__ = None
468
475
return self
469
476
477
+ def _get_sa_values (self , instance_values : dict ) -> dict :
478
+ values = {}
479
+ for k , v in instance_values .items ():
480
+ values [self ._column_name_map [k ]] = v
481
+ return values
482
+
470
483
@classmethod
471
484
async def get (cls , ident , bind = None , timeout = DEFAULT ):
472
485
"""
@@ -592,11 +605,12 @@ def to_dict(self):
592
605
593
606
"""
594
607
cls = type (self )
595
- keys = set (c .name for c in cls )
608
+ # noinspection PyTypeChecker
609
+ keys = set (cls ._column_name_map .invert_get (c .name ) for c in cls )
596
610
for key , prop in cls .__dict__ .items ():
597
611
if isinstance (prop , json_support .JSONProperty ):
598
612
keys .add (key )
599
- keys .discard (prop .column_name )
613
+ keys .discard (prop .prop_name )
600
614
return dict ((k , getattr (self , k )) for k in keys )
601
615
602
616
@classmethod
0 commit comments