1
- import asyncio
2
1
import re
3
2
import warnings
4
3
8
7
NotSupportedError , ProgrammingError )
9
8
10
9
from .log import logger
11
- from .utils import PY_35 , create_future
10
+ from .utils import create_future
12
11
13
12
14
13
# https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18
@@ -149,14 +148,13 @@ def closed(self):
149
148
"""
150
149
return True if not self ._connection else False
151
150
152
- @asyncio .coroutine
153
- def close (self ):
151
+ async def close (self ):
154
152
"""Closing a cursor just exhausts all remaining data."""
155
153
conn = self ._connection
156
154
if conn is None :
157
155
return
158
156
try :
159
- while (yield from self .nextset ()):
157
+ while (await self .nextset ()):
160
158
pass
161
159
finally :
162
160
self ._connection = None
@@ -179,17 +177,16 @@ def setinputsizes(self, *args):
179
177
def setoutputsizes (self , * args ):
180
178
"""Does nothing, required by DB API."""
181
179
182
- @asyncio .coroutine
183
- def nextset (self ):
180
+ async def nextset (self ):
184
181
"""Get the next query set"""
185
182
conn = self ._get_db ()
186
183
current_result = self ._result
187
184
if current_result is None or current_result is not conn ._result :
188
185
return
189
186
if not current_result .has_next :
190
187
return
191
- yield from conn .next_result ()
192
- yield from self ._do_get_result ()
188
+ await conn .next_result ()
189
+ await self ._do_get_result ()
193
190
return True
194
191
195
192
def _escape_args (self , args , conn ):
@@ -215,8 +212,7 @@ def mogrify(self, query, args=None):
215
212
query = query % self ._escape_args (args , conn )
216
213
return query
217
214
218
- @asyncio .coroutine
219
- def execute (self , query , args = None ):
215
+ async def execute (self , query , args = None ):
220
216
"""Executes the given operation
221
217
222
218
Executes the given operation substituting any markers with
@@ -231,21 +227,20 @@ def execute(self, query, args=None):
231
227
"""
232
228
conn = self ._get_db ()
233
229
234
- while (yield from self .nextset ()):
230
+ while (await self .nextset ()):
235
231
pass
236
232
237
233
if args is not None :
238
234
query = query % self ._escape_args (args , conn )
239
235
240
- yield from self ._query (query )
236
+ await self ._query (query )
241
237
self ._executed = query
242
238
if self ._echo :
243
239
logger .info (query )
244
240
logger .info ("%r" , args )
245
241
return self ._rowcount
246
242
247
- @asyncio .coroutine
248
- def executemany (self , query , args ):
243
+ async def executemany (self , query , args ):
249
244
"""Execute the given operation multiple times
250
245
251
246
The executemany() method will execute the operation iterating
@@ -259,7 +254,7 @@ def executemany(self, query, args):
259
254
('John', '555-003')
260
255
]
261
256
stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')"
262
- yield from cursor.executemany(stmt, data)
257
+ await cursor.executemany(stmt, data)
263
258
264
259
INSERT or REPLACE statements are optimized by batching the data,
265
260
that is using the MySQL multiple rows syntax.
@@ -280,20 +275,19 @@ def executemany(self, query, args):
280
275
q_values = m .group (2 ).rstrip ()
281
276
q_postfix = m .group (3 ) or ''
282
277
assert q_values [0 ] == '(' and q_values [- 1 ] == ')'
283
- return (yield from self ._do_execute_many (
278
+ return (await self ._do_execute_many (
284
279
q_prefix , q_values , q_postfix , args , self .max_stmt_length ,
285
280
self ._get_db ().encoding ))
286
281
else :
287
282
rows = 0
288
283
for arg in args :
289
- yield from self .execute (query , arg )
284
+ await self .execute (query , arg )
290
285
rows += self ._rowcount
291
286
self ._rowcount = rows
292
287
return self ._rowcount
293
288
294
- @asyncio .coroutine
295
- def _do_execute_many (self , prefix , values , postfix , args , max_stmt_length ,
296
- encoding ):
289
+ async def _do_execute_many (self , prefix , values , postfix , args ,
290
+ max_stmt_length , encoding ):
297
291
conn = self ._get_db ()
298
292
escape = self ._escape_args
299
293
if isinstance (prefix , str ):
@@ -312,19 +306,18 @@ def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length,
312
306
if isinstance (v , str ):
313
307
v = v .encode (encoding , 'surrogateescape' )
314
308
if len (sql ) + len (v ) + len (postfix ) + 1 > max_stmt_length :
315
- r = yield from self .execute (sql + postfix )
309
+ r = await self .execute (sql + postfix )
316
310
rows += r
317
311
sql = bytearray (prefix )
318
312
else :
319
313
sql += b','
320
314
sql += v
321
- r = yield from self .execute (sql + postfix )
315
+ r = await self .execute (sql + postfix )
322
316
rows += r
323
317
self ._rowcount = rows
324
318
return rows
325
319
326
- @asyncio .coroutine
327
- def callproc (self , procname , args = ()):
320
+ async def callproc (self , procname , args = ()):
328
321
"""Execute stored procedure procname with args
329
322
330
323
Compatibility warning: PEP-249 specifies that any modified
@@ -357,12 +350,12 @@ def callproc(self, procname, args=()):
357
350
358
351
for index , arg in enumerate (args ):
359
352
q = "SET @_%s_%d=%s" % (procname , index , conn .escape (arg ))
360
- yield from self ._query (q )
361
- yield from self .nextset ()
353
+ await self ._query (q )
354
+ await self .nextset ()
362
355
363
356
_args = ',' .join ('@_%s_%d' % (procname , i ) for i in range (len (args )))
364
357
q = "CALL %s(%s)" % (procname , _args )
365
- yield from self ._query (q )
358
+ await self ._query (q )
366
359
self ._executed = q
367
360
return args
368
361
@@ -454,15 +447,13 @@ def scroll(self, value, mode='relative'):
454
447
fut .set_result (None )
455
448
return fut
456
449
457
- @asyncio .coroutine
458
- def _query (self , q ):
450
+ async def _query (self , q ):
459
451
conn = self ._get_db ()
460
452
self ._last_executed = q
461
- yield from conn .query (q )
462
- yield from self ._do_get_result ()
453
+ await conn .query (q )
454
+ await self ._do_get_result ()
463
455
464
- @asyncio .coroutine
465
- def _do_get_result (self ):
456
+ async def _do_get_result (self ):
466
457
conn = self ._get_db ()
467
458
self ._rownumber = 0
468
459
self ._result = result = conn ._result
@@ -472,13 +463,12 @@ def _do_get_result(self):
472
463
self ._rows = result .rows
473
464
474
465
if result .warning_count > 0 :
475
- yield from self ._show_warnings (conn )
466
+ await self ._show_warnings (conn )
476
467
477
- @asyncio .coroutine
478
- def _show_warnings (self , conn ):
468
+ async def _show_warnings (self , conn ):
479
469
if self ._result and self ._result .has_next :
480
470
return
481
- ws = yield from conn .show_warnings ()
471
+ ws = await conn .show_warnings ()
482
472
if ws is None :
483
473
return
484
474
for w in ws :
@@ -496,36 +486,30 @@ def _show_warnings(self, conn):
496
486
ProgrammingError = ProgrammingError
497
487
NotSupportedError = NotSupportedError
498
488
499
- if PY_35 : # pragma: no branch
500
- @asyncio .coroutine
501
- def __aiter__ (self ):
502
- return self
489
+ async def __aiter__ (self ):
490
+ return self
503
491
504
- @asyncio .coroutine
505
- def __anext__ (self ):
506
- ret = yield from self .fetchone ()
507
- if ret is not None :
508
- return ret
509
- else :
510
- raise StopAsyncIteration # noqa
492
+ async def __anext__ (self ):
493
+ ret = await self .fetchone ()
494
+ if ret is not None :
495
+ return ret
496
+ else :
497
+ raise StopAsyncIteration # noqa
511
498
512
- @asyncio .coroutine
513
- def __aenter__ (self ):
514
- return self
499
+ async def __aenter__ (self ):
500
+ return self
515
501
516
- @asyncio .coroutine
517
- def __aexit__ (self , exc_type , exc_val , exc_tb ):
518
- yield from self .close ()
519
- return
502
+ async def __aexit__ (self , exc_type , exc_val , exc_tb ):
503
+ await self .close ()
504
+ return
520
505
521
506
522
507
class _DictCursorMixin :
523
508
# You can override this to use OrderedDict or other dict-like types.
524
509
dict_type = dict
525
510
526
- @asyncio .coroutine
527
- def _do_get_result (self ):
528
- yield from super ()._do_get_result ()
511
+ async def _do_get_result (self ):
512
+ await super ()._do_get_result ()
529
513
fields = []
530
514
if self ._description :
531
515
for f in self ._result .fields :
@@ -563,61 +547,55 @@ class SSCursor(Cursor):
563
547
possible to scroll backwards, as only the current row is held in memory.
564
548
"""
565
549
566
- @asyncio .coroutine
567
- def close (self ):
550
+ async def close (self ):
568
551
conn = self ._connection
569
552
if conn is None :
570
553
return
571
554
572
555
if self ._result is not None and self ._result is conn ._result :
573
- yield from self ._result ._finish_unbuffered_query ()
556
+ await self ._result ._finish_unbuffered_query ()
574
557
575
558
try :
576
- while (yield from self .nextset ()):
559
+ while (await self .nextset ()):
577
560
pass
578
561
finally :
579
562
self ._connection = None
580
563
581
- @asyncio .coroutine
582
- def _query (self , q ):
564
+ async def _query (self , q ):
583
565
conn = self ._get_db ()
584
566
self ._last_executed = q
585
- yield from conn .query (q , unbuffered = True )
586
- yield from self ._do_get_result ()
567
+ await conn .query (q , unbuffered = True )
568
+ await self ._do_get_result ()
587
569
return self ._rowcount
588
570
589
- @asyncio .coroutine
590
- def _read_next (self ):
571
+ async def _read_next (self ):
591
572
"""Read next row """
592
- row = yield from self ._result ._read_rowdata_packet_unbuffered ()
573
+ row = await self ._result ._read_rowdata_packet_unbuffered ()
593
574
row = self ._conv_row (row )
594
575
return row
595
576
596
- @asyncio .coroutine
597
- def fetchone (self ):
577
+ async def fetchone (self ):
598
578
""" Fetch next row """
599
579
self ._check_executed ()
600
- row = yield from self ._read_next ()
580
+ row = await self ._read_next ()
601
581
if row is None :
602
582
return
603
583
self ._rownumber += 1
604
584
return row
605
585
606
- @asyncio .coroutine
607
- def fetchall (self ):
586
+ async def fetchall (self ):
608
587
"""Fetch all, as per MySQLdb. Pretty useless for large queries, as
609
588
it is buffered.
610
589
"""
611
590
rows = []
612
591
while True :
613
- row = yield from self .fetchone ()
592
+ row = await self .fetchone ()
614
593
if row is None :
615
594
break
616
595
rows .append (row )
617
596
return rows
618
597
619
- @asyncio .coroutine
620
- def fetchmany (self , size = None ):
598
+ async def fetchmany (self , size = None ):
621
599
"""Returns the next set of rows of a query result, returning a
622
600
list of tuples. When no more rows are available, it returns an
623
601
empty list.
@@ -634,15 +612,14 @@ def fetchmany(self, size=None):
634
612
635
613
rows = []
636
614
for i in range (size ):
637
- row = yield from self ._read_next ()
615
+ row = await self ._read_next ()
638
616
if row is None :
639
617
break
640
618
rows .append (row )
641
619
self ._rownumber += 1
642
620
return rows
643
621
644
- @asyncio .coroutine
645
- def scroll (self , value , mode = 'relative' ):
622
+ async def scroll (self , value , mode = 'relative' ):
646
623
"""Scroll the cursor in the result set to a new position
647
624
according to mode . Same as :meth:`Cursor.scroll`, but move cursor
648
625
on server side one by one row. If you want to move 20 rows forward
@@ -661,7 +638,7 @@ def scroll(self, value, mode='relative'):
661
638
"by this cursor" )
662
639
663
640
for _ in range (value ):
664
- yield from self ._read_next ()
641
+ await self ._read_next ()
665
642
self ._rownumber += value
666
643
elif mode == 'absolute' :
667
644
if value < self ._rownumber :
@@ -670,7 +647,7 @@ def scroll(self, value, mode='relative'):
670
647
671
648
end = value - self ._rownumber
672
649
for _ in range (end ):
673
- yield from self ._read_next ()
650
+ await self ._read_next ()
674
651
self ._rownumber = value
675
652
else :
676
653
raise ProgrammingError ("unknown scroll mode %s" % mode )
0 commit comments