diff --git a/poetry.lock b/poetry.lock index 2899fdd..8f41bea 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,19 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. + +[[package]] +name = "aiomqtt" +version = "2.4.0" +description = "The idiomatic asyncio MQTT client" +optional = false +python-versions = "<4.0,>=3.8" +groups = ["main"] +files = [ + {file = "aiomqtt-2.4.0-py3-none-any.whl", hash = "sha256:721296e2b79df5f6c7c4dfc91700ae0166953a4127735c92637859619dbd84e4"}, + {file = "aiomqtt-2.4.0.tar.gz", hash = "sha256:ab0f18fc5b7ffaa57451c407417d674db837b00a9c7d953cccd02be64f046c17"}, +] + +[package.dependencies] +paho-mqtt = ">=2.1.0,<3.0.0" [[package]] name = "anyio" @@ -217,21 +232,6 @@ files = [ graph = ["objgraph (>=1.7.2)"] profile = ["gprof2dot (>=2022.7.29)"] -[[package]] -name = "gmqtt" -version = "0.7.0" -description = "Client for MQTT protocol" -optional = false -python-versions = ">=3.5" -groups = ["main"] -files = [ - {file = "gmqtt-0.7.0-py3-none-any.whl", hash = "sha256:3e5571a20e9c115d83d600caa228b06f716087653e241035e29cec73277b52cc"}, - {file = "gmqtt-0.7.0.tar.gz", hash = "sha256:bedfec7bac26b6b4ce1f0c4c32cff3d663526a54c882d323d41560fc3b9b44a2"}, -] - -[package.extras] -test = ["atomicwrites (>=1.3.0)", "attrs (>=19.1.0)", "codecov (>=2.0.15)", "coverage (>=4.5.3)", "more-itertools (>=7.0.0)", "pluggy (>=0.11.0)", "py (>=1.8.0)", "pytest (>=5.4.0)", "pytest-asyncio (>=0.12.0)", "pytest-cov (>=2.7.1)", "six (>=1.12.0)", "uvloop (>=0.14.0)"] - [[package]] name = "h11" version = "0.16.0" @@ -459,6 +459,21 @@ files = [ {file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"}, ] +[[package]] +name = "paho-mqtt" +version = "2.1.0" +description = "MQTT version 5.0/3.1.1 client class" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "paho_mqtt-2.1.0-py3-none-any.whl", hash = "sha256:6db9ba9b34ed5bc6b6e3812718c7e06e2fd7444540df2455d2c51bd58808feee"}, + {file = "paho_mqtt-2.1.0.tar.gz", hash = "sha256:12d6e7511d4137555a3f6ea167ae846af2c7357b10bc6fa4f7c3968fc1723834"}, +] + +[package.extras] +proxy = ["pysocks"] + [[package]] name = "pathspec" version = "0.12.1" @@ -824,4 +839,4 @@ devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3) [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0" -content-hash = "f78ad950e86b99428d134a345e435585c43bf2bbef1d53363cb45981c281631c" +content-hash = "4c6200f0f8d9c5a40267f08e5fc43dc221eef59613208bcb3b2fcb5a490b3c7f" diff --git a/pyproject.toml b/pyproject.toml index 3c803f3..0621a12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,10 +18,10 @@ requires-python = '>=3.12,<4.0' dependencies = [ "saic-ismart-client-ng (>=0.9.2,<0.10.0)", 'httpx (>=0.28.1,<0.29.0)', - 'gmqtt (>=0.7.0,<0.8.0)', 'inflection (>=0.5.1,<0.6.0)', 'apscheduler (>=3.11.0,<4.0.0)', 'python-dotenv (>=1.1.1,<2.0.0)', + "aiomqtt (>=2.4.0,<3.0.0)", ] [project.urls] diff --git a/src/configuration/__init__.py b/src/configuration/__init__.py index de30dad..0e8fb93 100644 --- a/src/configuration/__init__.py +++ b/src/configuration/__init__.py @@ -1,14 +1,17 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: from integrations.openwb.charging_station import ChargingStation +Transport = Literal["tcp", "websockets"] + + class TransportProtocol(Enum): - def __init__(self, transport_mechanism: str, with_tls: bool) -> None: + def __init__(self, transport_mechanism: Transport, with_tls: bool) -> None: self.transport_mechanism = transport_mechanism self.with_tls = with_tls diff --git a/src/configuration/parser.py b/src/configuration/parser.py index 6ceffba..f028782 100644 --- a/src/configuration/parser.py +++ b/src/configuration/parser.py @@ -102,15 +102,18 @@ def __parse_mqtt_transport(args: Namespace, config: Configuration) -> None: args.tls_server_cert_check_hostname ) else: - msg = f"Invalid MQTT URI scheme: {parse_result.scheme}, use tcp or ws" + msg = f"Invalid MQTT URI scheme: {parse_result.scheme}, use tls, tcp or ws" raise SystemExit(msg) if parse_result.port: config.mqtt_port = parse_result.port - elif config.mqtt_transport_protocol == TransportProtocol.TCP: - config.mqtt_port = 1883 - else: + elif config.mqtt_transport_protocol == TransportProtocol.TLS: + config.mqtt_port = 8883 + elif config.mqtt_transport_protocol == TransportProtocol.WS: config.mqtt_port = 9001 + else: + # fallback to default mqtt port + config.mqtt_port = 1883 config.mqtt_host = str(parse_result.hostname) diff --git a/src/log_config.py b/src/log_config.py index 9c96cae..2a49e3a 100644 --- a/src/log_config.py +++ b/src/log_config.py @@ -7,7 +7,7 @@ MODULES_DEFAULT_LOG_LEVEL = { "asyncio": "WARNING", - "gmqtt": "WARNING", + "aiomqtt": "WARNING", "httpcore": "WARNING", "httpx": "WARNING", "saic_ismart_client_ng": "WARNING", diff --git a/src/main.py b/src/main.py index dc47e09..e2d90d9 100644 --- a/src/main.py +++ b/src/main.py @@ -27,4 +27,5 @@ configuration = process_command_line() mqtt_gateway = MqttGateway(configuration) + asyncio.run(mqtt_gateway.run(), debug=debug_log_enabled()) diff --git a/src/publisher/core.py b/src/publisher/core.py index b0ffa66..77ca96c 100644 --- a/src/publisher/core.py +++ b/src/publisher/core.py @@ -55,28 +55,61 @@ def is_connected(self) -> bool: @abstractmethod def publish_json( - self, key: str, data: dict[str, Any], no_prefix: bool = False + self, + key: str, + data: dict[str, Any], + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, ) -> None: raise NotImplementedError @abstractmethod - def publish_str(self, key: str, value: str, no_prefix: bool = False) -> None: + def publish_str( + self, + key: str, + value: str, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: raise NotImplementedError @abstractmethod - def publish_int(self, key: str, value: int, no_prefix: bool = False) -> None: + def publish_int( + self, + key: str, + value: int, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: raise NotImplementedError @abstractmethod - def publish_bool(self, key: str, value: bool, no_prefix: bool = False) -> None: + def publish_bool( + self, + key: str, + value: bool, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: raise NotImplementedError @abstractmethod - def publish_float(self, key: str, value: float, no_prefix: bool = False) -> None: + def publish_float( + self, + key: str, + value: float, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: raise NotImplementedError @abstractmethod - def clear_topic(self, key: str, no_prefix: bool = False) -> None: + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: raise NotImplementedError def get_mqtt_account_prefix(self) -> str: @@ -154,7 +187,9 @@ def __anonymize(self, data: T) -> T: return data def keepalive(self) -> None: - self.publish_str(mqtt_topics.INTERNAL_LWT, "online", False) + self.publish_str( + mqtt_topics.INTERNAL_LWT, "online", no_prefix=False, retain=True, qos=1 + ) @staticmethod def anonymize_str(value: str) -> str: diff --git a/src/publisher/log_publisher.py b/src/publisher/log_publisher.py index efb6b24..6a717b9 100644 --- a/src/publisher/log_publisher.py +++ b/src/publisher/log_publisher.py @@ -24,29 +24,62 @@ def enable_commands(self) -> None: @override def publish_json( - self, key: str, data: dict[str, Any], no_prefix: bool = False + self, + key: str, + data: dict[str, Any], + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, ) -> None: anonymized_json = self.dict_to_anonymized_json(data) self.internal_publish(key, anonymized_json) @override - def publish_str(self, key: str, value: str, no_prefix: bool = False) -> None: + def publish_str( + self, + key: str, + value: str, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: self.internal_publish(key, value) @override - def publish_int(self, key: str, value: int, no_prefix: bool = False) -> None: + def publish_int( + self, + key: str, + value: int, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: self.internal_publish(key, value) @override - def publish_bool(self, key: str, value: bool, no_prefix: bool = False) -> None: + def publish_bool( + self, + key: str, + value: bool, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: self.internal_publish(key, value) @override - def publish_float(self, key: str, value: float, no_prefix: bool = False) -> None: + def publish_float( + self, + key: str, + value: float, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: self.internal_publish(key, value) @override - def clear_topic(self, key: str, no_prefix: bool = False) -> None: + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: self.internal_publish(key, None) def internal_publish(self, key: str, value: Any) -> None: diff --git a/src/publisher/mqtt_publisher.py b/src/publisher/mqtt_publisher.py index cb2809a..dc6f9b0 100644 --- a/src/publisher/mqtt_publisher.py +++ b/src/publisher/mqtt_publisher.py @@ -1,10 +1,11 @@ from __future__ import annotations +import asyncio import logging import ssl -from typing import TYPE_CHECKING, Any, Final, cast, override +from typing import TYPE_CHECKING, Any, override -import gmqtt +import aiomqtt import mqtt_topics from publisher.core import Publisher @@ -17,7 +18,10 @@ class MqttPublisher(Publisher): - def __init__(self, configuration: Configuration) -> None: + def __init__( + self, + configuration: Configuration, + ) -> None: super().__init__(configuration) self.publisher_id = configuration.mqtt_client_id self.host = self.configuration.mqtt_host @@ -27,111 +31,150 @@ def __init__(self, configuration: Configuration) -> None: self.last_charge_state_by_vin: dict[str, str] = {} self.vin_by_charger_connected_topic: dict[str, str] = {} self.first_connection = True + self.client: None | aiomqtt.Client = None + self.__running: asyncio.Task[None] | None = None + self.__connected = asyncio.Event() - mqtt_client = gmqtt.Client( - client_id=str(self.publisher_id), - transport=self.transport_protocol.transport_mechanism, - will_message=gmqtt.Message( - topic=self.get_topic(mqtt_topics.INTERNAL_LWT, False), - payload="offline", - retain=True, - ), - ) - mqtt_client.on_connect = self.__on_connect - mqtt_client.on_message = self.__on_message - self.client: Final[gmqtt.Client] = mqtt_client - - @override - async def connect(self) -> None: - if self.configuration.mqtt_user is not None: - if self.configuration.mqtt_password is not None: - self.client.set_auth_credentials( - username=self.configuration.mqtt_user, - password=self.configuration.mqtt_password, - ) - else: - self.client.set_auth_credentials(username=self.configuration.mqtt_user) - + async def __run_loop(self) -> None: + if not self.host: + LOG.info("MQTT host is not configured") + return + ssl_context: ssl.SSLContext | None = None if self.transport_protocol.with_tls: ssl_context = ssl.create_default_context() - cert_uri = self.configuration.tls_server_cert_path - if cert_uri: - LOG.debug(f"Using custom CA file {cert_uri}") - ssl_context.load_verify_locations(cafile=cert_uri) + if self.configuration.tls_server_cert_path: + LOG.debug( + f"Using custom CA file {self.configuration.tls_server_cert_path}" + ) + ssl_context.load_verify_locations( + cafile=self.configuration.tls_server_cert_path + ) if not self.configuration.tls_server_cert_check_hostname: LOG.warning( f"Skipping hostname check for TLS connection to {self.host}" ) - ssl_context.check_hostname = False - else: - ssl_context = None - await self.client.connect( - host=self.host, + + client = aiomqtt.Client( + hostname=self.host, port=self.port, - version=gmqtt.constants.MQTTv311, - ssl=ssl_context, + identifier=str(self.publisher_id) + "a", + transport=self.transport_protocol.transport_mechanism, + username=self.configuration.mqtt_user or None, + password=self.configuration.mqtt_password or None, + clean_session=True, + tls_context=ssl_context, + tls_insecure=bool( + ssl_context and not self.configuration.tls_server_cert_check_hostname + ), + will=aiomqtt.Will( + topic=self.get_topic(mqtt_topics.INTERNAL_LWT, False), + payload="offline", + retain=True, + qos=1, + ), ) - - def __on_connect( - self, _client: Any, _flags: Any, rc: int, _properties: Any - ) -> None: - if rc == gmqtt.constants.CONNACK_ACCEPTED: - LOG.info("Connected to MQTT broker") - if not self.first_connection: - self.enable_commands() - self.first_connection = False - self.keepalive() - else: - if rc == gmqtt.constants.CONNACK_REFUSED_BAD_USERNAME_PASSWORD: - LOG.error( - f"MQTT connection error: bad username or password. Return code {rc}" + client.pending_calls_threshold = 150 + reconnect_interval = 5 + while True: + try: + LOG.debug( + "Connecting to %s:%s as %s", + self.host, + self.port, + self.publisher_id, ) - elif rc == gmqtt.constants.CONNACK_REFUSED_PROTOCOL_VERSION: - LOG.error( - f"MQTT connection error: refused protocol version. Return code {rc}" + async with client as client_context: + self.client = client_context + self.__connected.set() + await self.__on_connect() + async for message in client_context.messages: + await self._on_message( + client_context, + str(message.topic), + message.payload, + message.qos, + message.properties, + ) + except aiomqtt.MqttError: + LOG.warning( + "Connection to %s:%s lost; Reconnecting in %d seconds ...", + self.host, + self.port, + reconnect_interval, ) - else: - LOG.error(f"MQTT connection error.Return code {rc}") - msg = f"Unable to connect to MQTT broker. Return code: {rc}" - raise SystemExit(msg) + await asyncio.sleep(reconnect_interval) + except asyncio.exceptions.CancelledError: + LOG.debug("MQTT publisher loop cancelled") + raise + finally: + self.__connected.clear() + LOG.info("MQTT client disconnected") + + @override + async def connect(self) -> None: + if self.__running and not self.__running.done(): + LOG.warning("MQTT client is already running") + return + self.__running = asyncio.create_task(self.__run_loop()) + await self.__connected.wait() + + async def __on_connect(self) -> None: + LOG.info("Connected to MQTT broker") + if not self.first_connection: + await self.__enable_commands() + self.first_connection = False + self.keepalive() @override def enable_commands(self) -> None: - LOG.info("Subscribing to MQTT command topics") - mqtt_account_prefix = self.get_mqtt_account_prefix() - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/{mqtt_topics.SET_SUFFIX}" - ) - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/+/{mqtt_topics.SET_SUFFIX}" - ) - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_MODE}/{mqtt_topics.SET_SUFFIX}" - ) - self.client.subscribe( - f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_PERIOD}/+/{mqtt_topics.SET_SUFFIX}" - ) - for charging_station in self.configuration.charging_stations_by_vin.values(): - LOG.debug( - f"Subscribing to MQTT topic {charging_station.charge_state_topic}" + loop = asyncio.get_running_loop() + asyncio.run_coroutine_threadsafe(self.__enable_commands(), loop) + + async def __enable_commands(self) -> None: + if not self.__connected.is_set() or not self.client: + LOG.error("Failed to enable commands: MQTT client is not connected") + return + try: + LOG.info("Subscribing to MQTT command topics") + mqtt_account_prefix = self.get_mqtt_account_prefix() + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/{mqtt_topics.SET_SUFFIX}" + ) + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/+/+/+/{mqtt_topics.SET_SUFFIX}" ) - self.vin_by_charge_state_topic[charging_station.charge_state_topic] = ( - charging_station.vin + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_MODE}/{mqtt_topics.SET_SUFFIX}" ) - self.client.subscribe(charging_station.charge_state_topic) - if charging_station.connected_topic: + await self.client.subscribe( + f"{mqtt_account_prefix}/{mqtt_topics.VEHICLES}/+/{mqtt_topics.REFRESH_PERIOD}/+/{mqtt_topics.SET_SUFFIX}" + ) + for ( + charging_station + ) in self.configuration.charging_stations_by_vin.values(): LOG.debug( - f"Subscribing to MQTT topic {charging_station.connected_topic}" + f"Subscribing to MQTT topic {charging_station.charge_state_topic}" + ) + self.vin_by_charge_state_topic[charging_station.charge_state_topic] = ( + charging_station.vin ) - self.vin_by_charger_connected_topic[ - charging_station.connected_topic - ] = charging_station.vin - self.client.subscribe(charging_station.connected_topic) - if self.configuration.ha_discovery_enabled: - # enable dynamic discovery pushing in case ha reconnects - self.client.subscribe(self.configuration.ha_lwt_topic) - - async def __on_message( + await self.client.subscribe(charging_station.charge_state_topic) + if charging_station.connected_topic: + LOG.debug( + f"Subscribing to MQTT topic {charging_station.connected_topic}" + ) + self.vin_by_charger_connected_topic[ + charging_station.connected_topic + ] = charging_station.vin + await self.client.subscribe(charging_station.connected_topic) + if self.configuration.ha_discovery_enabled: + # enable dynamic discovery pushing in case ha reconnects + await self.client.subscribe(self.configuration.ha_lwt_topic) + except aiomqtt.MqttError as e: + LOG.error("Failed to subscribe to MQTT command topics: {e}") + raise e + + async def _on_message( self, _client: Any, topic: str, payload: Any, _qos: Any, _properties: Any ) -> None: try: @@ -178,39 +221,104 @@ async def __on_message_real(self, *, topic: str, payload: str) -> None: vin=vin, topic=topic, payload=payload ) - def __publish(self, topic: str, payload: Any) -> None: - self.client.publish(topic, payload, retain=True) + def __publish( + self, topic: str, payload: Any, retain: bool = False, qos: int = 0 + ) -> None: + LOG.debug("Publishing to MQTT topic %s with payload %s", topic, payload) + loop = asyncio.get_running_loop() + asyncio.run_coroutine_threadsafe( + self.__async_publish(topic, payload, retain=retain, qos=qos), loop + ) + + async def __async_publish( + self, topic: str, payload: Any, retain: bool, qos: int + ) -> None: + if not (self.client and self.is_connected()): + LOG.error("Failed to publish: MQTT client is not connected") + return + try: + await self.client.publish(topic, payload, retain=retain, qos=qos) + except aiomqtt.MqttError as e: + LOG.error( + f"Failed to publish to MQTT topic {topic} with payload {payload}: {e}" + ) @override def is_connected(self) -> bool: - return cast("bool", self.client.is_connected) + return self.__connected.is_set() @override def publish_json( - self, key: str, data: dict[str, Any], no_prefix: bool = False + self, + key: str, + data: dict[str, Any], + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, ) -> None: payload = self.dict_to_anonymized_json(data) - self.__publish(topic=self.get_topic(key, no_prefix), payload=payload) + self.__publish( + topic=self.get_topic(key, no_prefix), + payload=payload, + retain=retain, + qos=qos, + ) @override - def publish_str(self, key: str, value: str, no_prefix: bool = False) -> None: - self.__publish(topic=self.get_topic(key, no_prefix), payload=value) + def publish_str( + self, + key: str, + value: str, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: + self.__publish( + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos + ) @override - def publish_int(self, key: str, value: int, no_prefix: bool = False) -> None: - self.__publish(topic=self.get_topic(key, no_prefix), payload=value) + def publish_int( + self, + key: str, + value: int, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: + self.__publish( + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos + ) @override - def publish_bool(self, key: str, value: bool, no_prefix: bool = False) -> None: - self.__publish(topic=self.get_topic(key, no_prefix), payload=value) + def publish_bool( + self, + key: str, + value: bool, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: + self.__publish( + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos + ) @override - def publish_float(self, key: str, value: float, no_prefix: bool = False) -> None: - self.__publish(topic=self.get_topic(key, no_prefix), payload=value) + def publish_float( + self, + key: str, + value: float, + no_prefix: bool = False, + retain: bool = False, + qos: int = 0, + ) -> None: + self.__publish( + topic=self.get_topic(key, no_prefix), payload=value, retain=retain, qos=qos + ) @override - def clear_topic(self, key: str, no_prefix: bool = False) -> None: - self.__publish(topic=self.get_topic(key, no_prefix), payload=None) + def clear_topic(self, key: str, no_prefix: bool = False, qos: int = 0) -> None: + self.__publish(topic=self.get_topic(key, no_prefix), payload=None, qos=qos) def get_vin_from_topic(self, topic: str) -> str: global_topic_removed = topic[len(self.configuration.mqtt_topic) + 1 :] diff --git a/tests/test_mqtt_publisher.py b/tests/test_mqtt_publisher.py index f8ac595..61022aa 100644 --- a/tests/test_mqtt_publisher.py +++ b/tests/test_mqtt_publisher.py @@ -68,7 +68,7 @@ async def test_update_rear_window_heat_state(self) -> None: assert self.received_payload == REAR_WINDOW_HEAT_STATE async def send_message(self, topic: str, payload: Any) -> None: - await self.mqtt_client.client.on_message("client", topic, payload, 0, {}) + await self.mqtt_client._on_message("client", topic, payload, 0, {}) async def on_charging_detected(self, vin: str) -> None: pass