1414from sqlspec .utils .serializers import from_json , to_json
1515
1616if TYPE_CHECKING :
17+ from psqlpy import Listener
18+
1719 from sqlspec .adapters .psqlpy .config import PsqlpyConfig
1820
1921
2729
2830
2931class PsqlpyEventsBackend :
30- """Native LISTEN/NOTIFY backend for psqlpy adapters."""
32+ """Native LISTEN/NOTIFY backend for psqlpy adapters.
33+
34+ Uses psqlpy's Listener API which provides a dedicated connection for
35+ receiving PostgreSQL NOTIFY messages via callbacks or async iteration.
36+ """
3137
3238 supports_sync = False
3339 supports_async = True
@@ -39,8 +45,8 @@ def __init__(self, config: "PsqlpyConfig") -> None:
3945 raise ImproperConfigurationError (msg )
4046 self ._config = config
4147 self ._runtime = config .get_observability_runtime ()
42- self ._listen_connection_async : Any | None = None
43- self ._listen_connection_async_cm : Any | None = None
48+ self ._listener : Any | None = None
49+ self ._listener_started : bool = False
4450
4551 async def publish_async (self , channel : str , payload : "dict[str, Any]" , metadata : "dict[str, Any] | None" = None ) -> str :
4652 event_id = uuid .uuid4 ().hex
@@ -56,25 +62,36 @@ def publish(self, *_: Any, **__: Any) -> str:
5662 raise ImproperConfigurationError (msg )
5763
5864 async def dequeue_async (self , channel : str , poll_interval : float ) -> EventMessage | None :
59- connection = await self ._ensure_async_listener (channel )
60- await connection . execute ( f"LISTEN { channel } " )
61- future : asyncio . Future [ str ] = asyncio .get_running_loop (). create_future ()
65+ listener = await self ._ensure_listener (channel )
66+ received_payload : str | None = None
67+ event = asyncio .Event ()
6268
63- def _callback (_conn : Any , _pid : int , notified_channel : str , payload : str ) -> None :
64- if notified_channel != channel or future .done ():
65- return
66- future .set_result (payload )
69+ async def _callback (
70+ _connection : Any ,
71+ payload : str ,
72+ notified_channel : str ,
73+ _process_id : int ,
74+ ) -> None :
75+ nonlocal received_payload
76+ if notified_channel == channel and received_payload is None :
77+ received_payload = payload
78+ event .set ()
6779
68- connection .add_listener (channel , _callback )
69- try :
70- try :
71- payload = await asyncio .wait_for (future , timeout = poll_interval )
72- except asyncio .TimeoutError :
73- return None
74- return self ._decode_payload (channel , payload )
75- finally :
76- with contextlib .suppress (Exception ):
77- connection .remove_listener (channel , _callback )
80+ await listener .add_callback (channel = channel , callback = _callback )
81+
82+ if not self ._listener_started :
83+ listener .listen ()
84+ self ._listener_started = True
85+ await asyncio .sleep (0.05 )
86+
87+ with contextlib .suppress (asyncio .TimeoutError ):
88+ await asyncio .wait_for (event .wait (), timeout = poll_interval )
89+
90+ await listener .clear_channel_callbacks (channel = channel )
91+
92+ if received_payload is not None :
93+ return self ._decode_payload (channel , received_payload )
94+ return None
7895
7996 def dequeue (self , * _ : Any , ** __ : Any ) -> EventMessage | None :
8097 msg = "dequeue is not supported for sync Psqlpy backends"
@@ -87,15 +104,21 @@ def ack(self, _event_id: str) -> None:
87104 msg = "ack is not supported for sync Psqlpy backends"
88105 raise ImproperConfigurationError (msg )
89106
90- async def _ensure_async_listener (self , channel : str ) -> Any :
91- if self ._listen_connection_async is None :
92- self ._listen_connection_async_cm = self ._config .provide_connection ()
93- self ._listen_connection_async = await self ._listen_connection_async_cm .__aenter__ ()
94- try :
95- await self ._listen_connection_async .set_autocommit (True ) # type: ignore[attr-defined]
96- except Exception :
97- pass
98- return self ._listen_connection_async
107+ async def _ensure_listener (self , channel : str ) -> "Listener" :
108+ if self ._listener is None :
109+ pool = await self ._config .provide_pool ()
110+ self ._listener = pool .listener ()
111+ await self ._listener .startup ()
112+ return self ._listener
113+
114+ async def shutdown (self ) -> None :
115+ """Shutdown the listener and release resources."""
116+ if self ._listener is not None :
117+ if self ._listener_started :
118+ self ._listener .abort_listen ()
119+ self ._listener_started = False
120+ await self ._listener .shutdown ()
121+ self ._listener = None
99122
100123 @staticmethod
101124 def _encode_payload (event_id : str , payload : "dict[str, Any]" , metadata : "dict[str, Any] | None" ) -> str :
@@ -147,7 +170,11 @@ def _parse_timestamp(value: Any) -> datetime:
147170
148171
149172class PsqlpyHybridEventsBackend :
150- """Durable hybrid backend combining queue storage with LISTEN/NOTIFY wakeups."""
173+ """Durable hybrid backend combining queue storage with LISTEN/NOTIFY wakeups.
174+
175+ Uses psqlpy's Listener API for real-time notifications while persisting
176+ events to a durable queue table.
177+ """
151178
152179 supports_sync = False
153180 supports_async = True
@@ -160,6 +187,8 @@ def __init__(self, config: "PsqlpyConfig", queue_backend: QueueEventBackend) ->
160187 self ._config = config
161188 self ._queue = queue_backend
162189 self ._runtime = config .get_observability_runtime ()
190+ self ._listener : Any | None = None
191+ self ._listener_started : bool = False
163192
164193 async def publish_async (self , channel : str , payload : "dict[str, Any]" , metadata : "dict[str, Any] | None" = None ) -> str :
165194 event_id = uuid .uuid4 ().hex
@@ -172,31 +201,30 @@ def publish(self, *_: Any, **__: Any) -> str:
172201 raise ImproperConfigurationError (msg )
173202
174203 async def dequeue_async (self , channel : str , poll_interval : float ) -> EventMessage | None :
175- connection_cm = self ._config .provide_connection ()
176- connection = await connection_cm .__aenter__ ()
177- try :
178- listener = getattr (connection , "add_listener" , None )
179- if listener is None :
180- return await self ._queue .dequeue_async (channel )
181- future : asyncio .Future [str ] = asyncio .get_running_loop ().create_future ()
182-
183- def _callback (_conn : Any , _pid : int , notified_channel : str , payload : str ) -> None :
184- if notified_channel != channel or future .done ():
185- return
186- future .set_result (payload )
187-
188- listener (channel , _callback )
189- try :
190- await asyncio .wait_for (future , timeout = poll_interval )
191- except asyncio .TimeoutError :
192- return await self ._queue .dequeue_async (channel )
193- return await self ._queue .dequeue_async (channel )
194- finally :
195- with contextlib .suppress (Exception ):
196- remove = getattr (connection , "remove_listener" , None )
197- if remove :
198- remove (channel , _callback )
199- await connection_cm .__aexit__ (None , None , None )
204+ listener = await self ._ensure_listener (channel )
205+ event = asyncio .Event ()
206+
207+ async def _callback (
208+ _connection : Any ,
209+ _payload : str ,
210+ notified_channel : str ,
211+ _process_id : int ,
212+ ) -> None :
213+ if notified_channel == channel :
214+ event .set ()
215+
216+ await listener .add_callback (channel = channel , callback = _callback )
217+
218+ if not self ._listener_started :
219+ listener .listen ()
220+ self ._listener_started = True
221+ await asyncio .sleep (0.05 )
222+
223+ with contextlib .suppress (asyncio .TimeoutError ):
224+ await asyncio .wait_for (event .wait (), timeout = poll_interval )
225+
226+ await listener .clear_channel_callbacks (channel = channel )
227+ return await self ._queue .dequeue_async (channel )
200228
201229 async def ack_async (self , event_id : str ) -> None :
202230 await self ._queue .ack_async (event_id )
@@ -206,6 +234,22 @@ def ack(self, _event_id: str) -> None:
206234 msg = "ack is not supported for sync Psqlpy backends"
207235 raise ImproperConfigurationError (msg )
208236
237+ async def _ensure_listener (self , channel : str ) -> "Listener" :
238+ if self ._listener is None :
239+ pool = await self ._config .provide_pool ()
240+ self ._listener = pool .listener ()
241+ await self ._listener .startup ()
242+ return self ._listener
243+
244+ async def shutdown (self ) -> None :
245+ """Shutdown the listener and release resources."""
246+ if self ._listener is not None :
247+ if self ._listener_started :
248+ self ._listener .abort_listen ()
249+ self ._listener_started = False
250+ await self ._listener .shutdown ()
251+ self ._listener = None
252+
209253 async def _publish_durable_async (
210254 self , channel : str , event_id : str , payload : "dict[str, Any]" , metadata : "dict[str, Any] | None"
211255 ) -> None :
0 commit comments