11import asyncio
2+ import collections
23import contextlib
34import logging
45import ssl
@@ -17,8 +18,9 @@ def __init__(self, config):
1718 self ._config = config
1819 self ._mdns_resolver = enapter .mdns .Resolver ()
1920 self ._tls_context = self ._new_tls_context (config )
20- self ._publisher = None
21- self ._publisher_connected = asyncio .Event ()
21+ self ._client = None
22+ self ._client_ready = asyncio .Event ()
23+ self ._subscribers = collections .defaultdict (int )
2224
2325 @staticmethod
2426 def _new_logger (config ):
@@ -29,58 +31,89 @@ def config(self):
2931 return self ._config
3032
3133 async def publish (self , * args , ** kwargs ):
32- await self ._publisher_connected . wait ()
33- await self . _publisher .publish (* args , ** kwargs )
34+ client = await self ._wait_client ()
35+ await client .publish (* args , ** kwargs )
3436
3537 @enapter .async_ .generator
36- async def subscribe (self , * topics ):
38+ async def subscribe (self , topic ):
3739 while True :
40+ client = await self ._wait_client ()
41+
3842 try :
39- async with self . _connect () as subscriber :
40- for topic in topics :
41- await subscriber . subscribe ( topic )
42- self . _logger . info ( "subscriber [%s] connected" , "," . join ( topics ))
43- async for msg in subscriber . messages :
44- yield msg
43+ async with client . messages () as messages :
44+ async with self . _subscribe ( client , topic ) :
45+ async for msg in messages :
46+ if msg . topic . matches ( topic ):
47+ yield msg
48+
4549 except aiomqtt .MqttError as e :
4650 self ._logger .error (e )
4751 retry_interval = 5
4852 await asyncio .sleep (retry_interval )
49- finally :
50- self ._logger .info ("subscriber disconnected" )
53+
54+ @contextlib .asynccontextmanager
55+ async def _subscribe (self , client , topic ):
56+ first_subscriber = not self ._subscribers [topic ]
57+ self ._subscribers [topic ] += 1
58+ try :
59+ if first_subscriber :
60+ await client .subscribe (topic )
61+ yield
62+ finally :
63+ self ._subscribers [topic ] -= 1
64+ assert not self ._subscribers [topic ] < 0
65+ last_unsubscriber = not self ._subscribers [topic ]
66+ if last_unsubscriber :
67+ del self ._subscribers [topic ]
68+ await client .unsubscribe (topic )
69+
70+ async def _wait_client (self ):
71+ await self ._client_ready .wait ()
72+ assert self ._client_ready .is_set ()
73+ return self ._client
5174
5275 async def _run (self ):
5376 self ._logger .info ("starting" )
77+
5478 self ._started .set ()
79+
5580 while True :
5681 try :
57- async with self ._connect () as publisher :
58- self ._logger .info ("publisher connected" )
59- self ._publisher = publisher
60- self ._publisher_connected .set ()
61- async for msg in publisher .messages :
62- pass
82+ async with self ._connect () as client :
83+ self ._client = client
84+ self ._client_ready .set ()
85+ self ._logger .info ("client ready" )
86+
87+ # tracking disconnect
88+ async with client .messages () as messages :
89+ async for msg in messages :
90+ pass
6391 except aiomqtt .MqttError as e :
6492 self ._logger .error (e )
6593 retry_interval = 5
6694 await asyncio .sleep (retry_interval )
6795 finally :
68- self ._publisher_connected .clear ()
69- self ._publisher = None
70- self ._logger .info ("publisher disconnected " )
96+ self ._client_ready .clear ()
97+ self ._client = None
98+ self ._logger .info ("client not ready " )
7199
72100 @contextlib .asynccontextmanager
73101 async def _connect (self ):
74102 host = await self ._maybe_resolve_mdns (self ._config .host )
75- async with aiomqtt .Client (
76- hostname = host ,
77- port = self ._config .port ,
78- username = self ._config .user ,
79- password = self ._config .password ,
80- logger = self ._logger ,
81- tls_context = self ._tls_context ,
82- ) as client :
83- yield client
103+
104+ try :
105+ async with aiomqtt .Client (
106+ hostname = host ,
107+ port = self ._config .port ,
108+ username = self ._config .user ,
109+ password = self ._config .password ,
110+ logger = self ._logger ,
111+ tls_context = self ._tls_context ,
112+ ) as client :
113+ yield client
114+ except asyncio .CancelledError :
115+ # FIXME: A cancelled `aiomqtt.Client.connect` leaks resources.
116+ raise
84117
85118 @staticmethod
86119 def _new_tls_context (config ):
0 commit comments