11import logging
2- from typing import Any , Iterable , Optional
2+ from typing import Any , Iterable , Optional , Callable
33
44import psycopg
55
@@ -27,6 +27,26 @@ def __init__(
2727 self .channels = channels
2828 self .default_publish_channel = default_publish_channel
2929
30+ def get_publish_channel (self , channel : Optional [str ] = None ):
31+ "Handle default for the publishing channel for calls to publish_message, shared sync and async"
32+ if channel is not None :
33+ return channel
34+ if self .default_publish_channel is None :
35+ raise ValueError ('Could not determine a channel to use publish to from settings or PGNotify config' )
36+ return self .default_publish_channel
37+
38+ def get_connection_method (self , factory_path : Optional [str ] = None ) -> Callable :
39+ "Handles settings, returns a method (async or sync) for getting a new connection"
40+ if factory_path :
41+ factory = resolve_callable (factory_path )
42+ if not factory :
43+ raise RuntimeError (f'Could not import connection factory { factory_path } ' )
44+ return factory
45+ elif self ._config :
46+ return self .create_connection
47+ else :
48+ raise RuntimeError ('Could not construct connection for lack of config or factory' )
49+
3050
3151class AsyncBroker (PGNotifyBase ):
3252 def __init__ (
@@ -41,10 +61,10 @@ def __init__(
4161 raise RuntimeError ('Must specify either config or async_connection_factory' )
4262
4363 if config :
44- self ._config : Optional [ dict ] = config .copy ()
64+ self ._config : dict = config .copy ()
4565 self ._config ['autocommit' ] = True
4666 else :
47- self ._config = None
67+ self ._config = {}
4868
4969 self ._async_connection_factory = async_connection_factory
5070 self ._connection : Optional [Any ] = connection
@@ -53,22 +73,12 @@ def __init__(
5373
5474 async def get_connection (self ) -> psycopg .AsyncConnection :
5575 if not self ._connection :
56- if self ._async_connection_factory :
57- factory = resolve_callable (self ._async_connection_factory )
58- if not factory :
59- raise RuntimeError (f'Could not import connection factory { self ._async_connection_factory } ' )
60- if self ._config :
61- self ._connection = await factory (** self ._config )
62- else :
63- self ._connection = await factory ()
64- elif self ._config :
65- self ._connection = await AsyncBroker .create_connection (self ._config )
66- else :
67- raise RuntimeError ('Could not construct async connection for lack of config or factory' )
76+ factory = self .get_connection_method (factory_path = self ._async_connection_factory )
77+ self ._connection = await factory (** self ._config )
6878 return self ._connection
6979
7080 @staticmethod
71- async def create_connection (config ) -> psycopg .AsyncConnection :
81+ async def create_connection (** config ) -> psycopg .AsyncConnection :
7282 return await psycopg .AsyncConnection .connect (** config )
7383
7484 async def aprocess_notify (self , connected_callback = None ):
@@ -88,8 +98,8 @@ async def aprocess_notify(self, connected_callback=None):
8898
8999 async def apublish_message (self , channel : Optional [str ] = None , payload = None ) -> None :
90100 connection = await self .get_connection ()
91- if not channel :
92- channel = self . default_publish_channel
101+ channel = self . get_publish_channel ( channel )
102+
93103 async with connection .cursor () as cur :
94104 if not payload :
95105 await cur .execute (f'NOTIFY { channel } ;' )
@@ -119,7 +129,7 @@ def connection_saver(**config):
119129 """
120130 if connection_save ._connection is None :
121131 config ['autocommit' ] = True
122- connection_save ._connection = SyncBroker .create_connection (config )
132+ connection_save ._connection = SyncBroker .create_connection (** config )
123133 return connection_save ._connection
124134
125135
@@ -147,30 +157,17 @@ def __init__(
147157
148158 def get_connection (self ) -> psycopg .Connection :
149159 if not self ._connection :
150- if self ._sync_connection_factory :
151- factory = resolve_callable (self ._sync_connection_factory )
152- if not factory :
153- raise RuntimeError (f'Could not import connection factory { self ._sync_connection_factory } ' )
154- if self ._config :
155- self ._connection = factory (** self ._config )
156- else :
157- self ._connection = factory ()
158- elif self ._config :
159- self ._connection = SyncBroker .create_connection (self ._config )
160- else :
161- raise RuntimeError ('Cound not construct synchronous connection for lack of config or factory' )
160+ factory = self .get_connection_method (factory_path = self ._sync_connection_factory )
161+ self ._connection = factory (** self ._config )
162162 return self ._connection
163163
164164 @staticmethod
165- def create_connection (config ) -> psycopg .Connection :
165+ def create_connection (** config ) -> psycopg .Connection :
166166 return psycopg .Connection .connect (** config )
167167
168168 def publish_message (self , channel : Optional [str ] = None , message : str = '' ) -> None :
169169 connection = self .get_connection ()
170- if channel is None :
171- if self .default_publish_channel is None :
172- raise ValueError ('Could not determine a channel to use publish to from settings or PGNotify config' )
173- channel = self .default_publish_channel
170+ channel = self .get_publish_channel (channel )
174171
175172 with connection .cursor () as cur :
176173 cur .execute ('SELECT pg_notify(%s, %s);' , (channel , message ))
0 commit comments