|
38 | 38 | get_hassjob_callable_job_type, |
39 | 39 | ) |
40 | 40 | from homeassistant.exceptions import HomeAssistantError |
41 | | -from homeassistant.helpers.dispatcher import async_dispatcher_send |
| 41 | +from homeassistant.helpers.dispatcher import ( |
| 42 | + async_dispatcher_connect, |
| 43 | + async_dispatcher_send, |
| 44 | +) |
42 | 45 | from homeassistant.helpers.importlib import async_import_module |
43 | 46 | from homeassistant.helpers.start import async_at_started |
44 | 47 | from homeassistant.helpers.typing import ConfigType |
|
71 | 74 | DEFAULT_WS_PATH, |
72 | 75 | DOMAIN, |
73 | 76 | MQTT_CONNECTION_STATE, |
| 77 | + MQTT_PROCESSED_SUBSCRIPTIONS, |
74 | 78 | PROTOCOL_5, |
75 | 79 | PROTOCOL_31, |
76 | 80 | TRANSPORT_WEBSOCKETS, |
|
109 | 113 | SUBSCRIBE_COOLDOWN = 0.1 |
110 | 114 | UNSUBSCRIBE_COOLDOWN = 0.1 |
111 | 115 | TIMEOUT_ACK = 10 |
| 116 | +SUBSCRIBE_TIMEOUT = 10 |
112 | 117 | RECONNECT_INTERVAL_SECONDS = 10 |
113 | 118 |
|
114 | 119 | MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1 |
@@ -184,19 +189,71 @@ async def async_publish( |
184 | 189 | ) |
185 | 190 |
|
186 | 191 |
|
| 192 | +@callback |
| 193 | +def async_on_subscribe_done( |
| 194 | + hass: HomeAssistant, |
| 195 | + topic: str, |
| 196 | + qos: int, |
| 197 | + on_subscribe_status: CALLBACK_TYPE, |
| 198 | +) -> CALLBACK_TYPE: |
| 199 | + """Call on_subscribe_done when the matched subscription was completed. |
| 200 | +
|
| 201 | + If a subscription is already present the callback will call |
| 202 | + on_subscribe_status directly. |
| 203 | + Call the returned callback to stop and cleanup status monitoring. |
| 204 | + """ |
| 205 | + |
| 206 | + async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None: |
| 207 | + if (topic, qos) not in subscriptions: |
| 208 | + return |
| 209 | + hass.loop.call_soon(on_subscribe_status) |
| 210 | + |
| 211 | + mqtt_data = hass.data[DATA_MQTT] |
| 212 | + if ( |
| 213 | + mqtt_data.client.connected |
| 214 | + and mqtt_data.client.is_active_subscription(topic) |
| 215 | + and not mqtt_data.client.is_pending_subscription(topic) |
| 216 | + ): |
| 217 | + hass.loop.call_soon(on_subscribe_status) |
| 218 | + |
| 219 | + return async_dispatcher_connect( |
| 220 | + hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe |
| 221 | + ) |
| 222 | + |
| 223 | + |
187 | 224 | @bind_hass |
188 | 225 | async def async_subscribe( |
189 | 226 | hass: HomeAssistant, |
190 | 227 | topic: str, |
191 | 228 | msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None], |
192 | 229 | qos: int = DEFAULT_QOS, |
193 | 230 | encoding: str | None = DEFAULT_ENCODING, |
| 231 | + on_subscribe: CALLBACK_TYPE | None = None, |
194 | 232 | ) -> CALLBACK_TYPE: |
195 | 233 | """Subscribe to an MQTT topic. |
196 | 234 |
|
| 235 | + If the on_subcribe callback hook is set, it will be called once |
| 236 | + when the subscription has been completed. |
| 237 | +
|
197 | 238 | Call the return value to unsubscribe. |
198 | 239 | """ |
199 | | - return async_subscribe_internal(hass, topic, msg_callback, qos, encoding) |
| 240 | + handler: CALLBACK_TYPE | None = None |
| 241 | + |
| 242 | + def _on_subscribe_done() -> None: |
| 243 | + """Call once when the subscription was completed.""" |
| 244 | + if TYPE_CHECKING: |
| 245 | + assert on_subscribe is not None and handler is not None |
| 246 | + |
| 247 | + handler() |
| 248 | + on_subscribe() |
| 249 | + |
| 250 | + subscription_handler = async_subscribe_internal( |
| 251 | + hass, topic, msg_callback, qos, encoding |
| 252 | + ) |
| 253 | + if on_subscribe is not None: |
| 254 | + handler = async_on_subscribe_done(hass, topic, qos, _on_subscribe_done) |
| 255 | + |
| 256 | + return subscription_handler |
200 | 257 |
|
201 | 258 |
|
202 | 259 | @callback |
@@ -640,12 +697,16 @@ def _async_on_socket_unregister_write( |
640 | 697 | if fileno > -1: |
641 | 698 | self.loop.remove_writer(sock) |
642 | 699 |
|
643 | | - def _is_active_subscription(self, topic: str) -> bool: |
| 700 | + def is_active_subscription(self, topic: str) -> bool: |
644 | 701 | """Check if a topic has an active subscription.""" |
645 | 702 | return topic in self._simple_subscriptions or any( |
646 | 703 | other.topic == topic for other in self._wildcard_subscriptions |
647 | 704 | ) |
648 | 705 |
|
| 706 | + def is_pending_subscription(self, topic: str) -> bool: |
| 707 | + """Check if a topic has a pending subscription.""" |
| 708 | + return topic in self._pending_subscriptions |
| 709 | + |
649 | 710 | async def async_publish( |
650 | 711 | self, topic: str, payload: PublishPayloadType, qos: int, retain: bool |
651 | 712 | ) -> None: |
@@ -899,7 +960,7 @@ def _async_remove(self, subscription: Subscription) -> None: |
899 | 960 | @callback |
900 | 961 | def _async_unsubscribe(self, topic: str) -> None: |
901 | 962 | """Unsubscribe from a topic.""" |
902 | | - if self._is_active_subscription(topic): |
| 963 | + if self.is_active_subscription(topic): |
903 | 964 | if self._max_qos[topic] == 0: |
904 | 965 | return |
905 | 966 | subs = self._matching_subscriptions(topic) |
@@ -963,6 +1024,7 @@ async def _async_perform_subscriptions(self) -> None: |
963 | 1024 | self._last_subscribe = time.monotonic() |
964 | 1025 |
|
965 | 1026 | await self._async_wait_for_mid_or_raise(mid, result) |
| 1027 | + async_dispatcher_send(self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, chunk_list) |
966 | 1028 |
|
967 | 1029 | async def _async_perform_unsubscribes(self) -> None: |
968 | 1030 | """Perform pending MQTT client unsubscribes.""" |
|
0 commit comments