@@ -29,7 +29,7 @@ class Connection:
29
29
__slots__ = ('_protocol' , '_transport' , '_loop' , '_types_stmt' ,
30
30
'_type_by_name_stmt' , '_top_xact' , '_uid' , '_aborted' ,
31
31
'_stmt_cache_max_size' , '_stmt_cache' , '_stmts_to_close' ,
32
- '_addr' , '_opts' , '_command_timeout' )
32
+ '_addr' , '_opts' , '_command_timeout' , '_listeners' )
33
33
34
34
def __init__ (self , protocol , transport , loop , addr , opts , * ,
35
35
statement_cache_size , command_timeout ):
@@ -51,7 +51,44 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
51
51
52
52
self ._command_timeout = command_timeout
53
53
54
+ self ._listeners = {}
55
+
56
+ async def add_listener (self , channel , callback ):
57
+ """Add a listener for Postgres notifications.
58
+
59
+ :param str channel: Channel to listen on.
60
+ :param callable callback:
61
+ A callable receiving the following arguments:
62
+ **connection**: a Connection the callback is registered with;
63
+ **pid**: PID of the Postgres server that sent the notification;
64
+ **channel**: name of the channel the notification was sent to;
65
+ **payload**: the payload.
66
+ """
67
+ if channel not in self ._listeners :
68
+ await self .fetch ('LISTEN {}' .format (channel ))
69
+ self ._listeners [channel ] = set ()
70
+ self ._listeners [channel ].add (callback )
71
+
72
+ async def remove_listener (self , channel , callback ):
73
+ """Remove a listening callback on the specified channel."""
74
+ if channel not in self ._listeners :
75
+ return
76
+ if callback not in self ._listeners [channel ]:
77
+ return
78
+ self ._listeners [channel ].remove (callback )
79
+ if not self ._listeners [channel ]:
80
+ del self ._listeners [channel ]
81
+ await self .fetch ('UNLISTEN {}' .format (channel ))
82
+
83
+ def get_server_pid (self ):
84
+ """Return the PID of the Postgres server the connection is bound to."""
85
+ return self ._protocol .get_server_pid ()
86
+
54
87
def get_settings (self ):
88
+ """Return connection settings.
89
+
90
+ :return: :class:`~asyncpg.ConnectionSettings`.
91
+ """
55
92
return self ._protocol .get_settings ()
56
93
57
94
def transaction (self , * , isolation = 'read_committed' , readonly = False ,
@@ -269,17 +306,20 @@ async def close(self):
269
306
if self .is_closed ():
270
307
return
271
308
self ._close_stmts ()
309
+ self ._listeners = {}
272
310
self ._aborted = True
273
311
protocol = self ._protocol
274
312
await protocol .close ()
275
313
276
314
def terminate (self ):
277
315
"""Terminate the connection without waiting for pending data."""
278
316
self ._close_stmts ()
317
+ self ._listeners = {}
279
318
self ._aborted = True
280
319
self ._protocol .abort ()
281
320
282
321
async def reset (self ):
322
+ self ._listeners = {}
283
323
await self .execute ('''
284
324
SET SESSION AUTHORIZATION DEFAULT;
285
325
RESET ALL;
@@ -351,6 +391,20 @@ async def cancel():
351
391
352
392
self ._loop .create_task (cancel ())
353
393
394
+ def _notify (self , pid , channel , payload ):
395
+ if channel not in self ._listeners :
396
+ return
397
+
398
+ for cb in self ._listeners [channel ]:
399
+ try :
400
+ cb (self , pid , channel , payload )
401
+ except Exception as ex :
402
+ self ._loop .call_exception_handler ({
403
+ 'message' : 'Unhandled exception in asyncpg notification '
404
+ 'listener callback {!r}' .format (cb ),
405
+ 'exception' : ex
406
+ })
407
+
354
408
355
409
async def connect (dsn = None , * ,
356
410
host = None , port = None ,
0 commit comments