@@ -789,49 +789,69 @@ async def _query_and_update(bind, item, query, cols, execution_opts):
789
789
if bind ._dialect .support_returning :
790
790
# noinspection PyArgumentList
791
791
query = query .returning (* cols )
792
- row = await bind .first (query )
792
+
793
+ async def _execute_and_fetch (conn , query ):
794
+ context , row = await conn ._first_with_context (query )
795
+ if not bind ._dialect .support_returning :
796
+ if context .isinsert :
797
+ table = context .compiled .statement .table
798
+ key_getter = context .compiled ._key_getters_for_crud_column [2 ]
799
+ compiled_params = context .compiled_parameters [0 ]
800
+ last_row_id = context .get_lastrowid ()
801
+ if last_row_id is not None :
802
+ lookup_conds = [
803
+ c == last_row_id
804
+ if c is table ._autoincrement_column
805
+ else c == _cast_json (
806
+ c , compiled_params .get (key_getter (c ), None ))
807
+ for c in table .primary_key
808
+ ]
809
+ else :
810
+ lookup_conds = [
811
+ c == _cast_json (
812
+ c , compiled_params .get (key_getter (c ), None ))
813
+ for c in table .columns
814
+ ]
815
+ query = sa .select (table .columns ).where (
816
+ sa .and_ (* lookup_conds )).execution_options (** execution_opts )
817
+ row = await conn .first (query )
818
+ elif context .isupdate :
819
+ if context .get_affected_rows () == 0 :
820
+ raise NoSuchRowError ()
821
+ table = context .compiled .statement .table
822
+ if len (table .primary_key ) > 0 :
823
+ lookup_conds = [
824
+ c == _cast_json (
825
+ c , item .__values__ [
826
+ item ._column_name_map .invert_get (c .name )])
827
+ for c in table .primary_key
828
+ ]
829
+ else :
830
+ lookup_conds = [
831
+ c == _cast_json (
832
+ c , item .__values__ [
833
+ item ._column_name_map .invert_get (c .name )])
834
+ for c in table .columns
835
+ ]
836
+ query = sa .select (table .columns ).where (
837
+ sa .and_ (* lookup_conds )).execution_options (** execution_opts )
838
+ row = await conn .first (query )
839
+ return row
840
+
841
+ if isinstance (bind , GinoConnection ):
842
+ row = await _execute_and_fetch (bind , query )
793
843
else :
794
- # CAVEAT: MySQL doesn't support RETURNING. The workaround here is
795
- # to get lastrowid and load it after insertion.
796
- # Note that this only works for tables with AUTO_INCREMENT column
797
- # For update queries, update using its primary key
798
-
799
- # make insertion and select in one transaction to get the might-be
800
- # "dirty" row
801
- release_conn = False
802
- if not isinstance (bind , GinoConnection ):
803
- conn = await bind .acquire (reuse = True )
804
- release_conn = True
805
- else :
806
- conn = bind
807
- try :
808
- lastrowid , affected_rows = await conn .all (
809
- query .execution_options (return_affected_rows = True )
810
- )
811
- if not lastrowid and not affected_rows :
812
- raise NoSuchRowError ()
813
- # It's insertion and primary key is AUTO_INCREMENT
814
- if lastrowid :
815
- pkey = cls .__table__ .primary_key
816
- query = (
817
- sa .select (cols )
818
- .where (pkey .columns .values ()[0 ] == lastrowid )
819
- .execution_options (** execution_opts )
820
- )
821
- else :
822
- try :
823
- query = (
824
- sa .select (cols )
825
- .where (item .lookup ())
826
- .execution_options (** execution_opts )
827
- )
828
- except LookupError : # no primary key
829
- return None
830
- row = await conn .first (query )
831
- finally :
832
- if release_conn :
833
- await conn .release ()
844
+ async with bind .acquire (reuse = True ) as conn :
845
+ row = await _execute_and_fetch (conn , query )
834
846
if not row :
835
847
raise NoSuchRowError ()
836
848
for k , v in row .items ():
837
849
item .__values__ [item ._column_name_map .invert_get (k )] = v
850
+
851
+
852
+ def _cast_json (column , value ):
853
+ # FIXME: for MySQL, json string in WHERE clause needs to be cast to JSON type
854
+ if (isinstance (column .type , sa .JSON ) or
855
+ isinstance (getattr (column .type , 'impl' , None ), sa .JSON )):
856
+ return sa .cast (value , sa .JSON )
857
+ return value
0 commit comments