-
Notifications
You must be signed in to change notification settings - Fork 17
Refactor database layer: replace psycopg3 queries with Django ORM asy… #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Demagalawrence
wants to merge
1
commit into
danidee10:main
Choose a base branch
from
Demagalawrence:replace-psycopg3-with-orm
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,10 +7,7 @@ | |
| import typing | ||
| from datetime import datetime, timedelta | ||
|
|
||
| import psycopg | ||
| import psycopg_pool | ||
| import psycopg_pool.base | ||
| from psycopg import sql | ||
|
|
||
|
|
||
| from .models import GroupChannel, Message | ||
|
|
||
|
|
@@ -33,11 +30,8 @@ | |
| # regardless of the amount of threads | ||
| # And also to prevent RuntimeErrors when the event loop is closed after running tests | ||
| # and psycopg Async workers are not cleaned up properly | ||
| is_creating_connection_pool = asyncio.Lock() | ||
| connection_pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None | ||
|
|
||
| MESSAGE_TABLE = Message._meta.db_table | ||
| GROUP_CHANNEL_TABLE = GroupChannel._meta.db_table | ||
|
|
||
|
|
||
|
|
||
| def utc_now() -> datetime: | ||
|
|
@@ -54,11 +48,9 @@ def utc_now() -> datetime: | |
|
|
||
| class DatabaseLayer: | ||
| """ | ||
| Encapsulates database operations | ||
|
|
||
| A connection pool is used for efficient management of database operations | ||
| This is also the reason why psycopg is used directly instead of django's ORM | ||
| which doesn't support connection pooling | ||
| Encapsulates database operations using Django ORM for async operations. | ||
| Replaces handcrafted psycopg3 queries for simplicity, maintainability, | ||
| and native Django 5.1 connection pooling support. | ||
| """ | ||
|
|
||
| def __init__( | ||
|
|
@@ -73,138 +65,36 @@ def __init__( | |
| self.db_params = db_params | ||
| self.psycopg_options = psycopg_options | ||
|
|
||
| async def get_db_pool( | ||
| self, db_params: dict[str, typing.Any] | ||
| ) -> psycopg_pool.AsyncConnectionPool: | ||
| """ | ||
| Returns a connection pool for the database | ||
|
|
||
| Uses a `Lock` to ensure that only one coroutine can create the connection pool | ||
| Others have to wait until the connection pool is created | ||
| """ | ||
| global connection_pool # pylint: disable=W0603 | ||
|
|
||
| async def _configure_connection(conn: psycopg.AsyncConnection) -> None: | ||
| await conn.set_autocommit(True) | ||
| conn.prepare_threshold = 0 # All statements should be prepared | ||
| conn.prepared_max = None # No limit on the number of prepared statements | ||
|
|
||
| async with is_creating_connection_pool: | ||
| if connection_pool is not None: | ||
| self.logger.debug('Pool %s already exists', connection_pool.name) | ||
|
|
||
| pool_stats = connection_pool.get_stats() | ||
| self.logger.debug('Pool stats: %s', pool_stats) | ||
|
|
||
| return connection_pool | ||
|
|
||
| conn_info = psycopg.conninfo.make_conninfo(conninfo='', **db_params) | ||
|
|
||
| connection_pool = psycopg_pool.AsyncConnectionPool( | ||
| conninfo=conn_info, | ||
| open=False, | ||
| configure=_configure_connection, | ||
| **self.psycopg_options, | ||
| ) | ||
| await connection_pool.open(wait=True) | ||
|
|
||
| self.logger.debug('Pool %s created', connection_pool.name) | ||
|
|
||
| return connection_pool | ||
|
|
||
|
|
||
| async def retrieve_group_channels(self, group_key: str) -> list[str]: | ||
| """Retrieves all channels for a group""" | ||
| retrieve_channels_sql = sql.SQL( | ||
| 'SELECT DISTINCT group_key,channel FROM {table} WHERE group_key=%s' | ||
| ).format(table=sql.Identifier(GROUP_CHANNEL_TABLE)) | ||
|
|
||
| db_pool = await self.get_db_pool(db_params=self.db_params) | ||
| async with db_pool.connection() as conn: | ||
| async with conn.cursor() as cursor: | ||
| await cursor.execute(retrieve_channels_sql, (group_key,)) | ||
| result = await cursor.fetchall() | ||
| return [row[1] for row in result] | ||
|
|
||
| async def send_to_channel( | ||
| self, | ||
| group_key: str, | ||
| message: bytes, | ||
| expire: int, | ||
| channel: typing.Optional[str] = None, | ||
| ) -> None: | ||
| """Send a message to a channel/channels (if no channel is specified).""" | ||
| message_add_sql = sql.SQL( | ||
| 'INSERT INTO {table} (channel, message, expire) VALUES (%s, %s, %s)' | ||
| ).format(table=sql.Identifier(MESSAGE_TABLE)) | ||
|
|
||
| if channel is None: | ||
| channels = await self.retrieve_group_channels(group_key) | ||
| if not channels: | ||
| self.logger.warning('Group: %s does not exist, did you call group_add?', group_key) | ||
| return | ||
| else: | ||
| channels = [channel] | ||
|
|
||
| expiry_datetime = utc_now() + timedelta(seconds=expire) | ||
| db_pool = await self.get_db_pool(db_params=self.db_params) | ||
| async with db_pool.connection() as conn: | ||
| async with conn.cursor() as cursor: | ||
| if len(channels) == 1: | ||
| # single insert | ||
| data = (channels[0], message, expiry_datetime) | ||
| await cursor.execute(message_add_sql, data) | ||
| else: | ||
| # Bulk insert messages | ||
| multi_data = [(channel, message, expiry_datetime) for channel in channels] | ||
| await cursor.executemany(message_add_sql, multi_data) | ||
|
|
||
| qs = GroupChannel.objects.using(self.using).filter(group_key=group_key) | ||
| return list(await qs.values_list('channel', flat=True)) | ||
| async def add_channel_to_group(self, group_key: str, channel: str, expire: int) -> None: | ||
| """Adds a channel to a group""" | ||
| expiry_datetime = utc_now() + timedelta(seconds=expire) | ||
| group_add_sql = sql.SQL( | ||
| 'INSERT INTO {table} (group_key, channel, expire) VALUES (%s, %s, %s)' | ||
| ).format(table=sql.Identifier(GROUP_CHANNEL_TABLE)) | ||
|
|
||
| db_pool = await self.get_db_pool(db_params=self.db_params) | ||
| async with db_pool.connection() as conn: | ||
| async with conn.cursor() as cursor: | ||
| data = (group_key, channel, expiry_datetime) | ||
| await cursor.execute(group_add_sql, data) | ||
| await GroupChannel.objects.using(self.using).acreate( | ||
| group_key=group_key, channel=channel, expire=expiry_datetime | ||
| ) | ||
|
|
||
| self.logger.debug('Channel %s added to Group %s', channel, group_key) | ||
|
|
||
| async def delete_expired_groups(self) -> None: | ||
| """Deletes expired groups after a random delay""" | ||
| delete_expired_groups_sql = sql.SQL('DELETE FROM {table} WHERE expire < %s').format( | ||
| table=sql.Identifier(GROUP_CHANNEL_TABLE) | ||
| ) | ||
|
|
||
| expire = 60 * random.randint(10, 20) | ||
| self.logger.debug('Deleting expired groups in %s seconds...', expire) | ||
| await asyncio.sleep(expire) | ||
|
|
||
| delay = 60 * random.randint(10, 20) | ||
| self.logger.debug('Deleting expired groups in %s seconds...', delay) | ||
| await asyncio.sleep(delay) | ||
| now = utc_now() | ||
| db_pool = await self.get_db_pool(db_params=self.db_params) | ||
| async with db_pool.connection() as conn: | ||
| async with conn.cursor() as cursor: | ||
| await cursor.execute(delete_expired_groups_sql, (now,)) | ||
| await GroupChannel.objects.using(self.using).filter(expire__lt=now).adelete() | ||
|
|
||
| async def delete_expired_messages(self, expire: typing.Optional[int] = None) -> None: | ||
| """Deletes expired messages after a set time or random delay""" | ||
| delete_expired_messages_sql = sql.SQL('DELETE FROM {table} WHERE expire < %s').format( | ||
| table=sql.Identifier(MESSAGE_TABLE) | ||
| ) | ||
|
|
||
| if expire is None: | ||
| expire = 60 * random.randint(10, 20) | ||
| self.logger.debug('Deleting expired messages in %s seconds...', expire) | ||
| await asyncio.sleep(expire) | ||
|
|
||
| now = utc_now() | ||
| db_pool = await self.get_db_pool(db_params=self.db_params) | ||
| async with db_pool.connection() as conn: | ||
| async with conn.cursor() as cursor: | ||
| await cursor.execute(delete_expired_messages_sql, (now,)) | ||
| await Message.objects.using(self.using).filter(expire__lt=now).adelete() | ||
|
|
||
| async def retrieve_non_expired_queued_messages(self) -> list[tuple[str, str, bytes, str]]: | ||
| """ | ||
|
|
@@ -214,55 +104,24 @@ async def retrieve_non_expired_queued_messages(self) -> list[tuple[str, str, byt | |
| queries. Even if the inner query is ordered, the returning | ||
| clause is not guaranteed to be ordered | ||
| """ | ||
| retrieve_queued_messages_sql = sql.SQL( | ||
| """ | ||
| DELETE FROM {table} | ||
| WHERE id IN ( | ||
| SELECT id | ||
| FROM {table} | ||
| WHERE expire > %s | ||
| FOR UPDATE SKIP LOCKED | ||
| ) | ||
| RETURNING id::text, channel, message, extract(epoch from expire)::text | ||
| """ | ||
| ).format(table=sql.Identifier(MESSAGE_TABLE)) | ||
|
|
||
| now = utc_now() | ||
| db_pool = await self.get_db_pool(db_params=self.db_params) | ||
| async with db_pool.connection() as conn: | ||
| async with conn.cursor() as cursor: | ||
| await cursor.execute(retrieve_queued_messages_sql, (now,)) | ||
|
|
||
| return await cursor.fetchall() | ||
|
|
||
| msgs = await Message.objects.using(self.using).filter(expire__gt=now).all() | ||
| result = [(str(m.id), m.channel, m.message, str(m.expire.timestamp())) for m in msgs] | ||
| await Message.objects.using(self.using).filter(id__in=[m.id for m in msgs]).adelete() | ||
| return result | ||
| async def retrieve_non_expired_queued_message_from_channel( | ||
| self, channel: str | ||
| ) -> typing.Optional[tuple[bytes]]: | ||
| """Retrieves a non-expired message from a channel""" | ||
| retrieve_queued_messages_sql = sql.SQL( | ||
| """ | ||
| DELETE FROM {table} | ||
| WHERE id = ( | ||
| SELECT id | ||
| FROM {table} | ||
| WHERE channel=%s AND expire > %s | ||
| ORDER BY id | ||
| FOR UPDATE SKIP LOCKED | ||
| LIMIT 1 | ||
| ) | ||
| RETURNING message | ||
| """ | ||
| ).format(table=sql.Identifier(MESSAGE_TABLE)) | ||
|
|
||
| now = utc_now() | ||
| db_pool = await self.get_db_pool(db_params=self.db_params) | ||
| async with db_pool.connection() as conn: | ||
| async with conn.cursor() as cursor: | ||
| await cursor.execute(retrieve_queued_messages_sql, (channel, now)) | ||
| message = await cursor.fetchone() | ||
|
|
||
| return typing.cast(typing.Optional[tuple[bytes]], message) | ||
|
|
||
| msg = await Message.objects.using(self.using).filter(channel=channel, expire__gt=now).afirst() | ||
| if msg: | ||
| result = (msg.message,) | ||
| await Message.objects.using(self.using).filter(id=msg.id).adelete() | ||
| return result | ||
| return None | ||
| def _channel_to_constant_bigint(self, channel: str) -> int: | ||
| """ | ||
| Converts a channel name to a constant bigint. | ||
|
|
@@ -280,35 +139,22 @@ def _channel_to_constant_bigint(self, channel: str) -> int: | |
| return signed_bigint | ||
|
|
||
| async def acquire_advisory_lock(self, channel: str) -> bool: | ||
| """Acquires an advisory lock from the database""" | ||
| advisory_lock_id = self._channel_to_constant_bigint(channel) | ||
| acquire_advisory_lock_sql = sql.SQL('SELECT pg_try_advisory_lock(%s::bigint)').format( | ||
| advisory_lock_id=advisory_lock_id | ||
| ) | ||
|
|
||
| db_pool = await self.get_db_pool(db_params=self.db_params) | ||
| async with db_pool.connection() as conn: | ||
| async with conn.cursor() as cursor: | ||
| await cursor.execute(acquire_advisory_lock_sql, (advisory_lock_id,)) | ||
|
|
||
| result = await cursor.fetchone() | ||
| return result[0] if result else False | ||
|
|
||
| """Acquires an advisory lock (if still needed, else can remove)""" | ||
| # keep for compatibility if using PostgreSQL advisory locks | ||
| return True # placeholder for ORM-only setup | ||
|
|
||
| async def delete_message_returning_message( | ||
| self, message_id: int | ||
| ) -> typing.Optional[tuple[bytes]]: | ||
| """Deletes a message from the database and returns the message""" | ||
| delete_message_returning_message_sql = sql.SQL( | ||
| 'DELETE FROM {table} WHERE id=%s RETURNING message' | ||
| ).format(table=sql.Identifier(MESSAGE_TABLE)) | ||
|
|
||
| db_pool = await self.get_db_pool(db_params=self.db_params) | ||
| async with db_pool.connection() as conn: | ||
| async with conn.cursor() as cursor: | ||
| await cursor.execute(delete_message_returning_message_sql, (message_id,)) | ||
|
|
||
| return await cursor.fetchone() | ||
|
|
||
| msg = await Message.objects.using(self.using).filter(id=message_id).afirst() | ||
| if msg: | ||
| result = (msg.message,) | ||
| await Message.objects.using(self.using).filter(id=msg.id).adelete() | ||
| return result | ||
| return None | ||
|
|
||
| async def delete_channel_group(self, group_key: str, channel: str) -> None: | ||
| """Deletes a channel from a group""" | ||
| await ( | ||
|
|
@@ -321,11 +167,5 @@ async def flush(self) -> None: | |
| """ | ||
| Flushes the channel layer by truncating the message and group tables | ||
| """ | ||
| db_pool = await self.get_db_pool(db_params=self.db_params) | ||
| async with db_pool.connection() as conn: | ||
| await conn.execute( | ||
| sql.SQL('TRUNCATE TABLE {table}').format(table=sql.Identifier(MESSAGE_TABLE)) | ||
| ) | ||
| await conn.execute( | ||
| sql.SQL('TRUNCATE TABLE {table}').format(table=sql.Identifier(GROUP_CHANNEL_TABLE)) | ||
| ) | ||
| await Message.objects.using(self.using).adelete() | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will also keep the handcrafted query and execute both Truncate calls in one SQL statement. This should save one roundtrip to the server |
||
| await GroupChannel.objects.using(self.using).adelete() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will make an exception here and keep the handcrafted query because it's more efficient.
The query uses
DELETE...RETURNINGwhich deletes and returns the deleted result in one query.The new query with django ORM makes two queries (
SELECTandDELETE)