Skip to content

Commit 575baec

Browse files
committed
Race condition with pre-existing rows.
1 parent c8db98a commit 575baec

File tree

1 file changed

+28
-44
lines changed

1 file changed

+28
-44
lines changed

core/pubsub.py

Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import asyncio
2-
import json
3-
import zlib
42

53
from contextlib import suppress
6-
from psycopg import sql, Connection, OperationalError, AsyncConnection, InternalError
4+
from psycopg import sql, Connection, OperationalError, AsyncConnection
75
from typing import Callable
86

97
from core.data.impl.nodeimpl import NodeImpl
108

119

1210
class PubSub:
13-
_SENTINEL = object()
1411

1512
def __init__(self, node: NodeImpl, name: str, url: str, handler: Callable):
1613
self.node = node
@@ -92,28 +89,27 @@ async def _process_write(self):
9289
async with await AsyncConnection.connect(self.url, autocommit=True) as conn:
9390
while not self._stop_event.is_set():
9491
message = await self.write_queue.get()
95-
if message == self._SENTINEL:
92+
if not message:
93+
return
94+
try:
95+
query = sql.SQL("""
96+
INSERT INTO {table} (guild_id, node, data)
97+
VALUES (%(guild_id)s, %(node)s, %(data)s)
98+
""").format(table=sql.Identifier(self.name))
99+
await conn.execute(query, message)
100+
finally:
96101
# Notify the queue that the message has been processed.
97102
self.write_queue.task_done()
98-
return
99-
query = sql.SQL("""
100-
INSERT INTO {table} (guild_id, node, data)
101-
VALUES (%(guild_id)s, %(node)s, %(data)s)
102-
""").format(table=sql.Identifier(self.name))
103-
await conn.execute(query, message)
104-
# Notify the queue that the message has been processed.
105-
self.write_queue.task_done()
106103

107104
async def _process_read(self):
108-
async def do_read(id: int):
105+
async def do_read():
109106
ids_to_delete = []
110107
query = sql.SQL("""
111108
SELECT id, data FROM {table}
112-
WHERE id <= %(id)s AND guild_id = %(guild_id)s AND node = %(node)s
109+
WHERE guild_id = %(guild_id)s AND node = %(node)s
113110
ORDER BY id
114111
""").format(table=sql.Identifier(self.name))
115112
cursor = await conn.execute(query, {
116-
'id': id,
117113
'guild_id': self.node.guild_id,
118114
'node': "Master" if self.node.master else self.node.name
119115
})
@@ -134,46 +130,34 @@ async def do_read(id: int):
134130
with suppress(OperationalError):
135131
async with await AsyncConnection.connect(self.url, autocommit=True) as conn:
136132
while not self._stop_event.is_set():
137-
row_id = await self.read_queue.get()
138-
if row_id == self._SENTINEL:
139-
# Notify the queue that the message has been processed.
140-
self.read_queue.task_done()
141-
return
142-
await do_read(row_id)
143-
# Notify the queue that the message has been processed.
144-
self.read_queue.task_done()
133+
try:
134+
# we will read every 5s independent if there is data in the queue or not
135+
if not await asyncio.wait_for(self.read_queue.get(), timeout=5.0):
136+
return
137+
try:
138+
await do_read()
139+
finally:
140+
# Notify the queue that the message has been processed.
141+
self.read_queue.task_done()
142+
except (TimeoutError, asyncio.TimeoutError):
143+
await do_read()
145144

146145
async def subscribe(self):
147146
while not self._stop_event.is_set():
148147
with suppress(OperationalError):
149148
async with await AsyncConnection.connect(self.url, autocommit=True) as conn:
150149
async with conn.cursor() as cursor:
151150
# preprocess all rows that might be there
152-
async for row in await cursor.execute(sql.SQL("""
153-
WITH to_delete AS (
154-
SELECT id
155-
FROM {table}
156-
WHERE guild_id = %s
157-
AND node = %s
158-
ORDER BY id
159-
)
160-
DELETE FROM {table}
161-
WHERE id IN (SELECT id FROM to_delete)
162-
RETURNING id
163-
""").format(table=sql.Identifier(self.name)), (self.node.guild_id, self.node.name)):
164-
self.read_queue.put_nowait(row[0])
165151
await cursor.execute(sql.SQL("LISTEN {table}").format(table=sql.Identifier(self.name)))
166152
gen = conn.notifies()
167153
async for n in gen:
168154
if self._stop_event.is_set():
169155
self.log.debug(f'- {self.name.title()} stopped.')
170156
await gen.aclose()
171157
return
172-
data = json.loads(n.payload)
173-
if data['guild_id'] == self.node.guild_id and data['node'] == self.node.name or (
174-
self.node.master and data['node'] == 'Master'
175-
):
176-
self.read_queue.put_nowait(data['row_id'])
158+
node = n.payload
159+
if node == self.node.name or (self.node.master and node == 'Master'):
160+
self.read_queue.put_nowait(n.payload)
177161
await asyncio.sleep(1)
178162

179163
async def publish(self, data: dict) -> None:
@@ -200,7 +184,7 @@ async def clear(self):
200184

201185
async def close(self):
202186
self._stop_event.set()
203-
self.write_queue.put_nowait(self._SENTINEL)
187+
self.write_queue.put_nowait(None)
204188
await self.write_worker
205-
self.read_queue.put_nowait(self._SENTINEL)
189+
self.read_queue.put_nowait(None)
206190
await self.read_worker

0 commit comments

Comments
 (0)