|
1 | 1 | import logging |
| 2 | +from typing import Any, Iterable, Optional |
2 | 3 |
|
3 | 4 | import psycopg |
4 | 5 |
|
| 6 | +from dispatcher.brokers.base import BaseBroker |
| 7 | +from dispatcher.utils import resolve_callable |
| 8 | + |
5 | 9 | logger = logging.getLogger(__name__) |
6 | 10 |
|
7 | 11 |
|
|
13 | 17 | """ |
14 | 18 |
|
15 | 19 |
|
16 | | -# TODO: get database data from settings |
17 | | -# # As Django settings, may not use |
18 | | -# DATABASES = { |
19 | | -# "default": { |
20 | | -# "ENGINE": "django.db.backends.postgresql", |
21 | | -# "HOST": os.getenv("DB_HOST", "127.0.0.1"), |
22 | | -# "PORT": os.getenv("DB_PORT", 55777), |
23 | | -# "USER": os.getenv("DB_USER", "dispatch"), |
24 | | -# "PASSWORD": os.getenv("DB_PASSWORD", "dispatching"), |
25 | | -# "NAME": os.getenv("DB_NAME", "dispatch_db"), |
26 | | -# } |
27 | | -# } |
28 | | - |
29 | | - |
30 | | -async def aget_connection(config): |
31 | | - return await psycopg.AsyncConnection.connect(**config, autocommit=True) |
32 | | - |
33 | | - |
34 | | -def get_connection(config): |
35 | | - return psycopg.Connection.connect(**config, autocommit=True) |
36 | | - |
37 | | - |
38 | | -async def aprocess_notify(connection, channels, connected_callback=None): |
39 | | - async with connection.cursor() as cur: |
40 | | - for channel in channels: |
41 | | - await cur.execute(f"LISTEN {channel};") |
42 | | - logger.info(f"Set up pg_notify listening on channel '{channel}'") |
43 | | - |
44 | | - if connected_callback: |
45 | | - await connected_callback() |
46 | | - |
47 | | - while True: |
48 | | - logger.debug('Starting listening for pg_notify notifications') |
49 | | - async for notify in connection.notifies(): |
50 | | - yield notify.channel, notify.payload |
51 | | - |
52 | | - |
53 | | -async def apublish_message(connection, channel, payload=None): |
54 | | - async with connection.cursor() as cur: |
55 | | - if not payload: |
56 | | - await cur.execute(f'NOTIFY {channel};') |
| 20 | +class PGNotifyBase(BaseBroker): |
| 21 | + |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + channels: Iterable[str] = ('dispatcher_default',), |
| 25 | + default_publish_channel: str = 'dispatcher_default', |
| 26 | + ) -> None: |
| 27 | + self.channels = channels |
| 28 | + self.default_publish_channel = default_publish_channel |
| 29 | + |
| 30 | + |
| 31 | +class AsyncBroker(PGNotifyBase): |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + config: Optional[dict] = None, |
| 35 | + async_connection_factory: Optional[str] = None, |
| 36 | + sync_connection_factory: Optional[str] = None, # noqa |
| 37 | + connection: Optional[psycopg.AsyncConnection] = None, |
| 38 | + **kwargs, |
| 39 | + ) -> None: |
| 40 | + if not (config or async_connection_factory or connection): |
| 41 | + raise RuntimeError('Must specify either config or async_connection_factory') |
| 42 | + |
| 43 | + if config: |
| 44 | + self._config: Optional[dict] = config.copy() |
| 45 | + self._config['autocommit'] = True |
| 46 | + else: |
| 47 | + self._config = None |
| 48 | + |
| 49 | + self._async_connection_factory = async_connection_factory |
| 50 | + self._connection: Optional[Any] = connection |
| 51 | + |
| 52 | + super().__init__(**kwargs) |
| 53 | + |
| 54 | + async def get_connection(self) -> psycopg.AsyncConnection: |
| 55 | + if not self._connection: |
| 56 | + if self._async_connection_factory: |
| 57 | + factory = resolve_callable(self._async_connection_factory) |
| 58 | + if not factory: |
| 59 | + raise RuntimeError(f'Could not import connection factory {self._async_connection_factory}') |
| 60 | + if self._config: |
| 61 | + self._connection = await factory(**self._config) |
| 62 | + else: |
| 63 | + self._connection = await factory() |
| 64 | + elif self._config: |
| 65 | + self._connection = await AsyncBroker.create_connection(self._config) |
| 66 | + else: |
| 67 | + raise RuntimeError('Could not construct async connection for lack of config or factory') |
| 68 | + return self._connection |
| 69 | + |
| 70 | + @staticmethod |
| 71 | + async def create_connection(config) -> psycopg.AsyncConnection: |
| 72 | + return await psycopg.AsyncConnection.connect(**config) |
| 73 | + |
| 74 | + async def aprocess_notify(self, connected_callback=None): |
| 75 | + connection = await self.get_connection() |
| 76 | + async with connection.cursor() as cur: |
| 77 | + for channel in self.channels: |
| 78 | + await cur.execute(f"LISTEN {channel};") |
| 79 | + logger.info(f"Set up pg_notify listening on channel '{channel}'") |
| 80 | + |
| 81 | + if connected_callback: |
| 82 | + await connected_callback() |
| 83 | + |
| 84 | + while True: |
| 85 | + logger.debug('Starting listening for pg_notify notifications') |
| 86 | + async for notify in connection.notifies(): |
| 87 | + yield notify.channel, notify.payload |
| 88 | + |
| 89 | + async def apublish_message(self, channel: Optional[str] = None, payload=None) -> None: |
| 90 | + connection = await self.get_connection() |
| 91 | + if not channel: |
| 92 | + channel = self.default_publish_channel |
| 93 | + async with connection.cursor() as cur: |
| 94 | + if not payload: |
| 95 | + await cur.execute(f'NOTIFY {channel};') |
| 96 | + else: |
| 97 | + await cur.execute(f"NOTIFY {channel}, '{payload}';") |
| 98 | + |
| 99 | + async def aclose(self) -> None: |
| 100 | + if self._connection: |
| 101 | + await self._connection.close() |
| 102 | + self._connection = None |
| 103 | + |
| 104 | + |
| 105 | +connection_save = object() |
| 106 | + |
| 107 | + |
| 108 | +def connection_saver(**config): |
| 109 | + """ |
| 110 | + This mimics the behavior of Django for tests and demos |
| 111 | + Philosophically, this is used by an application that uses an ORM, |
| 112 | + or otherwise has its own connection management logic. |
| 113 | + Dispatcher does not manage connections, so this a simulation of that. |
| 114 | + """ |
| 115 | + if not hasattr(connection_save, '_connection'): |
| 116 | + config['autocommit'] = True |
| 117 | + connection_save._connection = SyncBroker.connect(**config) |
| 118 | + return connection_save._connection |
| 119 | + |
| 120 | + |
| 121 | +class SyncBroker(PGNotifyBase): |
| 122 | + def __init__( |
| 123 | + self, |
| 124 | + config: Optional[dict] = None, |
| 125 | + async_connection_factory: Optional[str] = None, # noqa |
| 126 | + sync_connection_factory: Optional[str] = None, |
| 127 | + connection: Optional[psycopg.Connection] = None, |
| 128 | + **kwargs, |
| 129 | + ) -> None: |
| 130 | + if not (config or sync_connection_factory or connection): |
| 131 | + raise RuntimeError('Must specify either config or async_connection_factory') |
| 132 | + |
| 133 | + if config: |
| 134 | + self._config: Optional[dict] = config.copy() |
| 135 | + self._config['autocommit'] = True |
57 | 136 | else: |
58 | | - await cur.execute(f"NOTIFY {channel}, '{payload}';") |
59 | | - |
60 | | - |
61 | | -def get_django_connection(): |
62 | | - try: |
63 | | - from django.conf import ImproperlyConfigured |
64 | | - from django.db import connection as pg_connection |
65 | | - except ImportError: |
66 | | - return None |
67 | | - else: |
68 | | - try: |
69 | | - if pg_connection.connection is None: |
70 | | - pg_connection.connect() |
71 | | - if pg_connection.connection is None: |
72 | | - raise RuntimeError('Unexpectedly could not connect to postgres for pg_notify actions') |
73 | | - return pg_connection.connection |
74 | | - except ImproperlyConfigured: |
75 | | - return None |
76 | | - |
77 | | - |
78 | | -def publish_message(queue, message, config=None, connection=None, new_connection=False): |
79 | | - conn = None |
80 | | - if connection: |
81 | | - conn = connection |
82 | | - |
83 | | - if (not conn) and (not new_connection): |
84 | | - conn = get_django_connection() |
85 | | - |
86 | | - created_new_conn = False |
87 | | - if not conn: |
88 | | - if config is None: |
89 | | - raise RuntimeError('Could not use Django connection, and no postgres config supplied') |
90 | | - conn = get_connection(config) |
91 | | - created_new_conn = True |
92 | | - |
93 | | - with conn.cursor() as cur: |
94 | | - cur.execute('SELECT pg_notify(%s, %s);', (queue, message)) |
95 | | - |
96 | | - logger.debug(f'Sent pg_notify message to {queue}') |
97 | | - |
98 | | - if created_new_conn: |
99 | | - conn.close() |
| 137 | + self._config = None |
| 138 | + |
| 139 | + self._sync_connection_factory = sync_connection_factory |
| 140 | + self._connection: Optional[Any] = connection |
| 141 | + super().__init__(**kwargs) |
| 142 | + |
| 143 | + def get_connection(self) -> psycopg.Connection: |
| 144 | + if not self._connection: |
| 145 | + if self._sync_connection_factory: |
| 146 | + factory = resolve_callable(self._sync_connection_factory) |
| 147 | + if not factory: |
| 148 | + raise RuntimeError(f'Could not import connection factory {self._sync_connection_factory}') |
| 149 | + if self._config: |
| 150 | + self._connection = factory(**self._config) |
| 151 | + else: |
| 152 | + self._connection = factory() |
| 153 | + elif self._config: |
| 154 | + self._connection = SyncBroker.create_connection(self._config) |
| 155 | + else: |
| 156 | + raise RuntimeError('Cound not construct synchronous connection for lack of config or factory') |
| 157 | + return self._connection |
| 158 | + |
| 159 | + @staticmethod |
| 160 | + def create_connection(config) -> psycopg.Connection: |
| 161 | + return psycopg.Connection.connect(**config) |
| 162 | + |
| 163 | + def publish_message(self, channel: Optional[str], message: dict) -> None: |
| 164 | + connection = self.get_connection() |
| 165 | + if not channel: |
| 166 | + channel = self.default_publish_channel |
| 167 | + |
| 168 | + with connection.cursor() as cur: |
| 169 | + cur.execute('SELECT pg_notify(%s, %s);', (channel, message)) |
| 170 | + |
| 171 | + logger.debug(f'Sent pg_notify message to {channel}') |
| 172 | + |
| 173 | + def close(self) -> None: |
| 174 | + if self._connection: |
| 175 | + self._connection.close() |
| 176 | + self._connection = None |
0 commit comments