Skip to content

Commit 100baab

Browse files
committed
refetch after insert and update
MySQL doesn't support returning after insert and update
1 parent 3cf684b commit 100baab

File tree

6 files changed

+97
-76
lines changed

6 files changed

+97
-76
lines changed

mysql_tests/test_bind.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,9 @@ async def test_unbind(aiomysql_pool):
4040

4141

4242
async def test_db_api(bind, random_name):
43-
lastrowid = await db.all(
44-
User.insert().values(name=random_name).execution_options(
45-
return_lastrowid=True)
46-
)
47-
r = await db.scalar(User.select('nickname').where(User.id == lastrowid))
43+
result = await db.first(User.insert().values(name=random_name))
44+
assert result is None
45+
r = await db.scalar(User.select('nickname').where(User.nickname == random_name))
4846
assert r == random_name
4947
assert (
5048
await db.first(User.query.where(User.nickname == random_name))

mysql_tests/test_json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class PropsTest(db.Model):
130130
assert t.obj["x"] == 1
131131
assert t.arr[-1] == 6
132132
assert await db.select(
133-
[PropsTest.profile, PropsTest.raw, PropsTest.bool,]
133+
[PropsTest.profile, PropsTest.raw, PropsTest.bool]
134134
).gino.first() == (
135135
{
136136
"arr": [3, 4, 5, 6],

src/gino/crud.py

Lines changed: 61 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -789,49 +789,69 @@ async def _query_and_update(bind, item, query, cols, execution_opts):
789789
if bind._dialect.support_returning:
790790
# noinspection PyArgumentList
791791
query = query.returning(*cols)
792-
row = await bind.first(query)
792+
793+
async def _execute_and_fetch(conn, query):
794+
context, row = await conn._first_with_context(query)
795+
if not bind._dialect.support_returning:
796+
if context.isinsert:
797+
table = context.compiled.statement.table
798+
key_getter = context.compiled._key_getters_for_crud_column[2]
799+
compiled_params = context.compiled_parameters[0]
800+
last_row_id = context.get_lastrowid()
801+
if last_row_id is not None:
802+
lookup_conds = [
803+
c == last_row_id
804+
if c is table._autoincrement_column
805+
else c == _cast_json(
806+
c, compiled_params.get(key_getter(c), None))
807+
for c in table.primary_key
808+
]
809+
else:
810+
lookup_conds = [
811+
c == _cast_json(
812+
c, compiled_params.get(key_getter(c), None))
813+
for c in table.columns
814+
]
815+
query = sa.select(table.columns).where(
816+
sa.and_(*lookup_conds)).execution_options(**execution_opts)
817+
row = await conn.first(query)
818+
elif context.isupdate:
819+
if context.get_affected_rows() == 0:
820+
raise NoSuchRowError()
821+
table = context.compiled.statement.table
822+
if len(table.primary_key) > 0:
823+
lookup_conds = [
824+
c == _cast_json(
825+
c, item.__values__[
826+
item._column_name_map.invert_get(c.name)])
827+
for c in table.primary_key
828+
]
829+
else:
830+
lookup_conds = [
831+
c == _cast_json(
832+
c, item.__values__[
833+
item._column_name_map.invert_get(c.name)])
834+
for c in table.columns
835+
]
836+
query = sa.select(table.columns).where(
837+
sa.and_(*lookup_conds)).execution_options(**execution_opts)
838+
row = await conn.first(query)
839+
return row
840+
841+
if isinstance(bind, GinoConnection):
842+
row = await _execute_and_fetch(bind, query)
793843
else:
794-
# CAVEAT: MySQL doesn't support RETURNING. The workaround here is
795-
# to get lastrowid and load it after insertion.
796-
# Note that this only works for tables with AUTO_INCREMENT column
797-
# For update queries, update using its primary key
798-
799-
# make insertion and select in one transaction to get the might-be
800-
# "dirty" row
801-
release_conn = False
802-
if not isinstance(bind, GinoConnection):
803-
conn = await bind.acquire(reuse=True)
804-
release_conn = True
805-
else:
806-
conn = bind
807-
try:
808-
lastrowid, affected_rows = await conn.all(
809-
query.execution_options(return_affected_rows=True)
810-
)
811-
if not lastrowid and not affected_rows:
812-
raise NoSuchRowError()
813-
# It's insertion and primary key is AUTO_INCREMENT
814-
if lastrowid:
815-
pkey = cls.__table__.primary_key
816-
query = (
817-
sa.select(cols)
818-
.where(pkey.columns.values()[0] == lastrowid)
819-
.execution_options(**execution_opts)
820-
)
821-
else:
822-
try:
823-
query = (
824-
sa.select(cols)
825-
.where(item.lookup())
826-
.execution_options(**execution_opts)
827-
)
828-
except LookupError: # no primary key
829-
return None
830-
row = await conn.first(query)
831-
finally:
832-
if release_conn:
833-
await conn.release()
844+
async with bind.acquire(reuse=True) as conn:
845+
row = await _execute_and_fetch(conn, query)
834846
if not row:
835847
raise NoSuchRowError()
836848
for k, v in row.items():
837849
item.__values__[item._column_name_map.invert_get(k)] = v
850+
851+
852+
def _cast_json(column, value):
853+
# FIXME: for MySQL, json string in WHERE clause needs to be cast to JSON type
854+
if (isinstance(column.type, sa.JSON) or
855+
isinstance(getattr(column.type, 'impl', None), sa.JSON)):
856+
return sa.cast(value, sa.JSON)
857+
return value

src/gino/dialects/aiomysql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ class AiomysqlDBAPI(base.BaseDBAPI):
4848
# noinspection PyAbstractClass
4949
class AiomysqlExecutionContext(base.ExecutionContextOverride, MySQLExecutionContext):
5050
def get_lastrowid(self):
51-
return self.cursor.last_row_id
51+
lastrowid = self.cursor.last_row_id
52+
return None if lastrowid == 0 else lastrowid
5253

5354
def get_affected_rows(self):
5455
return self.cursor.affected_rows

src/gino/dialects/base.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def __init__(self, context):
196196
def context(self):
197197
return self._context
198198

199-
async def execute(self, one=False, return_model=True, status=False):
199+
async def execute(self, one=False, return_model=True, status=False,
200+
return_context=False):
200201
context = self._context
201202

202203
param_groups = []
@@ -213,31 +214,26 @@ async def execute(self, one=False, return_model=True, status=False):
213214
return await cursor.async_execute(
214215
context.statement, context.timeout, param_groups, many=True
215216
)
217+
args = param_groups[0]
218+
if context.baked_query:
219+
rows = await cursor.execute_baked(
220+
context.baked_query, context.timeout, args, one
221+
)
216222
else:
217-
args = param_groups[0]
218-
if context.baked_query:
219-
rows = await cursor.execute_baked(
220-
context.baked_query, context.timeout, args, one
221-
)
223+
rows = await cursor.async_execute(
224+
context.statement, context.timeout, args, 1 if one else 0
225+
)
226+
item = context.process_rows(rows, return_model=return_model)
227+
if one:
228+
if item:
229+
item = item[0]
222230
else:
223-
rows = await cursor.async_execute(
224-
context.statement, context.timeout, args, 1 if one else 0
225-
)
226-
if not self.context.dialect.support_returning and (
227-
self.context.isinsert or self.context.isupdate
228-
):
229-
if self.context.execution_options.get("return_affected_rows", False):
230-
return context.get_lastrowid(), context.get_affected_rows()
231-
return context.get_lastrowid()
232-
item = context.process_rows(rows, return_model=return_model)
233-
if one:
234-
if item:
235-
item = item[0]
236-
else:
237-
item = None
238-
if status:
239-
item = cursor.get_statusmsg(), item
240-
return item
231+
item = None
232+
if status:
233+
return cursor.get_statusmsg(), item
234+
if return_context:
235+
return context, item
236+
return item
241237

242238
def iterate(self):
243239
if self._context.executemany:
@@ -296,10 +292,6 @@ def timeout(self):
296292
def loader(self):
297293
return self._compiled_first_opt("loader", None)
298294

299-
@util.memoized_property
300-
def return_affected_rows(self):
301-
return self._compiled_first_opt("return_affected_rows", False)
302-
303295
def process_rows(self, rows, return_model=True):
304296
if not rows:
305297
return []
@@ -417,6 +409,12 @@ def _init_baked_query(cls, dialect, connection, dbapi_connection, bq, parameters
417409
self.baked_query = bq
418410
return self
419411

412+
def get_lastrowid(self):
413+
raise NotImplementedError
414+
415+
def get_affected_rows(self):
416+
raise NotImplementedError
417+
420418

421419
class AsyncDialectMixin:
422420
cursor_cls = DBAPICursor

src/gino/engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,10 @@ async def first(self, clause, *multiparams, **params):
348348
result = self._execute(clause, multiparams, params)
349349
return await result.execute(one=True)
350350

351+
async def _first_with_context(self, clause, *multiparams, **params):
352+
result = self._execute(clause, multiparams, params)
353+
return await result.execute(one=True, return_context=True)
354+
351355
async def one_or_none(self, clause, *multiparams, **params):
352356
"""
353357
Runs the given query in database, returns at most one result.

0 commit comments

Comments
 (0)