11import asyncio
2+ import json
3+ import zlib
24
35from contextlib import suppress
4- from psycopg import sql , Connection , OperationalError , AsyncConnection
6+ from psycopg import sql , Connection , OperationalError , AsyncConnection , InternalError
57from typing import Callable
68
79from core .data .impl .nodeimpl import NodeImpl
810
911
1012class PubSub :
13+ _SENTINEL = object ()
1114
1215 def __init__ (self , node : NodeImpl , name : str , url : str , handler : Callable ):
1316 self .node = node
@@ -23,58 +26,64 @@ def __init__(self, node: NodeImpl, name: str, url: str, handler: Callable):
2326 self .write_worker = asyncio .create_task (self ._process_write ())
2427
2528 def create_table (self ):
29+ lock_key = zlib .crc32 (f"PubSubDDL:{ self .name } " .encode ("utf-8" ))
30+
2631 with Connection .connect (self .url , autocommit = True ) as conn :
27- query = sql .SQL ("""
28- CREATE TABLE IF NOT EXISTS {table} (
29- id SERIAL PRIMARY KEY,
30- guild_id BIGINT NOT NULL,
31- node TEXT NOT NULL,
32- time TIMESTAMP NOT NULL DEFAULT (now() AT TIME ZONE 'utc'),
33- data JSON
34- )
35- """ ).format (table = sql .Identifier (self .name ))
36- conn .execute (query )
37- query = sql .SQL ("""
38- CREATE TABLE IF NOT EXISTS {table} (
39- id SERIAL PRIMARY KEY,
40- guild_id BIGINT NOT NULL,
41- node TEXT NOT NULL,
42- time TIMESTAMP NOT NULL DEFAULT (now() AT TIME ZONE 'utc'),
43- data JSON
44- )
45- """ ).format (table = sql .Identifier (self .name ))
46- conn .execute (query )
47- query = sql .SQL ("""
48- CREATE OR REPLACE FUNCTION {func}()
49- RETURNS trigger
50- AS $$
51- BEGIN
52- PERFORM pg_notify({name}, NEW.node);
53- RETURN NEW;
54- END;
55- $$ LANGUAGE plpgsql;
56- """ ).format (func = sql .Identifier (self .name + '_notify' ), name = sql .Literal (self .name ))
57- conn .execute (query )
58- query = sql .SQL ("""
59- DO $$
60- BEGIN
61- IF NOT EXISTS (
62- SELECT 1
63- FROM pg_trigger
64- WHERE tgname = {trigger_name}
65- AND tgrelid = {name}::regclass
66- ) THEN
67- CREATE TRIGGER {trigger}
68- AFTER INSERT OR UPDATE ON {table}
69- FOR EACH ROW
70- EXECUTE PROCEDURE {func}();
71- END IF;
72- END;
73- $$;
74- """ ).format (table = sql .Identifier (self .name ), trigger = sql .Identifier (self .name + '_trigger' ),
75- func = sql .Identifier (self .name + '_notify' ), name = sql .Literal (self .name ),
76- trigger_name = sql .Literal (self .name + '_trigger' ))
77- conn .execute (query )
32+ try :
33+ conn .execute ("SELECT pg_advisory_lock(%s)" , (lock_key ,))
34+
35+ query = sql .SQL ("""
36+ CREATE TABLE IF NOT EXISTS {table} (
37+ id SERIAL PRIMARY KEY,
38+ guild_id BIGINT NOT NULL,
39+ node TEXT NOT NULL,
40+ time TIMESTAMP NOT NULL DEFAULT (now() AT TIME ZONE 'utc'),
41+ data JSONB
42+ )
43+ """ ).format (table = sql .Identifier (self .name ))
44+ conn .execute (query )
45+ query = sql .SQL ("""
46+ CREATE OR REPLACE FUNCTION {func}()
47+ RETURNS trigger
48+ AS $$
49+ BEGIN
50+ PERFORM pg_notify({name}, json_build_object(
51+ 'row_id', NEW.id,
52+ 'guild_id', NEW.guild_id,
53+ 'node', NEW.node
54+ )::text);
55+ RETURN NEW;
56+ END;
57+ $$ LANGUAGE plpgsql;
58+ """ ).format (func = sql .Identifier (self .name + '_notify' ), name = sql .Literal (self .name ))
59+ conn .execute (query )
60+ query = sql .SQL ("""
61+ DO $$
62+ BEGIN
63+ IF NOT EXISTS (
64+ SELECT 1
65+ FROM pg_trigger
66+ WHERE tgname = {trigger_name}
67+ AND tgrelid = {name}::regclass
68+ ) THEN
69+ CREATE TRIGGER {trigger}
70+ AFTER INSERT OR UPDATE ON {table}
71+ FOR EACH ROW
72+ EXECUTE PROCEDURE {func}();
73+ END IF;
74+ END;
75+ $$;
76+ """ ).format (table = sql .Identifier (self .name ), trigger = sql .Identifier (self .name + '_trigger' ),
77+ func = sql .Identifier (self .name + '_notify' ), name = sql .Literal (self .name ),
78+ trigger_name = sql .Literal (self .name + '_trigger' ))
79+ conn .execute (query )
80+
81+ except InternalError as ex :
82+ self .log .exception (ex )
83+ raise
84+ finally :
85+ with suppress (Exception ):
86+ conn .execute ("SELECT pg_advisory_unlock(%s)" , (lock_key ,))
7887
7988 async def _process_write (self ):
8089 await asyncio .sleep (1 ) # Ensure the rest of __init__ has finished
@@ -83,27 +92,28 @@ async def _process_write(self):
8392 async with await AsyncConnection .connect (self .url , autocommit = True ) as conn :
8493 while not self ._stop_event .is_set ():
8594 message = await self .write_queue .get ()
86- if not message :
87- return
88- try :
89- query = sql .SQL ("""
90- INSERT INTO {table} (guild_id, node, data)
91- VALUES (%(guild_id)s, %(node)s, %(data)s)
92- """ ).format (table = sql .Identifier (self .name ))
93- await conn .execute (query , message )
94- finally :
95+ if message == self ._SENTINEL :
9596 # Notify the queue that the message has been processed.
9697 self .write_queue .task_done ()
98+ return
99+ query = sql .SQL ("""
100+ INSERT INTO {table} (guild_id, node, data)
101+ VALUES (%(guild_id)s, %(node)s, %(data)s)
102+ """ ).format (table = sql .Identifier (self .name ))
103+ await conn .execute (query , message )
104+ # Notify the queue that the message has been processed.
105+ self .write_queue .task_done ()
97106
98107 async def _process_read (self ):
99- async def do_read ():
108+ async def do_read (id : int ):
100109 ids_to_delete = []
101110 query = sql .SQL ("""
102111 SELECT id, data FROM {table}
103- WHERE guild_id = %(guild_id)s AND node = %(node)s
112+ WHERE id <= %(id)s AND guild_id = %(guild_id)s AND node = %(node)s
104113 ORDER BY id
105114 """ ).format (table = sql .Identifier (self .name ))
106115 cursor = await conn .execute (query , {
116+ 'id' : id ,
107117 'guild_id' : self .node .guild_id ,
108118 'node' : "Master" if self .node .master else self .node .name
109119 })
@@ -124,17 +134,14 @@ async def do_read():
124134 with suppress (OperationalError ):
125135 async with await AsyncConnection .connect (self .url , autocommit = True ) as conn :
126136 while not self ._stop_event .is_set ():
127- try :
128- # we will read every 5s independent if there is data in the queue or not
129- if not await asyncio .wait_for (self .read_queue .get (), timeout = 5.0 ):
130- return
131- try :
132- await do_read ()
133- finally :
134- # Notify the queue that the message has been processed.
135- self .read_queue .task_done ()
136- except (TimeoutError , asyncio .TimeoutError ):
137- await do_read ()
137+ row_id = await self .read_queue .get ()
138+ if row_id == self ._SENTINEL :
139+ # Notify the queue that the message has been processed.
140+ self .read_queue .task_done ()
141+ return
142+ await do_read (row_id )
143+ # Notify the queue that the message has been processed.
144+ self .read_queue .task_done ()
138145
139146 async def subscribe (self ):
140147 while not self ._stop_event .is_set ():
@@ -149,9 +156,11 @@ async def subscribe(self):
149156 self .log .debug (f'- { self .name .title ()} stopped.' )
150157 await gen .aclose ()
151158 return
152- node = n .payload
153- if node == self .node .name or (self .node .master and node == 'Master' ):
154- self .read_queue .put_nowait (n .payload )
159+ data = json .loads (n .payload )
160+ if data ['guild_id' ] == self .node .guild_id and data ['node' ] == self .node .name or (
161+ self .node .master and data ['node' ] == 'Master'
162+ ):
163+ self .read_queue .put_nowait (data ['row_id' ])
155164 await asyncio .sleep (1 )
156165
157166 async def publish (self , data : dict ) -> None :
@@ -178,7 +187,7 @@ async def clear(self):
178187
179188 async def close (self ):
180189 self ._stop_event .set ()
181- self .write_queue .put_nowait (None )
190+ self .write_queue .put_nowait (self . _SENTINEL )
182191 await self .write_worker
183- self .read_queue .put_nowait (None )
192+ self .read_queue .put_nowait (self . _SENTINEL )
184193 await self .read_worker
0 commit comments