|
1 | 1 | import logging |
| 2 | +from typing import Callable, Iterable, Optional |
2 | 3 |
|
3 | 4 | import psycopg |
4 | 5 |
|
| 6 | +from dispatcher.brokers.base import BaseBroker |
| 7 | +from dispatcher.utils import resolve_callable |
| 8 | + |
5 | 9 | logger = logging.getLogger(__name__) |
6 | 10 |
|
7 | 11 |
|
|
13 | 17 | """ |
14 | 18 |
|
15 | 19 |
|
16 | | -# TODO: get database data from settings |
17 | | -# # As Django settings, may not use |
18 | | -# DATABASES = { |
19 | | -# "default": { |
20 | | -# "ENGINE": "django.db.backends.postgresql", |
21 | | -# "HOST": os.getenv("DB_HOST", "127.0.0.1"), |
22 | | -# "PORT": os.getenv("DB_PORT", 55777), |
23 | | -# "USER": os.getenv("DB_USER", "dispatch"), |
24 | | -# "PASSWORD": os.getenv("DB_PASSWORD", "dispatching"), |
25 | | -# "NAME": os.getenv("DB_NAME", "dispatch_db"), |
26 | | -# } |
27 | | -# } |
28 | | - |
29 | | - |
30 | | -async def aget_connection(config): |
31 | | - return await psycopg.AsyncConnection.connect(**config, autocommit=True) |
32 | | - |
33 | | - |
34 | | -def get_connection(config): |
35 | | - return psycopg.Connection.connect(**config, autocommit=True) |
36 | | - |
37 | | - |
38 | | -async def aprocess_notify(connection, channels, connected_callback=None): |
39 | | - async with connection.cursor() as cur: |
40 | | - for channel in channels: |
41 | | - await cur.execute(f"LISTEN {channel};") |
42 | | - logger.info(f"Set up pg_notify listening on channel '{channel}'") |
43 | | - |
44 | | - if connected_callback: |
45 | | - await connected_callback() |
46 | | - |
47 | | - while True: |
48 | | - logger.debug('Starting listening for pg_notify notifications') |
49 | | - async for notify in connection.notifies(): |
50 | | - yield notify.channel, notify.payload |
51 | | - |
52 | | - |
53 | | -async def apublish_message(connection, channel, payload=None): |
54 | | - async with connection.cursor() as cur: |
55 | | - if not payload: |
56 | | - await cur.execute(f'NOTIFY {channel};') |
| 20 | +class PGNotifyBase(BaseBroker): |
| 21 | + |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + config: Optional[dict] = None, |
| 25 | + channels: Iterable[str] = ('dispatcher_default',), |
| 26 | + default_publish_channel: Optional[str] = None, |
| 27 | + ) -> None: |
| 28 | + """ |
| 29 | + channels - listening channels for the service and used for control-and-reply |
| 30 | + default_publish_channel - if not specified on task level or in the submission |
| 31 | + by default messages will be sent to this channel. |
| 32 | + this should be one of the listening channels for messages to be received. |
| 33 | + """ |
| 34 | + if config: |
| 35 | + self._config: dict = config.copy() |
| 36 | + self._config['autocommit'] = True |
| 37 | + else: |
| 38 | + self._config = {} |
| 39 | + |
| 40 | + self.channels = channels |
| 41 | + self.default_publish_channel = default_publish_channel |
| 42 | + |
| 43 | + def get_publish_channel(self, channel: Optional[str] = None): |
| 44 | + "Handle default for the publishing channel for calls to publish_message, shared sync and async" |
| 45 | + if channel is not None: |
| 46 | + return channel |
| 47 | + if self.default_publish_channel is None: |
| 48 | + raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config') |
| 49 | + return self.default_publish_channel |
| 50 | + |
| 51 | + def get_connection_method(self, factory_path: Optional[str] = None) -> Callable: |
| 52 | + "Handles settings, returns a method (async or sync) for getting a new connection" |
| 53 | + if factory_path: |
| 54 | + factory = resolve_callable(factory_path) |
| 55 | + if not factory: |
| 56 | + raise RuntimeError(f'Could not import connection factory {factory_path}') |
| 57 | + return factory |
| 58 | + elif self._config: |
| 59 | + return self.create_connection |
57 | 60 | else: |
58 | | - await cur.execute(f"NOTIFY {channel}, '{payload}';") |
59 | | - |
60 | | - |
61 | | -def get_django_connection(): |
62 | | - try: |
63 | | - from django.conf import ImproperlyConfigured |
64 | | - from django.db import connection as pg_connection |
65 | | - except ImportError: |
66 | | - return None |
67 | | - else: |
68 | | - try: |
69 | | - if pg_connection.connection is None: |
70 | | - pg_connection.connect() |
71 | | - if pg_connection.connection is None: |
72 | | - raise RuntimeError('Unexpectedly could not connect to postgres for pg_notify actions') |
73 | | - return pg_connection.connection |
74 | | - except ImproperlyConfigured: |
75 | | - return None |
76 | | - |
77 | | - |
78 | | -def publish_message(queue, message, config=None, connection=None, new_connection=False): |
79 | | - conn = None |
80 | | - if connection: |
81 | | - conn = connection |
82 | | - |
83 | | - if (not conn) and (not new_connection): |
84 | | - conn = get_django_connection() |
85 | | - |
86 | | - created_new_conn = False |
87 | | - if not conn: |
88 | | - if config is None: |
89 | | - raise RuntimeError('Could not use Django connection, and no postgres config supplied') |
90 | | - conn = get_connection(config) |
91 | | - created_new_conn = True |
92 | | - |
93 | | - with conn.cursor() as cur: |
94 | | - cur.execute('SELECT pg_notify(%s, %s);', (queue, message)) |
95 | | - |
96 | | - logger.debug(f'Sent pg_notify message to {queue}') |
97 | | - |
98 | | - if created_new_conn: |
99 | | - conn.close() |
| 61 | + raise RuntimeError('Could not construct connection for lack of config or factory') |
| 62 | + |
| 63 | + def create_connection(self): ... |
| 64 | + |
| 65 | + |
| 66 | +class AsyncBroker(PGNotifyBase): |
| 67 | + def __init__( |
| 68 | + self, |
| 69 | + config: Optional[dict] = None, |
| 70 | + async_connection_factory: Optional[str] = None, |
| 71 | + sync_connection_factory: Optional[str] = None, # noqa |
| 72 | + connection: Optional[psycopg.AsyncConnection] = None, |
| 73 | + **kwargs, |
| 74 | + ) -> None: |
| 75 | + if not (config or async_connection_factory or connection): |
| 76 | + raise RuntimeError('Must specify either config or async_connection_factory') |
| 77 | + |
| 78 | + self._async_connection_factory = async_connection_factory |
| 79 | + self._connection = connection |
| 80 | + |
| 81 | + super().__init__(config=config, **kwargs) |
| 82 | + |
| 83 | + async def get_connection(self) -> psycopg.AsyncConnection: |
| 84 | + if not self._connection: |
| 85 | + factory = self.get_connection_method(factory_path=self._async_connection_factory) |
| 86 | + connection = await factory(**self._config) |
| 87 | + self._connection = connection |
| 88 | + return connection # slightly weird due to MyPY |
| 89 | + return self._connection |
| 90 | + |
| 91 | + @staticmethod |
| 92 | + async def create_connection(**config) -> psycopg.AsyncConnection: |
| 93 | + return await psycopg.AsyncConnection.connect(**config) |
| 94 | + |
| 95 | + async def aprocess_notify(self, connected_callback=None): |
| 96 | + connection = await self.get_connection() |
| 97 | + async with connection.cursor() as cur: |
| 98 | + for channel in self.channels: |
| 99 | + await cur.execute(f"LISTEN {channel};") |
| 100 | + logger.info(f"Set up pg_notify listening on channel '{channel}'") |
| 101 | + |
| 102 | + if connected_callback: |
| 103 | + await connected_callback() |
| 104 | + |
| 105 | + while True: |
| 106 | + logger.debug('Starting listening for pg_notify notifications') |
| 107 | + async for notify in connection.notifies(): |
| 108 | + yield notify.channel, notify.payload |
| 109 | + |
| 110 | + async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: |
| 111 | + connection = await self.get_connection() |
| 112 | + channel = self.get_publish_channel(channel) |
| 113 | + |
| 114 | + async with connection.cursor() as cur: |
| 115 | + if not message: |
| 116 | + await cur.execute(f'NOTIFY {channel};') |
| 117 | + else: |
| 118 | + await cur.execute(f"NOTIFY {channel}, '{message}';") |
| 119 | + |
| 120 | + logger.debug(f'Sent pg_notify message of {len(message)} chars to {channel}') |
| 121 | + |
| 122 | + async def aclose(self) -> None: |
| 123 | + if self._connection: |
| 124 | + await self._connection.close() |
| 125 | + self._connection = None |
| 126 | + |
| 127 | + |
| 128 | +class SyncBroker(PGNotifyBase): |
| 129 | + def __init__( |
| 130 | + self, |
| 131 | + config: Optional[dict] = None, |
| 132 | + async_connection_factory: Optional[str] = None, # noqa |
| 133 | + sync_connection_factory: Optional[str] = None, |
| 134 | + connection: Optional[psycopg.Connection] = None, |
| 135 | + **kwargs, |
| 136 | + ) -> None: |
| 137 | + if not (config or sync_connection_factory or connection): |
| 138 | + raise RuntimeError('Must specify either config or async_connection_factory') |
| 139 | + |
| 140 | + self._sync_connection_factory = sync_connection_factory |
| 141 | + self._connection = connection |
| 142 | + super().__init__(config=config, **kwargs) |
| 143 | + |
| 144 | + def get_connection(self) -> psycopg.Connection: |
| 145 | + if not self._connection: |
| 146 | + factory = self.get_connection_method(factory_path=self._sync_connection_factory) |
| 147 | + connection = factory(**self._config) |
| 148 | + self._connection = connection |
| 149 | + return connection |
| 150 | + return self._connection |
| 151 | + |
| 152 | + @staticmethod |
| 153 | + def create_connection(**config) -> psycopg.Connection: |
| 154 | + return psycopg.Connection.connect(**config) |
| 155 | + |
| 156 | + def publish_message(self, channel: Optional[str] = None, message: str = '') -> None: |
| 157 | + connection = self.get_connection() |
| 158 | + channel = self.get_publish_channel(channel) |
| 159 | + |
| 160 | + with connection.cursor() as cur: |
| 161 | + if message: |
| 162 | + cur.execute('SELECT pg_notify(%s, %s);', (channel, message)) |
| 163 | + else: |
| 164 | + cur.execute(f'NOTIFY {channel};') |
| 165 | + |
| 166 | + logger.debug(f'Sent pg_notify message of {len(message)} chars to {channel}') |
| 167 | + |
| 168 | + def close(self) -> None: |
| 169 | + if self._connection: |
| 170 | + self._connection.close() |
| 171 | + self._connection = None |
| 172 | + |
| 173 | + |
| 174 | +class ConnectionSaver: |
| 175 | + def __init__(self) -> None: |
| 176 | + self._connection: Optional[psycopg.Connection] = None |
| 177 | + self._async_connection: Optional[psycopg.AsyncConnection] = None |
| 178 | + |
| 179 | + |
| 180 | +connection_save = ConnectionSaver() |
| 181 | + |
| 182 | + |
| 183 | +def connection_saver(**config) -> psycopg.Connection: |
| 184 | + """ |
| 185 | + This mimics the behavior of Django for tests and demos |
| 186 | + Philosophically, this is used by an application that uses an ORM, |
| 187 | + or otherwise has its own connection management logic. |
| 188 | + Dispatcher does not manage connections, so this a simulation of that. |
| 189 | + """ |
| 190 | + if connection_save._connection is None: |
| 191 | + config['autocommit'] = True |
| 192 | + connection_save._connection = SyncBroker.create_connection(**config) |
| 193 | + return connection_save._connection |
| 194 | + |
| 195 | + |
| 196 | +async def async_connection_saver(**config) -> psycopg.AsyncConnection: |
| 197 | + """ |
| 198 | + This mimics the behavior of Django for tests and demos |
| 199 | + Philosophically, this is used by an application that uses an ORM, |
| 200 | + or otherwise has its own connection management logic. |
| 201 | + Dispatcher does not manage connections, so this a simulation of that. |
| 202 | + """ |
| 203 | + if connection_save._async_connection is None: |
| 204 | + config['autocommit'] = True |
| 205 | + connection_save._async_connection = await AsyncBroker.create_connection(**config) |
| 206 | + return connection_save._async_connection |
0 commit comments