1010from collections .abc import Mapping
1111
1212import psycopg2
13- from psycopg2 import extras
14- from psycopg2 .extensions import POLL_ERROR , POLL_OK , POLL_READ , POLL_WRITE
13+ import psycopg2 . extensions
14+ import psycopg2 .extras
1515
1616from .cursor import Cursor
17- from .utils import _ContextManager , get_running_loop
17+ from .utils import _ContextManager , create_completed_future , get_running_loop
1818
1919__all__ = ('connect' ,)
2020
@@ -71,7 +71,7 @@ def __init__(
7171 self ._enable_json = enable_json
7272 self ._enable_hstore = enable_hstore
7373 self ._enable_uuid = enable_uuid
74- self ._loop = get_running_loop (kwargs . pop ( 'loop' , None ) is not None )
74+ self ._loop = get_running_loop ()
7575 self ._waiter = self ._loop .create_future ()
7676
7777 kwargs ['async_' ] = kwargs .pop ('async' , True )
@@ -84,7 +84,6 @@ def __init__(
8484 self ._last_usage = self ._loop .time ()
8585 self ._writing = False
8686 self ._echo = echo
87- self ._cursor_instance = None
8887 self ._notifies = asyncio .Queue ()
8988 self ._weakref = weakref .ref (self )
9089 self ._loop .add_reader (self ._fileno , self ._ready , self ._weakref )
@@ -136,21 +135,21 @@ def _ready(weak_self):
136135 if waiter is not None and not waiter .done ():
137136 waiter .set_exception (
138137 psycopg2 .OperationalError ("Connection closed" ))
139- if state == POLL_OK :
138+ if state == psycopg2 . extensions . POLL_OK :
140139 if self ._writing :
141140 self ._loop .remove_writer (self ._fileno )
142141 self ._writing = False
143142 if waiter is not None and not waiter .done ():
144143 waiter .set_result (None )
145- elif state == POLL_READ :
144+ elif state == psycopg2 . extensions . POLL_READ :
146145 if self ._writing :
147146 self ._loop .remove_writer (self ._fileno )
148147 self ._writing = False
149- elif state == POLL_WRITE :
148+ elif state == psycopg2 . extensions . POLL_WRITE :
150149 if not self ._writing :
151150 self ._loop .add_writer (self ._fileno , self ._ready , weak_self )
152151 self ._writing = True
153- elif state == POLL_ERROR :
152+ elif state == psycopg2 . extensions . POLL_ERROR :
154153 self ._fatal_error ("Fatal error on aiopg connection: "
155154 "POLL_ERROR from underlying .poll() call" )
156155 else :
@@ -209,9 +208,8 @@ def cursor(self, name=None, cursor_factory=None,
209208 *name*, *scrollable* and *withhold* parameters are not supported by
210209 psycopg in asynchronous mode.
211210
212- NOTE: as of [TODO] any previously created created cursor from this
213- connection will be closed
214211 """
212+
215213 self ._last_usage = self ._loop .time ()
216214 coro = self ._cursor (name = name , cursor_factory = cursor_factory ,
217215 scrollable = scrollable , withhold = withhold ,
@@ -222,24 +220,17 @@ async def _cursor(self, name=None, cursor_factory=None,
222220 scrollable = None , withhold = False , timeout = None ,
223221 isolation_level = None ):
224222
225- if not self .closed_cursor :
226- warnings .warn (f'You can only have one cursor per connection. '
227- f'The cursor for connection will be closed forcibly'
228- f' { self !r} .' , ResourceWarning )
229-
230- self .free_cursor ()
231-
232223 if timeout is None :
233224 timeout = self ._timeout
234225
235226 impl = await self ._cursor_impl (name = name ,
236227 cursor_factory = cursor_factory ,
237228 scrollable = scrollable ,
238229 withhold = withhold )
239- self . _cursor_instance = Cursor (
230+ cursor = Cursor (
240231 self , impl , timeout , self ._echo , isolation_level
241232 )
242- return self . _cursor_instance
233+ return cursor
243234
244235 async def _cursor_impl (self , name = None , cursor_factory = None ,
245236 scrollable = None , withhold = False ):
@@ -262,29 +253,14 @@ def _close(self):
262253 self ._loop .remove_writer (self ._fileno )
263254
264255 self ._conn .close ()
265- self .free_cursor ()
266256
267257 if self ._waiter is not None and not self ._waiter .done ():
268258 self ._waiter .set_exception (
269259 psycopg2 .OperationalError ("Connection closed" ))
270260
271- @property
272- def closed_cursor (self ):
273- if not self ._cursor_instance :
274- return True
275-
276- return bool (self ._cursor_instance .closed )
277-
278- def free_cursor (self ):
279- if not self .closed_cursor :
280- self ._cursor_instance .close ()
281- self ._cursor_instance = None
282-
283261 def close (self ):
284262 self ._close ()
285- ret = self ._loop .create_future ()
286- ret .set_result (None )
287- return ret
263+ return create_completed_future (self ._loop )
288264
289265 @property
290266 def closed (self ):
@@ -455,7 +431,6 @@ def __repr__(self):
455431 f'isexecuting={ self ._isexecuting ()} , '
456432 f'closed={ self .closed } , '
457433 f'echo={ self .echo } , '
458- f'cursor={ self ._cursor_instance } '
459434 f'>'
460435 )
461436
@@ -505,18 +480,18 @@ async def _get_oids(self):
505480 async def _connect (self ):
506481 try :
507482 await self ._poll (self ._waiter , self ._timeout )
508- except Exception :
509- self .close ()
483+ except BaseException :
484+ await asyncio . shield ( self .close () )
510485 raise
511486 if self ._enable_json :
512- extras .register_default_json (self ._conn )
487+ psycopg2 . extras .register_default_json (self ._conn )
513488 if self ._enable_uuid :
514- extras .register_uuid (conn_or_curs = self ._conn )
489+ psycopg2 . extras .register_uuid (conn_or_curs = self ._conn )
515490 if self ._enable_hstore :
516491 oids = await self ._get_oids ()
517492 if oids is not None :
518493 oid , array_oid = oids
519- extras .register_hstore (
494+ psycopg2 . extras .register_hstore (
520495 self ._conn ,
521496 oid = oid ,
522497 array_oid = array_oid
@@ -531,4 +506,4 @@ async def __aenter__(self):
531506 return self
532507
533508 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
534- self .close ()
509+ await self .close ()
0 commit comments