Skip to content

Commit 83252b3

Browse files
terrycainjettify
authored andcommitted
Replace yield from with await... (#283)
* async/await utils.py * async/await pool.py * async/await cursors.py * async/await connection.py * async/await sa/* * Switch some tests to async def * Added async with pool.get() * Fixed up SA, though iter no longer works :/ * Flake8
1 parent e95df74 commit 83252b3

16 files changed

+1150
-1390
lines changed

aiomysql/connection.py

Lines changed: 113 additions & 150 deletions
Large diffs are not rendered by default.

aiomysql/cursors.py

Lines changed: 60 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import re
32
import warnings
43

@@ -8,7 +7,7 @@
87
NotSupportedError, ProgrammingError)
98

109
from .log import logger
11-
from .utils import PY_35, create_future
10+
from .utils import create_future
1211

1312

1413
# https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18
@@ -149,14 +148,13 @@ def closed(self):
149148
"""
150149
return True if not self._connection else False
151150

152-
@asyncio.coroutine
153-
def close(self):
151+
async def close(self):
154152
"""Closing a cursor just exhausts all remaining data."""
155153
conn = self._connection
156154
if conn is None:
157155
return
158156
try:
159-
while (yield from self.nextset()):
157+
while (await self.nextset()):
160158
pass
161159
finally:
162160
self._connection = None
@@ -179,17 +177,16 @@ def setinputsizes(self, *args):
179177
def setoutputsizes(self, *args):
180178
"""Does nothing, required by DB API."""
181179

182-
@asyncio.coroutine
183-
def nextset(self):
180+
async def nextset(self):
184181
"""Get the next query set"""
185182
conn = self._get_db()
186183
current_result = self._result
187184
if current_result is None or current_result is not conn._result:
188185
return
189186
if not current_result.has_next:
190187
return
191-
yield from conn.next_result()
192-
yield from self._do_get_result()
188+
await conn.next_result()
189+
await self._do_get_result()
193190
return True
194191

195192
def _escape_args(self, args, conn):
@@ -215,8 +212,7 @@ def mogrify(self, query, args=None):
215212
query = query % self._escape_args(args, conn)
216213
return query
217214

218-
@asyncio.coroutine
219-
def execute(self, query, args=None):
215+
async def execute(self, query, args=None):
220216
"""Executes the given operation
221217
222218
Executes the given operation substituting any markers with
@@ -231,21 +227,20 @@ def execute(self, query, args=None):
231227
"""
232228
conn = self._get_db()
233229

234-
while (yield from self.nextset()):
230+
while (await self.nextset()):
235231
pass
236232

237233
if args is not None:
238234
query = query % self._escape_args(args, conn)
239235

240-
yield from self._query(query)
236+
await self._query(query)
241237
self._executed = query
242238
if self._echo:
243239
logger.info(query)
244240
logger.info("%r", args)
245241
return self._rowcount
246242

247-
@asyncio.coroutine
248-
def executemany(self, query, args):
243+
async def executemany(self, query, args):
249244
"""Execute the given operation multiple times
250245
251246
The executemany() method will execute the operation iterating
@@ -259,7 +254,7 @@ def executemany(self, query, args):
259254
('John', '555-003')
260255
]
261256
stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')"
262-
yield from cursor.executemany(stmt, data)
257+
await cursor.executemany(stmt, data)
263258
264259
INSERT or REPLACE statements are optimized by batching the data,
265260
that is using the MySQL multiple rows syntax.
@@ -280,20 +275,19 @@ def executemany(self, query, args):
280275
q_values = m.group(2).rstrip()
281276
q_postfix = m.group(3) or ''
282277
assert q_values[0] == '(' and q_values[-1] == ')'
283-
return (yield from self._do_execute_many(
278+
return (await self._do_execute_many(
284279
q_prefix, q_values, q_postfix, args, self.max_stmt_length,
285280
self._get_db().encoding))
286281
else:
287282
rows = 0
288283
for arg in args:
289-
yield from self.execute(query, arg)
284+
await self.execute(query, arg)
290285
rows += self._rowcount
291286
self._rowcount = rows
292287
return self._rowcount
293288

294-
@asyncio.coroutine
295-
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length,
296-
encoding):
289+
async def _do_execute_many(self, prefix, values, postfix, args,
290+
max_stmt_length, encoding):
297291
conn = self._get_db()
298292
escape = self._escape_args
299293
if isinstance(prefix, str):
@@ -312,19 +306,18 @@ def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length,
312306
if isinstance(v, str):
313307
v = v.encode(encoding, 'surrogateescape')
314308
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
315-
r = yield from self.execute(sql + postfix)
309+
r = await self.execute(sql + postfix)
316310
rows += r
317311
sql = bytearray(prefix)
318312
else:
319313
sql += b','
320314
sql += v
321-
r = yield from self.execute(sql + postfix)
315+
r = await self.execute(sql + postfix)
322316
rows += r
323317
self._rowcount = rows
324318
return rows
325319

326-
@asyncio.coroutine
327-
def callproc(self, procname, args=()):
320+
async def callproc(self, procname, args=()):
328321
"""Execute stored procedure procname with args
329322
330323
Compatibility warning: PEP-249 specifies that any modified
@@ -357,12 +350,12 @@ def callproc(self, procname, args=()):
357350

358351
for index, arg in enumerate(args):
359352
q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg))
360-
yield from self._query(q)
361-
yield from self.nextset()
353+
await self._query(q)
354+
await self.nextset()
362355

363356
_args = ','.join('@_%s_%d' % (procname, i) for i in range(len(args)))
364357
q = "CALL %s(%s)" % (procname, _args)
365-
yield from self._query(q)
358+
await self._query(q)
366359
self._executed = q
367360
return args
368361

@@ -454,15 +447,13 @@ def scroll(self, value, mode='relative'):
454447
fut.set_result(None)
455448
return fut
456449

457-
@asyncio.coroutine
458-
def _query(self, q):
450+
async def _query(self, q):
459451
conn = self._get_db()
460452
self._last_executed = q
461-
yield from conn.query(q)
462-
yield from self._do_get_result()
453+
await conn.query(q)
454+
await self._do_get_result()
463455

464-
@asyncio.coroutine
465-
def _do_get_result(self):
456+
async def _do_get_result(self):
466457
conn = self._get_db()
467458
self._rownumber = 0
468459
self._result = result = conn._result
@@ -472,13 +463,12 @@ def _do_get_result(self):
472463
self._rows = result.rows
473464

474465
if result.warning_count > 0:
475-
yield from self._show_warnings(conn)
466+
await self._show_warnings(conn)
476467

477-
@asyncio.coroutine
478-
def _show_warnings(self, conn):
468+
async def _show_warnings(self, conn):
479469
if self._result and self._result.has_next:
480470
return
481-
ws = yield from conn.show_warnings()
471+
ws = await conn.show_warnings()
482472
if ws is None:
483473
return
484474
for w in ws:
@@ -496,36 +486,30 @@ def _show_warnings(self, conn):
496486
ProgrammingError = ProgrammingError
497487
NotSupportedError = NotSupportedError
498488

499-
if PY_35: # pragma: no branch
500-
@asyncio.coroutine
501-
def __aiter__(self):
502-
return self
489+
async def __aiter__(self):
490+
return self
503491

504-
@asyncio.coroutine
505-
def __anext__(self):
506-
ret = yield from self.fetchone()
507-
if ret is not None:
508-
return ret
509-
else:
510-
raise StopAsyncIteration # noqa
492+
async def __anext__(self):
493+
ret = await self.fetchone()
494+
if ret is not None:
495+
return ret
496+
else:
497+
raise StopAsyncIteration # noqa
511498

512-
@asyncio.coroutine
513-
def __aenter__(self):
514-
return self
499+
async def __aenter__(self):
500+
return self
515501

516-
@asyncio.coroutine
517-
def __aexit__(self, exc_type, exc_val, exc_tb):
518-
yield from self.close()
519-
return
502+
async def __aexit__(self, exc_type, exc_val, exc_tb):
503+
await self.close()
504+
return
520505

521506

522507
class _DictCursorMixin:
523508
# You can override this to use OrderedDict or other dict-like types.
524509
dict_type = dict
525510

526-
@asyncio.coroutine
527-
def _do_get_result(self):
528-
yield from super()._do_get_result()
511+
async def _do_get_result(self):
512+
await super()._do_get_result()
529513
fields = []
530514
if self._description:
531515
for f in self._result.fields:
@@ -563,61 +547,55 @@ class SSCursor(Cursor):
563547
possible to scroll backwards, as only the current row is held in memory.
564548
"""
565549

566-
@asyncio.coroutine
567-
def close(self):
550+
async def close(self):
568551
conn = self._connection
569552
if conn is None:
570553
return
571554

572555
if self._result is not None and self._result is conn._result:
573-
yield from self._result._finish_unbuffered_query()
556+
await self._result._finish_unbuffered_query()
574557

575558
try:
576-
while (yield from self.nextset()):
559+
while (await self.nextset()):
577560
pass
578561
finally:
579562
self._connection = None
580563

581-
@asyncio.coroutine
582-
def _query(self, q):
564+
async def _query(self, q):
583565
conn = self._get_db()
584566
self._last_executed = q
585-
yield from conn.query(q, unbuffered=True)
586-
yield from self._do_get_result()
567+
await conn.query(q, unbuffered=True)
568+
await self._do_get_result()
587569
return self._rowcount
588570

589-
@asyncio.coroutine
590-
def _read_next(self):
571+
async def _read_next(self):
591572
"""Read next row """
592-
row = yield from self._result._read_rowdata_packet_unbuffered()
573+
row = await self._result._read_rowdata_packet_unbuffered()
593574
row = self._conv_row(row)
594575
return row
595576

596-
@asyncio.coroutine
597-
def fetchone(self):
577+
async def fetchone(self):
598578
""" Fetch next row """
599579
self._check_executed()
600-
row = yield from self._read_next()
580+
row = await self._read_next()
601581
if row is None:
602582
return
603583
self._rownumber += 1
604584
return row
605585

606-
@asyncio.coroutine
607-
def fetchall(self):
586+
async def fetchall(self):
608587
"""Fetch all, as per MySQLdb. Pretty useless for large queries, as
609588
it is buffered.
610589
"""
611590
rows = []
612591
while True:
613-
row = yield from self.fetchone()
592+
row = await self.fetchone()
614593
if row is None:
615594
break
616595
rows.append(row)
617596
return rows
618597

619-
@asyncio.coroutine
620-
def fetchmany(self, size=None):
598+
async def fetchmany(self, size=None):
621599
"""Returns the next set of rows of a query result, returning a
622600
list of tuples. When no more rows are available, it returns an
623601
empty list.
@@ -634,15 +612,14 @@ def fetchmany(self, size=None):
634612

635613
rows = []
636614
for i in range(size):
637-
row = yield from self._read_next()
615+
row = await self._read_next()
638616
if row is None:
639617
break
640618
rows.append(row)
641619
self._rownumber += 1
642620
return rows
643621

644-
@asyncio.coroutine
645-
def scroll(self, value, mode='relative'):
622+
async def scroll(self, value, mode='relative'):
646623
"""Scroll the cursor in the result set to a new position
647624
according to mode . Same as :meth:`Cursor.scroll`, but move cursor
648625
on server side one by one row. If you want to move 20 rows forward
@@ -661,7 +638,7 @@ def scroll(self, value, mode='relative'):
661638
"by this cursor")
662639

663640
for _ in range(value):
664-
yield from self._read_next()
641+
await self._read_next()
665642
self._rownumber += value
666643
elif mode == 'absolute':
667644
if value < self._rownumber:
@@ -670,7 +647,7 @@ def scroll(self, value, mode='relative'):
670647

671648
end = value - self._rownumber
672649
for _ in range(end):
673-
yield from self._read_next()
650+
await self._read_next()
674651
self._rownumber = value
675652
else:
676653
raise ProgrammingError("unknown scroll mode %s" % mode)

0 commit comments

Comments
 (0)