11import asyncio
2- import json
3- import zlib
42
53from contextlib import suppress
6- from psycopg import sql , Connection , OperationalError , AsyncConnection , InternalError
4+ from psycopg import sql , Connection , OperationalError , AsyncConnection
75from typing import Callable
86
97from core .data .impl .nodeimpl import NodeImpl
108
119
1210class PubSub :
13- _SENTINEL = object ()
1411
1512 def __init__ (self , node : NodeImpl , name : str , url : str , handler : Callable ):
1613 self .node = node
@@ -92,28 +89,27 @@ async def _process_write(self):
9289 async with await AsyncConnection .connect (self .url , autocommit = True ) as conn :
9390 while not self ._stop_event .is_set ():
9491 message = await self .write_queue .get ()
95- if message == self ._SENTINEL :
92+ if not message :
93+ return
94+ try :
95+ query = sql .SQL ("""
96+ INSERT INTO {table} (guild_id, node, data)
97+ VALUES (%(guild_id)s, %(node)s, %(data)s)
98+ """ ).format (table = sql .Identifier (self .name ))
99+ await conn .execute (query , message )
100+ finally :
96101 # Notify the queue that the message has been processed.
97102 self .write_queue .task_done ()
98- return
99- query = sql .SQL ("""
100- INSERT INTO {table} (guild_id, node, data)
101- VALUES (%(guild_id)s, %(node)s, %(data)s)
102- """ ).format (table = sql .Identifier (self .name ))
103- await conn .execute (query , message )
104- # Notify the queue that the message has been processed.
105- self .write_queue .task_done ()
106103
107104 async def _process_read (self ):
108- async def do_read (id : int ):
105+ async def do_read ():
109106 ids_to_delete = []
110107 query = sql .SQL ("""
111108 SELECT id, data FROM {table}
112- WHERE id <= %(id)s AND guild_id = %(guild_id)s AND node = %(node)s
109+ WHERE guild_id = %(guild_id)s AND node = %(node)s
113110 ORDER BY id
114111 """ ).format (table = sql .Identifier (self .name ))
115112 cursor = await conn .execute (query , {
116- 'id' : id ,
117113 'guild_id' : self .node .guild_id ,
118114 'node' : "Master" if self .node .master else self .node .name
119115 })
@@ -134,46 +130,34 @@ async def do_read(id: int):
134130 with suppress (OperationalError ):
135131 async with await AsyncConnection .connect (self .url , autocommit = True ) as conn :
136132 while not self ._stop_event .is_set ():
137- row_id = await self .read_queue .get ()
138- if row_id == self ._SENTINEL :
139- # Notify the queue that the message has been processed.
140- self .read_queue .task_done ()
141- return
142- await do_read (row_id )
143- # Notify the queue that the message has been processed.
144- self .read_queue .task_done ()
133+ try :
134+ # we will read every 5s independent if there is data in the queue or not
135+ if not await asyncio .wait_for (self .read_queue .get (), timeout = 5.0 ):
136+ return
137+ try :
138+ await do_read ()
139+ finally :
140+ # Notify the queue that the message has been processed.
141+ self .read_queue .task_done ()
142+ except (TimeoutError , asyncio .TimeoutError ):
143+ await do_read ()
145144
146145 async def subscribe (self ):
147146 while not self ._stop_event .is_set ():
148147 with suppress (OperationalError ):
149148 async with await AsyncConnection .connect (self .url , autocommit = True ) as conn :
150149 async with conn .cursor () as cursor :
151150 # preprocess all rows that might be there
152- async for row in await cursor .execute (sql .SQL ("""
153- WITH to_delete AS (
154- SELECT id
155- FROM {table}
156- WHERE guild_id = %s
157- AND node = %s
158- ORDER BY id
159- )
160- DELETE FROM {table}
161- WHERE id IN (SELECT id FROM to_delete)
162- RETURNING id
163- """ ).format (table = sql .Identifier (self .name )), (self .node .guild_id , self .node .name )):
164- self .read_queue .put_nowait (row [0 ])
165151 await cursor .execute (sql .SQL ("LISTEN {table}" ).format (table = sql .Identifier (self .name )))
166152 gen = conn .notifies ()
167153 async for n in gen :
168154 if self ._stop_event .is_set ():
169155 self .log .debug (f'- { self .name .title ()} stopped.' )
170156 await gen .aclose ()
171157 return
172- data = json .loads (n .payload )
173- if data ['guild_id' ] == self .node .guild_id and data ['node' ] == self .node .name or (
174- self .node .master and data ['node' ] == 'Master'
175- ):
176- self .read_queue .put_nowait (data ['row_id' ])
158+ node = n .payload
159+ if node == self .node .name or (self .node .master and node == 'Master' ):
160+ self .read_queue .put_nowait (n .payload )
177161 await asyncio .sleep (1 )
178162
179163 async def publish (self , data : dict ) -> None :
@@ -200,7 +184,7 @@ async def clear(self):
200184
201185 async def close (self ):
202186 self ._stop_event .set ()
203- self .write_queue .put_nowait (self . _SENTINEL )
187+ self .write_queue .put_nowait (None )
204188 await self .write_worker
205- self .read_queue .put_nowait (self . _SENTINEL )
189+ self .read_queue .put_nowait (None )
206190 await self .read_worker
0 commit comments