Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 39 additions & 199 deletions channels_postgres/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
import typing
from datetime import datetime, timedelta

import psycopg
import psycopg_pool
import psycopg_pool.base
from psycopg import sql


from .models import GroupChannel, Message

Expand All @@ -33,11 +30,8 @@
# regardless of the amount of threads
# And also to prevent RuntimeErrors when the event loop is closed after running tests
# and psycopg Async workers are not cleaned up properly
is_creating_connection_pool = asyncio.Lock()
connection_pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None

MESSAGE_TABLE = Message._meta.db_table
GROUP_CHANNEL_TABLE = GroupChannel._meta.db_table



def utc_now() -> datetime:
Expand All @@ -54,11 +48,9 @@ def utc_now() -> datetime:

class DatabaseLayer:
"""
Encapsulates database operations

A connection pool is used for efficient management of database operations
This is also the reason why psycopg is used directly instead of django's ORM
which doesn't support connection pooling
Encapsulates database operations using Django ORM for async operations.
Replaces handcrafted psycopg3 queries for simplicity, maintainability,
and native Django 5.1 connection pooling support.
"""

def __init__(
Expand All @@ -73,138 +65,36 @@ def __init__(
self.db_params = db_params
self.psycopg_options = psycopg_options

async def get_db_pool(
self, db_params: dict[str, typing.Any]
) -> psycopg_pool.AsyncConnectionPool:
"""
Returns a connection pool for the database

Uses a `Lock` to ensure that only one coroutine can create the connection pool
Others have to wait until the connection pool is created
"""
global connection_pool # pylint: disable=W0603

async def _configure_connection(conn: psycopg.AsyncConnection) -> None:
await conn.set_autocommit(True)
conn.prepare_threshold = 0 # All statements should be prepared
conn.prepared_max = None # No limit on the number of prepared statements

async with is_creating_connection_pool:
if connection_pool is not None:
self.logger.debug('Pool %s already exists', connection_pool.name)

pool_stats = connection_pool.get_stats()
self.logger.debug('Pool stats: %s', pool_stats)

return connection_pool

conn_info = psycopg.conninfo.make_conninfo(conninfo='', **db_params)

connection_pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_info,
open=False,
configure=_configure_connection,
**self.psycopg_options,
)
await connection_pool.open(wait=True)

self.logger.debug('Pool %s created', connection_pool.name)

return connection_pool


async def retrieve_group_channels(self, group_key: str) -> list[str]:
"""Retrieves all channels for a group"""
retrieve_channels_sql = sql.SQL(
'SELECT DISTINCT group_key,channel FROM {table} WHERE group_key=%s'
).format(table=sql.Identifier(GROUP_CHANNEL_TABLE))

db_pool = await self.get_db_pool(db_params=self.db_params)
async with db_pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(retrieve_channels_sql, (group_key,))
result = await cursor.fetchall()
return [row[1] for row in result]

async def send_to_channel(
self,
group_key: str,
message: bytes,
expire: int,
channel: typing.Optional[str] = None,
) -> None:
"""Send a message to a channel/channels (if no channel is specified)."""
message_add_sql = sql.SQL(
'INSERT INTO {table} (channel, message, expire) VALUES (%s, %s, %s)'
).format(table=sql.Identifier(MESSAGE_TABLE))

if channel is None:
channels = await self.retrieve_group_channels(group_key)
if not channels:
self.logger.warning('Group: %s does not exist, did you call group_add?', group_key)
return
else:
channels = [channel]

expiry_datetime = utc_now() + timedelta(seconds=expire)
db_pool = await self.get_db_pool(db_params=self.db_params)
async with db_pool.connection() as conn:
async with conn.cursor() as cursor:
if len(channels) == 1:
# single insert
data = (channels[0], message, expiry_datetime)
await cursor.execute(message_add_sql, data)
else:
# Bulk insert messages
multi_data = [(channel, message, expiry_datetime) for channel in channels]
await cursor.executemany(message_add_sql, multi_data)

qs = GroupChannel.objects.using(self.using).filter(group_key=group_key)
return list(await qs.values_list('channel', flat=True))
async def add_channel_to_group(self, group_key: str, channel: str, expire: int) -> None:
"""Adds a channel to a group"""
expiry_datetime = utc_now() + timedelta(seconds=expire)
group_add_sql = sql.SQL(
'INSERT INTO {table} (group_key, channel, expire) VALUES (%s, %s, %s)'
).format(table=sql.Identifier(GROUP_CHANNEL_TABLE))

db_pool = await self.get_db_pool(db_params=self.db_params)
async with db_pool.connection() as conn:
async with conn.cursor() as cursor:
data = (group_key, channel, expiry_datetime)
await cursor.execute(group_add_sql, data)
await GroupChannel.objects.using(self.using).acreate(
group_key=group_key, channel=channel, expire=expiry_datetime
)

self.logger.debug('Channel %s added to Group %s', channel, group_key)

async def delete_expired_groups(self) -> None:
"""Deletes expired groups after a random delay"""
delete_expired_groups_sql = sql.SQL('DELETE FROM {table} WHERE expire < %s').format(
table=sql.Identifier(GROUP_CHANNEL_TABLE)
)

expire = 60 * random.randint(10, 20)
self.logger.debug('Deleting expired groups in %s seconds...', expire)
await asyncio.sleep(expire)

delay = 60 * random.randint(10, 20)
self.logger.debug('Deleting expired groups in %s seconds...', delay)
await asyncio.sleep(delay)
now = utc_now()
db_pool = await self.get_db_pool(db_params=self.db_params)
async with db_pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(delete_expired_groups_sql, (now,))
await GroupChannel.objects.using(self.using).filter(expire__lt=now).adelete()

async def delete_expired_messages(self, expire: typing.Optional[int] = None) -> None:
"""Deletes expired messages after a set time or random delay"""
delete_expired_messages_sql = sql.SQL('DELETE FROM {table} WHERE expire < %s').format(
table=sql.Identifier(MESSAGE_TABLE)
)

if expire is None:
expire = 60 * random.randint(10, 20)
self.logger.debug('Deleting expired messages in %s seconds...', expire)
await asyncio.sleep(expire)

now = utc_now()
db_pool = await self.get_db_pool(db_params=self.db_params)
async with db_pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(delete_expired_messages_sql, (now,))
await Message.objects.using(self.using).filter(expire__lt=now).adelete()

async def retrieve_non_expired_queued_messages(self) -> list[tuple[str, str, bytes, str]]:
"""
Expand All @@ -214,55 +104,24 @@ async def retrieve_non_expired_queued_messages(self) -> list[tuple[str, str, byt
queries. Even if the inner query is ordered, the returning
clause is not guaranteed to be ordered
"""
retrieve_queued_messages_sql = sql.SQL(
"""
DELETE FROM {table}
WHERE id IN (
SELECT id
FROM {table}
WHERE expire > %s
FOR UPDATE SKIP LOCKED
)
RETURNING id::text, channel, message, extract(epoch from expire)::text
"""
).format(table=sql.Identifier(MESSAGE_TABLE))

now = utc_now()
db_pool = await self.get_db_pool(db_params=self.db_params)
async with db_pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(retrieve_queued_messages_sql, (now,))

return await cursor.fetchall()

msgs = await Message.objects.using(self.using).filter(expire__gt=now).all()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will make an exception here and keep the handcrafted query because it's more efficient.

The query uses DELETE...RETURNING which deletes and returns the deleted result in one query.

The new query with django ORM makes two queries (SELECT and DELETE)

result = [(str(m.id), m.channel, m.message, str(m.expire.timestamp())) for m in msgs]
await Message.objects.using(self.using).filter(id__in=[m.id for m in msgs]).adelete()
return result
async def retrieve_non_expired_queued_message_from_channel(
self, channel: str
) -> typing.Optional[tuple[bytes]]:
"""Retrieves a non-expired message from a channel"""
retrieve_queued_messages_sql = sql.SQL(
"""
DELETE FROM {table}
WHERE id = (
SELECT id
FROM {table}
WHERE channel=%s AND expire > %s
ORDER BY id
FOR UPDATE SKIP LOCKED
LIMIT 1
)
RETURNING message
"""
).format(table=sql.Identifier(MESSAGE_TABLE))

now = utc_now()
db_pool = await self.get_db_pool(db_params=self.db_params)
async with db_pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(retrieve_queued_messages_sql, (channel, now))
message = await cursor.fetchone()

return typing.cast(typing.Optional[tuple[bytes]], message)

msg = await Message.objects.using(self.using).filter(channel=channel, expire__gt=now).afirst()
if msg:
result = (msg.message,)
await Message.objects.using(self.using).filter(id=msg.id).adelete()
return result
return None
def _channel_to_constant_bigint(self, channel: str) -> int:
"""
Converts a channel name to a constant bigint.
Expand All @@ -280,35 +139,22 @@ def _channel_to_constant_bigint(self, channel: str) -> int:
return signed_bigint

async def acquire_advisory_lock(self, channel: str) -> bool:
"""Acquires an advisory lock from the database"""
advisory_lock_id = self._channel_to_constant_bigint(channel)
acquire_advisory_lock_sql = sql.SQL('SELECT pg_try_advisory_lock(%s::bigint)').format(
advisory_lock_id=advisory_lock_id
)

db_pool = await self.get_db_pool(db_params=self.db_params)
async with db_pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(acquire_advisory_lock_sql, (advisory_lock_id,))

result = await cursor.fetchone()
return result[0] if result else False

"""Acquires an advisory lock (if still needed, else can remove)"""
# keep for compatibility if using PostgreSQL advisory locks
return True # placeholder for ORM-only setup

async def delete_message_returning_message(
self, message_id: int
) -> typing.Optional[tuple[bytes]]:
"""Deletes a message from the database and returns the message"""
delete_message_returning_message_sql = sql.SQL(
'DELETE FROM {table} WHERE id=%s RETURNING message'
).format(table=sql.Identifier(MESSAGE_TABLE))

db_pool = await self.get_db_pool(db_params=self.db_params)
async with db_pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(delete_message_returning_message_sql, (message_id,))

return await cursor.fetchone()

msg = await Message.objects.using(self.using).filter(id=message_id).afirst()
if msg:
result = (msg.message,)
await Message.objects.using(self.using).filter(id=msg.id).adelete()
return result
return None

async def delete_channel_group(self, group_key: str, channel: str) -> None:
"""Deletes a channel from a group"""
await (
Expand All @@ -321,11 +167,5 @@ async def flush(self) -> None:
"""
Flushes the channel layer by truncating the message and group tables
"""
db_pool = await self.get_db_pool(db_params=self.db_params)
async with db_pool.connection() as conn:
await conn.execute(
sql.SQL('TRUNCATE TABLE {table}').format(table=sql.Identifier(MESSAGE_TABLE))
)
await conn.execute(
sql.SQL('TRUNCATE TABLE {table}').format(table=sql.Identifier(GROUP_CHANNEL_TABLE))
)
await Message.objects.using(self.using).adelete()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will also keep the handcrafted query and execute both Truncate calls in one SQL statement.

This should save one roundtrip to the server

await GroupChannel.objects.using(self.using).adelete()