Skip to content

Commit bbaa9f7

Browse files
committed
Review comment to consolidate factory handling
1 parent edafd4d commit bbaa9f7

File tree

1 file changed

+33
-36
lines changed

1 file changed

+33
-36
lines changed

dispatcher/brokers/pg_notify.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Iterable, Optional
2+
from typing import Any, Iterable, Optional, Callable
33

44
import psycopg
55

@@ -27,6 +27,26 @@ def __init__(
2727
self.channels = channels
2828
self.default_publish_channel = default_publish_channel
2929

30+
def get_publish_channel(self, channel: Optional[str] = None):
31+
"Handle default for the publishing channel for calls to publish_message, shared sync and async"
32+
if channel is not None:
33+
return channel
34+
if self.default_publish_channel is None:
35+
raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config')
36+
return self.default_publish_channel
37+
38+
def get_connection_method(self, factory_path: Optional[str] = None) -> Callable:
39+
"Handles settings, returns a method (async or sync) for getting a new connection"
40+
if factory_path:
41+
factory = resolve_callable(factory_path)
42+
if not factory:
43+
raise RuntimeError(f'Could not import connection factory {factory_path}')
44+
return factory
45+
elif self._config:
46+
return self.create_connection
47+
else:
48+
raise RuntimeError('Could not construct connection for lack of config or factory')
49+
3050

3151
class AsyncBroker(PGNotifyBase):
3252
def __init__(
@@ -41,10 +61,10 @@ def __init__(
4161
raise RuntimeError('Must specify either config or async_connection_factory')
4262

4363
if config:
44-
self._config: Optional[dict] = config.copy()
64+
self._config: dict = config.copy()
4565
self._config['autocommit'] = True
4666
else:
47-
self._config = None
67+
self._config = {}
4868

4969
self._async_connection_factory = async_connection_factory
5070
self._connection: Optional[Any] = connection
@@ -53,22 +73,12 @@ def __init__(
5373

5474
async def get_connection(self) -> psycopg.AsyncConnection:
5575
if not self._connection:
56-
if self._async_connection_factory:
57-
factory = resolve_callable(self._async_connection_factory)
58-
if not factory:
59-
raise RuntimeError(f'Could not import connection factory {self._async_connection_factory}')
60-
if self._config:
61-
self._connection = await factory(**self._config)
62-
else:
63-
self._connection = await factory()
64-
elif self._config:
65-
self._connection = await AsyncBroker.create_connection(self._config)
66-
else:
67-
raise RuntimeError('Could not construct async connection for lack of config or factory')
76+
factory = self.get_connection_method(factory_path=self._async_connection_factory)
77+
self._connection = await factory(**self._config)
6878
return self._connection
6979

7080
@staticmethod
71-
async def create_connection(config) -> psycopg.AsyncConnection:
81+
async def create_connection(**config) -> psycopg.AsyncConnection:
7282
return await psycopg.AsyncConnection.connect(**config)
7383

7484
async def aprocess_notify(self, connected_callback=None):
@@ -88,8 +98,8 @@ async def aprocess_notify(self, connected_callback=None):
8898

8999
async def apublish_message(self, channel: Optional[str] = None, payload=None) -> None:
90100
connection = await self.get_connection()
91-
if not channel:
92-
channel = self.default_publish_channel
101+
channel = self.get_publish_channel(channel)
102+
93103
async with connection.cursor() as cur:
94104
if not payload:
95105
await cur.execute(f'NOTIFY {channel};')
@@ -119,7 +129,7 @@ def connection_saver(**config):
119129
"""
120130
if connection_save._connection is None:
121131
config['autocommit'] = True
122-
connection_save._connection = SyncBroker.create_connection(config)
132+
connection_save._connection = SyncBroker.create_connection(**config)
123133
return connection_save._connection
124134

125135

@@ -147,30 +157,17 @@ def __init__(
147157

148158
def get_connection(self) -> psycopg.Connection:
149159
if not self._connection:
150-
if self._sync_connection_factory:
151-
factory = resolve_callable(self._sync_connection_factory)
152-
if not factory:
153-
raise RuntimeError(f'Could not import connection factory {self._sync_connection_factory}')
154-
if self._config:
155-
self._connection = factory(**self._config)
156-
else:
157-
self._connection = factory()
158-
elif self._config:
159-
self._connection = SyncBroker.create_connection(self._config)
160-
else:
161-
raise RuntimeError('Cound not construct synchronous connection for lack of config or factory')
160+
factory = self.get_connection_method(factory_path=self._sync_connection_factory)
161+
self._connection = factory(**self._config)
162162
return self._connection
163163

164164
@staticmethod
165-
def create_connection(config) -> psycopg.Connection:
165+
def create_connection(**config) -> psycopg.Connection:
166166
return psycopg.Connection.connect(**config)
167167

168168
def publish_message(self, channel: Optional[str] = None, message: str = '') -> None:
169169
connection = self.get_connection()
170-
if channel is None:
171-
if self.default_publish_channel is None:
172-
raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config')
173-
channel = self.default_publish_channel
170+
channel = self.get_publish_channel(channel)
174171

175172
with connection.cursor() as cur:
176173
cur.execute('SELECT pg_notify(%s, %s);', (channel, message))

0 commit comments

Comments
 (0)