Skip to content

Commit ef5573c

Browse files
authored
Allow to callback for MQTT subscription status (home-assistant#152994)
1 parent 45aecd5 commit ef5573c

File tree

4 files changed

+173
-4
lines changed

4 files changed

+173
-4
lines changed

homeassistant/components/mqtt/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from . import debug_info, discovery
4141
from .client import (
4242
MQTT,
43+
async_on_subscribe_done,
4344
async_publish,
4445
async_subscribe,
4546
async_subscribe_internal,
@@ -163,6 +164,7 @@
163164
"async_create_certificate_temp_files",
164165
"async_forward_entry_setup_and_setup_discovery",
165166
"async_migrate_entry",
167+
"async_on_subscribe_done",
166168
"async_prepare_subscribe_topics",
167169
"async_publish",
168170
"async_remove_config_entry_device",

homeassistant/components/mqtt/client.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@
3838
get_hassjob_callable_job_type,
3939
)
4040
from homeassistant.exceptions import HomeAssistantError
41-
from homeassistant.helpers.dispatcher import async_dispatcher_send
41+
from homeassistant.helpers.dispatcher import (
42+
async_dispatcher_connect,
43+
async_dispatcher_send,
44+
)
4245
from homeassistant.helpers.importlib import async_import_module
4346
from homeassistant.helpers.start import async_at_started
4447
from homeassistant.helpers.typing import ConfigType
@@ -71,6 +74,7 @@
7174
DEFAULT_WS_PATH,
7275
DOMAIN,
7376
MQTT_CONNECTION_STATE,
77+
MQTT_PROCESSED_SUBSCRIPTIONS,
7478
PROTOCOL_5,
7579
PROTOCOL_31,
7680
TRANSPORT_WEBSOCKETS,
@@ -109,6 +113,7 @@
109113
SUBSCRIBE_COOLDOWN = 0.1
110114
UNSUBSCRIBE_COOLDOWN = 0.1
111115
TIMEOUT_ACK = 10
116+
SUBSCRIBE_TIMEOUT = 10
112117
RECONNECT_INTERVAL_SECONDS = 10
113118

114119
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
@@ -184,19 +189,71 @@ async def async_publish(
184189
)
185190

186191

192+
@callback
193+
def async_on_subscribe_done(
194+
hass: HomeAssistant,
195+
topic: str,
196+
qos: int,
197+
on_subscribe_status: CALLBACK_TYPE,
198+
) -> CALLBACK_TYPE:
199+
"""Call on_subscribe_done when the matched subscription was completed.
200+
201+
If a subscription is already present the callback will call
202+
on_subscribe_status directly.
203+
Call the returned callback to stop and cleanup status monitoring.
204+
"""
205+
206+
async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None:
207+
if (topic, qos) not in subscriptions:
208+
return
209+
hass.loop.call_soon(on_subscribe_status)
210+
211+
mqtt_data = hass.data[DATA_MQTT]
212+
if (
213+
mqtt_data.client.connected
214+
and mqtt_data.client.is_active_subscription(topic)
215+
and not mqtt_data.client.is_pending_subscription(topic)
216+
):
217+
hass.loop.call_soon(on_subscribe_status)
218+
219+
return async_dispatcher_connect(
220+
hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe
221+
)
222+
223+
187224
@bind_hass
188225
async def async_subscribe(
189226
hass: HomeAssistant,
190227
topic: str,
191228
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
192229
qos: int = DEFAULT_QOS,
193230
encoding: str | None = DEFAULT_ENCODING,
231+
on_subscribe: CALLBACK_TYPE | None = None,
194232
) -> CALLBACK_TYPE:
195233
"""Subscribe to an MQTT topic.
196234
235+
If the on_subcribe callback hook is set, it will be called once
236+
when the subscription has been completed.
237+
197238
Call the return value to unsubscribe.
198239
"""
199-
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
240+
handler: CALLBACK_TYPE | None = None
241+
242+
def _on_subscribe_done() -> None:
243+
"""Call once when the subscription was completed."""
244+
if TYPE_CHECKING:
245+
assert on_subscribe is not None and handler is not None
246+
247+
handler()
248+
on_subscribe()
249+
250+
subscription_handler = async_subscribe_internal(
251+
hass, topic, msg_callback, qos, encoding
252+
)
253+
if on_subscribe is not None:
254+
handler = async_on_subscribe_done(hass, topic, qos, _on_subscribe_done)
255+
256+
return subscription_handler
200257

201258

202259
@callback
@@ -640,12 +697,16 @@ def _async_on_socket_unregister_write(
640697
if fileno > -1:
641698
self.loop.remove_writer(sock)
642699

643-
def _is_active_subscription(self, topic: str) -> bool:
700+
def is_active_subscription(self, topic: str) -> bool:
644701
"""Check if a topic has an active subscription."""
645702
return topic in self._simple_subscriptions or any(
646703
other.topic == topic for other in self._wildcard_subscriptions
647704
)
648705

706+
def is_pending_subscription(self, topic: str) -> bool:
707+
"""Check if a topic has a pending subscription."""
708+
return topic in self._pending_subscriptions
709+
649710
async def async_publish(
650711
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
651712
) -> None:
@@ -899,7 +960,7 @@ def _async_remove(self, subscription: Subscription) -> None:
899960
@callback
900961
def _async_unsubscribe(self, topic: str) -> None:
901962
"""Unsubscribe from a topic."""
902-
if self._is_active_subscription(topic):
963+
if self.is_active_subscription(topic):
903964
if self._max_qos[topic] == 0:
904965
return
905966
subs = self._matching_subscriptions(topic)
@@ -963,6 +1024,7 @@ async def _async_perform_subscriptions(self) -> None:
9631024
self._last_subscribe = time.monotonic()
9641025

9651026
await self._async_wait_for_mid_or_raise(mid, result)
1027+
async_dispatcher_send(self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, chunk_list)
9661028

9671029
async def _async_perform_unsubscribes(self) -> None:
9681030
"""Perform pending MQTT client unsubscribes."""

homeassistant/components/mqtt/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@
375375
LOGGER = logging.getLogger(__package__)
376376

377377
MQTT_CONNECTION_STATE = "mqtt_connection_state"
378+
MQTT_PROCESSED_SUBSCRIPTIONS = "mqtt_processed_subscriptions"
378379

379380
PAYLOAD_EMPTY_JSON = "{}"
380381
PAYLOAD_NONE = "None"

tests/components/mqtt/test_client.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,100 @@ async def test_subscribe_topic(
282282
unsub()
283283

284284

285+
async def test_status_subscription_done(
286+
hass: HomeAssistant,
287+
mqtt_client_mock: MqttMockPahoClient,
288+
mqtt_mock_entry: MqttMockHAClientGenerator,
289+
recorded_calls: list[ReceiveMessage],
290+
record_calls: MessageCallbackType,
291+
) -> None:
292+
"""Test the on subscription status."""
293+
await mqtt_mock_entry()
294+
295+
on_status = asyncio.Event()
296+
on_status_calls: list[bool] = []
297+
298+
def _on_subscribe_status() -> None:
299+
on_status.set()
300+
on_status_calls.append(True)
301+
302+
subscribe_callback = await mqtt.async_subscribe(
303+
hass, "test-topic", record_calls, qos=0
304+
)
305+
handler = mqtt.async_on_subscribe_done(
306+
hass, "test-topic", 0, on_subscribe_status=_on_subscribe_status
307+
)
308+
await on_status.wait()
309+
assert ("test-topic", 0) in help_all_subscribe_calls(mqtt_client_mock)
310+
311+
await mqtt.async_publish(hass, "test-topic", "beer ready", 0)
312+
handler()
313+
assert len(recorded_calls) == 1
314+
assert recorded_calls[0].topic == "test-topic"
315+
assert recorded_calls[0].payload == "beer ready"
316+
assert recorded_calls[0].qos == 0
317+
318+
# Test as we have an existing subscription, test we get a callback
319+
recorded_calls.clear()
320+
on_status.clear()
321+
handler = mqtt.async_on_subscribe_done(
322+
hass, "test-topic", 0, on_subscribe_status=_on_subscribe_status
323+
)
324+
assert len(on_status_calls) == 1
325+
await on_status.wait()
326+
assert len(on_status_calls) == 2
327+
328+
# cleanup
329+
handler()
330+
subscribe_callback()
331+
332+
333+
async def test_subscribe_topic_with_subscribe_done(
334+
hass: HomeAssistant,
335+
mqtt_mock_entry: MqttMockHAClientGenerator,
336+
recorded_calls: list[ReceiveMessage],
337+
record_calls: MessageCallbackType,
338+
) -> None:
339+
"""Test the subscription of a topic."""
340+
await mqtt_mock_entry()
341+
342+
on_status = asyncio.Event()
343+
344+
def _on_subscribe() -> None:
345+
hass.async_create_task(mqtt.async_publish(hass, "test-topic", "beer ready", 0))
346+
on_status.set()
347+
348+
# Start a first subscription
349+
unsub1 = await mqtt.async_subscribe(
350+
hass, "test-topic", record_calls, on_subscribe=_on_subscribe
351+
)
352+
await on_status.wait()
353+
await hass.async_block_till_done()
354+
assert len(recorded_calls) == 1
355+
assert recorded_calls[0].topic == "test-topic"
356+
assert recorded_calls[0].payload == "beer ready"
357+
assert recorded_calls[0].qos == 0
358+
recorded_calls.clear()
359+
360+
# Start a second subscription to the same topic
361+
on_status.clear()
362+
unsub2 = await mqtt.async_subscribe(
363+
hass, "test-topic", record_calls, on_subscribe=_on_subscribe
364+
)
365+
await on_status.wait()
366+
await hass.async_block_till_done()
367+
assert len(recorded_calls) == 2
368+
assert recorded_calls[0].topic == "test-topic"
369+
assert recorded_calls[0].payload == "beer ready"
370+
assert recorded_calls[0].qos == 0
371+
assert recorded_calls[1].topic == "test-topic"
372+
assert recorded_calls[1].payload == "beer ready"
373+
assert recorded_calls[1].qos == 0
374+
375+
unsub1()
376+
unsub2()
377+
378+
285379
@pytest.mark.usefixtures("mqtt_mock_entry")
286380
async def test_subscribe_topic_not_initialize(
287381
hass: HomeAssistant, record_calls: MessageCallbackType
@@ -292,6 +386,16 @@ async def test_subscribe_topic_not_initialize(
292386
):
293387
await mqtt.async_subscribe(hass, "test-topic", record_calls)
294388

389+
def _on_subscribe_callback() -> None:
390+
pass
391+
392+
with pytest.raises(
393+
HomeAssistantError, match=r".*make sure MQTT is set up correctly"
394+
):
395+
await mqtt.async_subscribe(
396+
hass, "test-topic", record_calls, on_subscribe=_on_subscribe_callback
397+
)
398+
295399

296400
async def test_subscribe_mqtt_config_entry_disabled(
297401
hass: HomeAssistant, mqtt_mock: MqttMockHAClient, record_calls: MessageCallbackType

0 commit comments

Comments
 (0)