diff --git a/mqtt/.gitignore b/mqtt/.gitignore new file mode 100644 index 0000000..3b2664d --- /dev/null +++ b/mqtt/.gitignore @@ -0,0 +1,178 @@ +# Databricks-specific Zone +.DS_Store +.python-version + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc diff --git a/mqtt/LICENSE.md b/mqtt/LICENSE.md new file mode 100644 index 0000000..05e6a45 --- /dev/null +++ b/mqtt/LICENSE.md @@ -0,0 +1,24 @@ +# DB license +**Definitions.** + +Agreement: The agreement between Databricks, Inc., and you governing the use of the Databricks Services, as that term is defined in the Master Cloud Services Agreement (MCSA) located at www.databricks.com/legal/mcsa. + +Licensed Materials: The source code, object code, data, and/or other works to which this license applies. + +**Scope of Use.** You may not use the Licensed Materials except in connection with your use of the Databricks Services pursuant to the Agreement. Your use of the Licensed Materials must comply at all times with any restrictions applicable to the Databricks Services, generally, and must be used in accordance with any applicable documentation. You may view, use, copy, modify, publish, and/or distribute the Licensed Materials solely for the purposes of using the Licensed Materials within or connecting to the Databricks Services. If you do not agree to these terms, you may not view, use, copy, modify, publish, and/or distribute the Licensed Materials. + +**Redistribution.** You may redistribute and sublicense the Licensed Materials so long as all use is in compliance with these terms. In addition: + +- You must give any other recipients a copy of this License; +- You must cause any modified files to carry prominent notices stating that you changed the files; +- You must retain, in any derivative works that you distribute, all copyright, patent, trademark, and attribution notices, excluding those notices that do not pertain to any part of the derivative works; and +- If a "NOTICE" text file is provided as part of its distribution, then any derivative works that you distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the derivative works. + + +You may add your own copyright statement to your modifications and may provide additional license terms and conditions for use, reproduction, or distribution of your modifications, or for any such derivative works as a whole, provided your use, reproduction, and distribution of the Licensed Materials otherwise complies with the conditions stated in this License. + +**Termination.** This license terminates automatically upon your breach of these terms or upon the termination of your Agreement. Additionally, Databricks may terminate this license at any time on notice. Upon termination, you must permanently delete the Licensed Materials and all copies thereof. + +**DISCLAIMER; LIMITATION OF LIABILITY.** + +THE LICENSED MATERIALS ARE PROVIDED “AS-IS” AND WITH ALL FAULTS. DATABRICKS, ON BEHALF OF ITSELF AND ITS LICENSORS, SPECIFICALLY DISCLAIMS ALL WARRANTIES RELATING TO THE LICENSED MATERIALS, EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, IMPLIED WARRANTIES, CONDITIONS AND OTHER TERMS OF MERCHANTABILITY, SATISFACTORY QUALITY OR FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. DATABRICKS AND ITS LICENSORS TOTAL AGGREGATE LIABILITY RELATING TO OR ARISING OUT OF YOUR USE OF OR DATABRICKS’ PROVISIONING OF THE LICENSED MATERIALS SHALL BE LIMITED TO ONE THOUSAND ($1,000) DOLLARS. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE LICENSED MATERIALS OR THE USE OR OTHER DEALINGS IN THE LICENSED MATERIALS. \ No newline at end of file diff --git a/mqtt/Makefile b/mqtt/Makefile new file mode 100644 index 0000000..4c5f1b9 --- /dev/null +++ b/mqtt/Makefile @@ -0,0 +1,28 @@ +.PHONY: dev test unit style check + +all: clean style test + +clean: ## Remove build artifacts and cache files + rm -rf build/ + rm -rf dist/ + rm -rf *.egg-info/ + rm -rf htmlcov/ + rm -rf .coverage + rm -rf coverage.xml + rm -rf .pytest_cache/ + rm -rf .mypy_cache/ + rm -rf .ruff_cache/ + find . -type d -name __pycache__ -delete + find . -type f -name "*.pyc" -delete + +test: + pip install -r requirements.txt + pytest . + +dev: + pip install -r requirements.txt + +style: + pre-commit run --all-files + +check: style test \ No newline at end of file diff --git a/mqtt/README.md b/mqtt/README.md new file mode 100644 index 0000000..95c68dc --- /dev/null +++ b/mqtt/README.md @@ -0,0 +1,138 @@ +# MQTT Data Source Connectors for Pyspark +[![Unity Catalog](https://img.shields.io/badge/Unity_Catalog-Enabled-00A1C9?style=for-the-badge)](https://docs.databricks.com/en/data-governance/unity-catalog/index.html) +[![Serverless](https://img.shields.io/badge/Serverless-Compute-00C851?style=for-the-badge)](https://docs.databricks.com/en/compute/serverless.html) +# Databricks Python Data Sources + +Introduced in Spark 4.x, Python Data Source API allows you to create PySpark Data Sources leveraging long standing python libraries for handling unique file types or specialized interfaces with spark read, readStream, write and writeStream APIs. + +| Data Source Name | Purpose | +| --- | --- | +| [MQTT](https://pypi.org/project/paho-mqtt/) | Read MQTT messages from a broker | + +--- + +## Configuration Options + +The MQTT data source supports the following configuration options, which can be set via Spark options or environment variables: + +| Option | Description | Required | Default | +|--------|-------------|----------|---------| +| `broker_address` | Hostname or IP address of the MQTT broker | Yes | - | +| `port` | Port number of the MQTT broker | No | 8883 | +| `username` | Username for broker authentication | No | "" | +| `password` | Password for broker authentication | No | "" | +| `topic` | MQTT topic to subscribe/publish to | No | "#" | +| `qos` | Quality of Service level (0, 1, or 2) | No | 0 | +| `require_tls` | Enable SSL/TLS (true/false) | No | true | +| `keepalive` | Keep alive interval in seconds | No | 60 | +| `clean_session` | Clean session flag (true/false) | No | false | +| `conn_time` | Connection timeout in seconds | No | 1 | +| `ca_certs` | Path to CA certificate file | No | - | +| `certfile` | Path to client certificate file | No | - | +| `keyfile` | Path to client key file | No | - | +| `tls_disable_certs` | Disable certificate verification | No | - | + +You can set these options in your PySpark code, for example: +```python +display( + spark.readStream.format("mqtt_pub_sub") + .option("topic", "#") + .option("broker_address", "host") + .option("username", "secret_user") + .option("password", "secret_password") + .option("qos", 2) + .option("require_tls", False) + .load() +) +``` + +--- + +## Building and Running Tests + +* Clone repo +* Create Virtual environment (Python 3.11) +* Ensure Docker/Podman is installed and properly configured +* Spin up a Docker container for a local MQTT Server: +```yaml +version: "3.7" +services: + mqtt5: + userns_mode: keep-id + image: eclipse-mosquitto + container_name: mqtt5 + ports: + - "1883:1883" # default mqtt port + - "9001:9001" # default mqtt port for websockets + volumes: + - ./config:/mosquitto/config:rw + - ./data:/mosquitto/data:rw + - ./log:/mosquitto/log:rw + restart: unless-stopped +``` + +* Create .env file at the project root directory: +```dotenv +MQTT_BROKER_HOST= +MQTT_BROKER_PORT= +MQTT_USERNAME= +MQTT_PASSWORD= +MQTT_BROKER_TOPIC_PREFIX= +``` + +* Run tests from project root directory +```shell +make test +``` + +* Build package +```shell +python -m build +``` + +--- + +## Example Usage + +```python +spark.dataSource.register(MqttDataSource) + +display( + spark.readStream.format("mqtt_pub_sub") + .option("topic", "#") + .option("broker_address", "host") + .option("username", "secret_user") + .option("password", "secret_password") + .option("qos", 2) + .option("require_tls", False) + .load() +) + +df.writeStream.format("console").start().awaitTermination() +``` + +--- + +## Project Support + +The code in this project is provided **for exploration purposes only** and is **not formally supported** by Databricks under any Service Level Agreements (SLAs). It is provided **AS-IS**, without any warranties or guarantees. + +Please **do not submit support tickets** to Databricks for issues related to the use of this project. + +The source code provided is subject to the Databricks [LICENSE](https://github.com/databricks-industry-solutions/python-data-sources/blob/main/LICENSE.md) . All third-party libraries included or referenced are subject to their respective licenses set forth in the project license. + +Any issues or bugs found should be submitted as **GitHub Issues** on the project repository. While these will be reviewed as time permits, there are **no formal SLAs** for support. + +## 📄 Third-Party Package Licenses + +© 2025 Databricks, Inc. All rights reserved. The source in this project is provided subject to the Databricks License [https://databricks.com/db-license-source]. All included or referenced third party libraries are subject to the licenses set forth below. + +| Datasource | Package | Purpose | License | Source | +| ---------- | ---------- | --------------------------------- | ----------- | ------------------------------------ | +| paho-mqtt | paho-mqtt | Python api for mqtt | EPL-v20 & EDL-v10 | https://pypi.org/project/paho-mqtt/ | + +## References + +- [Paho MQTT Python Client](https://pypi.org/project/paho-mqtt/) +- [Eclipse Mosquitto](https://mosquitto.org/) +- [Databricks Python Data Source API](https://docs.databricks.com/en/data-engineering/data-sources/python-data-sources.html) \ No newline at end of file diff --git a/mqtt/pyproject.toml b/mqtt/pyproject.toml new file mode 100644 index 0000000..63cc447 --- /dev/null +++ b/mqtt/pyproject.toml @@ -0,0 +1,60 @@ +[build-system] +requires = ["hatchling >= 1.26.3", "pyspark >= 4.0.0.dev2", "paho-mqtt >= 2.1.0"] +build-backend = "hatchling.build" + +[project] +name = "python_datasource_connectors" +version = "0.0.2" +authors = [ + { name="Jeffery Annor", email="jeffery.annor@databricks.com" }, + { name="Gaston Guitart", email="gaston.guitart@databricks.com" }, +] +description = "A Collection of Custom Data Source Connectors for PySpark" +requires-python = ">=3.11" +dependencies = [ + "pyspark >= 4.0.0.dev2", + "paho-mqtt >= 2.1.0", +] +classifiers = [ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "pytest-mock>=3.10.0", + "black>=23.0.0", + "isort>=5.12.0", + "flake8>=6.0.0", + "mypy>=1.0.0", + "pre-commit>=3.0.0", + "build>=0.10.0", + "twine>=4.0.0", +] +doc = [ + "sphinx>=6.0.0", + "sphinx-rtd-theme>=1.2.0", +] +test = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "pytest-mock>=3.10.0", +] + +[tool.pytest.ini_options] +addopts = [ + "--import-mode=importlib", +] +pythonpath = [ + "src" +] + +readme = "README.md" +license = "EPL-2.0" +license-files = ["LICEN[CS]E*"] + +[project.urls] +Homepage = "https://github.com/jefferyann-db/python-datasource-connectors/mqtt" +Issues = "https://github.com/jefferyann-db/python-datasource-connectors/issues" \ No newline at end of file diff --git a/mqtt/requirements.txt b/mqtt/requirements.txt new file mode 100644 index 0000000..d173f65 --- /dev/null +++ b/mqtt/requirements.txt @@ -0,0 +1 @@ +paho-mqtt \ No newline at end of file diff --git a/mqtt/src/python_datasource_connectors/__init__.py b/mqtt/src/python_datasource_connectors/__init__.py new file mode 100644 index 0000000..5cb92f7 --- /dev/null +++ b/mqtt/src/python_datasource_connectors/__init__.py @@ -0,0 +1 @@ +from python_datasource_connectors.mqtt_streaming import MqttDataSource \ No newline at end of file diff --git a/mqtt/src/python_datasource_connectors/mqtt_streaming.py b/mqtt/src/python_datasource_connectors/mqtt_streaming.py new file mode 100644 index 0000000..6f944ba --- /dev/null +++ b/mqtt/src/python_datasource_connectors/mqtt_streaming.py @@ -0,0 +1,280 @@ +import datetime +import logging +import random +import subprocess +import sys +import time + +from pyspark.errors import PySparkException +from pyspark.sql.datasource import DataSource, InputPartition, SimpleDataSourceStreamReader +from pyspark.sql.types import StructType, StructField, StringType + +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class RangePartition(InputPartition): + def __init__(self, start, end): + self.start = start + self.end = end + + +class MqttDataSource(DataSource): + """ + A PySpark DataSource for reading MQTT messages from a broker. + + This data source allows you to stream MQTT messages into Spark DataFrames, + supporting various MQTT broker configurations including authentication, + SSL/TLS encryption, and different quality of service levels. + + Supported options: + - broker_address: MQTT broker hostname or IP address (required) + - port: Broker port number (default: 8883) + - username: Authentication username (optional) + - password: Authentication password (optional) + - topic: MQTT topic to subscribe to (default: "#" for all topics) + - qos: Quality of Service level 0-2 (default: 0) + - require_tls: Enable SSL/TLS encryption (default: true) + - keepalive: Keep alive interval in seconds (default: 60) + + Example usage: + spark.readStream.format("mqtt_pub_sub") + .option("broker_address", "mqtt.example.com") + .option("topic", "sensors/+/temperature") + .option("username", "user") + .option("password", "pass") + .load() + """ + + @classmethod + def name(cls): + """Returns the name of the data source.""" + return "mqtt_pub_sub" + + def __init__(self, options): + """ + Initialize the MQTT data source with configuration options. + + Args: + options (dict): Configuration options for the MQTT connection. + See class docstring for supported options. + """ + self.options = options + + def schema(self): + """ + Define the schema of the data source. + + Returns: + StructType: The schema of the data source. + """ + return StructType([ + StructField("received_time", StringType(), True), + StructField("topic", StringType(), True), + StructField("message", StringType(), True), + StructField("is_duplicate", StringType(), True), + StructField("qos", StringType(), True), + StructField("is_retained", StringType(), True) + ]) + + def streamReader(self, schema: StructType): + """ + Create and return a stream reader for MQTT data. + + Args: + schema (StructType): The schema for the streaming data. + + Returns: + MqttSimpleStreamReader: A stream reader instance configured for MQTT. + """ + return MqttSimpleStreamReader(schema, self.options) + + +class MqttSimpleStreamReader(SimpleDataSourceStreamReader): + + def __init__(self, schema, options): + """ + Initialize the MQTT simple stream reader with configuration options. + + Args: + schema (StructType): The schema for the streaming data. + options (dict): Configuration options for the MQTT connection. + See class docstring for supported options. + """ + self._install_paho_mqtt() + super().__init__() + self.topic = self._parse_topic(options.get("topic", "#")) + self.broker_address = options.get("broker_address") + str_tls = options.get("require_tls", True).lower() + self.require_tls = True if str_tls == "true" else False + self.port = int(options.get("port", 8883)) + self.username = options.get("username", "") + self.password = options.get("password", "") + self.qos = int(options.get("qos", 0)) + self.keep_alive = int(options.get("keepalive", 60)) + self.clean_session = options.get("clean_session", False) + self.conn_timeout = int(options.get("conn_time", 1)) + self.clean_sesion = options.get("clean_sesion", False) + self.ca_certs = options.get("ca_certs", None) + self.certfile = options.get("certfile", None) + self.keyfile = options.get("keyfile", None) + self.tls_disable_certs = options.get("tls_disable_certs", None) + if self.clean_sesion not in [True, False]: + raise ValueError(f"Unsupported sesion: {self.clean_sesion}") + self.client_id = f'spark-data-source-mqtt-{random.randint(0, 1000000)}' + self.current = 0 + self.new_data = [] + + def _install_paho_mqtt(self): + try: + import paho.mqtt.client + except ImportError: + logger.warn("Installing paho-mqtt...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "paho-mqtt"]) + # importlib.reload(sys.modules[__name__]) + + def _parse_topic(self, topic_str: str): + """ + TODO: add docs, implement parsing of topic string + """ + return topic_str + + def _configure_tls(self, client): + """ + Configure TLS settings on the MQTT client based on provided certificate options. + """ + if self.require_tls: + # Build tls_set arguments based on provided certificates + tls_args = {} + + if self.ca_certs: + tls_args['ca_certs'] = self.ca_certs + + if self.certfile: + tls_args['certfile'] = self.certfile + + if self.keyfile: + tls_args['keyfile'] = self.keyfile + + # Call tls_set with the appropriate parameters + if tls_args: + client.tls_set(**tls_args) + logger.info(f"TLS configured with certificates: {list(tls_args.keys())}") + else: + # Basic TLS without custom certificates + client.tls_set() + logger.info("Basic TLS enabled") + else: + logger.info("TLS disabled") + + def initialOffset(self): + return {"offset": 0} + + def latestOffset(self) -> dict: + """ + Returns the current latest offset that the next microbatch will read to. + """ + self.current += 1 + return {"offset": self.current} + + def partitions(self, start: dict, end: dict): + + """ + Plans the partitioning of the current microbatch defined by start and end offset. It + needs to return a sequence of :class:`InputPartition` objects. + """ + return [RangePartition(start["offset"], end["offset"])] + + def read(self, partition): + """ + Read MQTT messages from the broker. + + Args: + partition (RangePartition): The partition to read from. + + Returns: + Iterator[list]: An iterator of lists containing the MQTT message data. + The list contains the following elements: + - received_time: The time the message was received. + - topic: The topic of the message. + - message: The payload of the message. + - is_duplicate: Whether the message is a duplicate. + - qos: The quality of service level of the message. + - is_retained: Whether the message is retained. + + Raises: + Exception: If the connection to the broker fails. + """ + import paho.mqtt.client as mqttClient + + def _get_mqtt_client(): + return mqttClient.Client(mqttClient.CallbackAPIVersion.VERSION1, self.client_id, + clean_session=self.clean_session) + + client = _get_mqtt_client() + + # Configure TLS with certificates if provided + self._configure_tls(client) + + client.username_pw_set(self.username, self.password) + + def on_connect(client, userdata, flags, rc): + if rc == 0: + client.subscribe(self.topic, qos=self.qos) + logger.warning(f"Connected to broker {self.broker_address} on port {self.port} with topic {self.topic}") + else: + logger.error(f"Connection failed to broker {self.broker_address} on port {self.port} with topic {self.topic}") + + def on_message(client, userdata, message): + msg_data = [ + str(datetime.datetime.now()), + message.topic, + str(message.payload.decode("utf-8", "ignore")), + message.dup, + message.qos, + message.retain + ] + logger.warning(msg_data) + self.new_data.append(msg_data) + + client.on_connect = on_connect + client.on_message = on_message + + try: + client.connect(self.broker_address, self.port, self.keep_alive) + except Exception as e: + connection_context = { + "broker_address": self.broker_address, + "port": self.port, + "topic": self.topic, + "client_id": self.client_id, + "require_tls": self.require_tls, + "keepalive": self.keep_alive, + "qos": self.qos, + "clean_session": self.clean_session, + "conn_timeout": self.conn_timeout + } + + error_msg = f"Failed to connect to MQTT broker. Connection details: {connection_context}" + logger.exception(error_msg, exc_info=e) + + # Re-raise with enhanced context + raise ConnectionError(error_msg) from e + client.loop_start() # Use loop_start to run the loop in a separate thread + + time.sleep(self.conn_timeout) # Wait for messages for the specified timeout + + client.loop_stop() # Stop the loop after the timeout + client.disconnect() + logger.warning("current state of data: %s", self.new_data) + + return (iter(self.new_data)) + + + + +class MqttSimpleStreamWriter(): + #To be implemented + def __init__(self, schema, options): + pass diff --git a/mqtt/tests/test_mqtt_streaming_pubsub.py b/mqtt/tests/test_mqtt_streaming_pubsub.py new file mode 100644 index 0000000..437498b --- /dev/null +++ b/mqtt/tests/test_mqtt_streaming_pubsub.py @@ -0,0 +1,168 @@ +import ssl + +import pytest +import os + +from pyspark.sql import SparkSession +from paho.mqtt import client as mqtt +import datetime +import time + +from python_datasource_connectors import MqttDataSource + +@pytest.fixture(scope="module") +def spark(): + spark = (SparkSession.builder + .master("local[2]") + .getOrCreate()) + + spark.sparkContext.setLogLevel("WARN") + yield spark + +@pytest.fixture(scope="module") +def mqtt_config(): + return { + "host": os.getenv("MQTT_LOCAL_BROKER_HOST", "localhost"), + "port": int(os.getenv("MQTT_LOCAL_BROKER_PORT", 1883)), + "username": os.getenv("MQTT_LOCAL_USERNAME", "root"), + "password": os.getenv("MQTT_LOCAL_PASSWORD", ""), + "topic_prefix": os.getenv("MQTT_LOCAL_BROKER_TOPIC_PREFIX", "test/pyspark"), + } + +@pytest.fixture(scope="module") +def mqtt_server_config(): + return { + "host": os.getenv("MQTT_REMOTE_BROKER_HOST",""), + "port": int(os.getenv("MQTT_REMOTE_BROKER_PORT", 883)), + "username": os.getenv("MQTT_REMOTE_USERNAME", ""), + "password": os.getenv("MQTT_REMOTE_PASSWORD", ""), + "topic_prefix": os.getenv("MQTT_REMOTE_BROKER_TOPIC_PREFIX", "test/pyspark"), + } + +@pytest.fixture +def mqtt_remote_client(mqtt_server_config): + client = mqtt.Client(callback_api_version=mqtt.CallbackAPIVersion.VERSION1) + + if mqtt_server_config["username"] and mqtt_server_config["password"]: + client.username_pw_set(username=mqtt_server_config["username"], password=mqtt_server_config["password"]) + + client.connect(mqtt_server_config["host"], mqtt_server_config["port"], 60) + sslSettings = ssl.SSLContext(ssl.PROTOCOL_TLS) + client.tls_set_context(sslSettings) + client.loop_start() + yield client + client.loop_stop() + client.disconnect() + +@pytest.fixture +def mqtt_client(mqtt_config): + client = mqtt.Client(callback_api_version=mqtt.CallbackAPIVersion.VERSION1) + + if mqtt_config["username"] and mqtt_config["password"]: + client.username_pw_set(username=mqtt_config["username"], password=mqtt_config["password"]) + + client.connect(mqtt_config["host"], mqtt_config["port"], 60) + client.loop_start() + yield client + client.loop_stop() + client.disconnect() + +def test_hivemq_read_stream(spark, mqtt_remote_client, mqtt_server_config): + """ + This test implements a slightly different logic than the local one. Here we use the "availableNow" trigger, + which will pull whatever is present in the MQTT topic when the streaming query starts. + MQTT will retain only one message per topic. So we send the four messages first, then set up the streaming + query, then expect only the last message to be pulled by our connector. + """ + spark.dataSource.register(MqttDataSource) + # Prepare the Test Messages for HiveMQ Remote Service + test_messages = [ + (mqtt_server_config["topic_prefix"], f"test message at {str(datetime.datetime.now())} and value=0", 2, False), + (mqtt_server_config["topic_prefix"], f"test message at {str(datetime.datetime.now())} and value=1", 2, False), + (mqtt_server_config["topic_prefix"], f"test message at {str(datetime.datetime.now())} and value=2", 2, False), + (mqtt_server_config["topic_prefix"], f"test message at {str(datetime.datetime.now())} and value=3", 2, False), + (mqtt_server_config["topic_prefix"], f"test message at {str(datetime.datetime.now())} and value=4", 2, False), + ] + + + # Publish the test messages to remote server + # Let's retain them so that this new subscriber can catch them + for (topic, payload, qos, is_persisted) in test_messages: + mqtt_remote_client.publish(topic, payload, qos=2, retain=True) + + time.sleep(5) + + # Start the streaming query + query = (spark.readStream + .format("mqtt_pub_sub") + .option("broker_address", mqtt_server_config["host"]) + .option("username", mqtt_server_config["username"]) + .option("port", mqtt_server_config["port"]) + .option("password", mqtt_server_config["password"]) + .option("topic", mqtt_server_config["topic_prefix"]) + .option("qos", 2) + .option("require_tls", True) + .load() + .writeStream + .format("memory") + .trigger(availableNow=True) + .queryName("mqtt_results") + .start() + ) + + time.sleep(5) + # No need to stop the query anymore, since we're using the availableNow trigger + # Assert Results + results = spark.sql("select * from mqtt_results").collect() + # Since we're testing retained messages, we expect only the last one to be pulled by our connector + assert len(results) == 1 + received = {(row.topic, row.message) for row in results} + # Pull only the topic and the payload to perform the assertion, compare against last message sent + expected = set((item[0], item[1]) for item in test_messages[-1:]) + assert received == expected + +def test_mqtt_local_read_stream(spark, mqtt_client, mqtt_config): + spark.dataSource.register(MqttDataSource) + # Prepare Test Messages + test_messages = [ + (mqtt_config["topic_prefix"], f"test message at {str(datetime.datetime.now())} and value=10", 2, False), + (mqtt_config["topic_prefix"], f"test message at {str(datetime.datetime.now())} and value=11", 2, False), + (mqtt_config["topic_prefix"], f"test message at {str(datetime.datetime.now())} and value=12", 2, False), + (mqtt_config["topic_prefix"], f"test message at {str(datetime.datetime.now())} and value=13", 2, False), + (mqtt_config["topic_prefix"], f"test message at {str(datetime.datetime.now())} and value=14", 2, False), + ] + + # Start the streaming query + query = (spark.readStream + .format("mqtt_pub_sub") + .option("broker_address", mqtt_config["host"]) + .option("username", mqtt_config["username"]) + .option("port", mqtt_config["port"]) + .option("password", mqtt_config["password"]) + .option("topic", mqtt_config["topic_prefix"]) + .option("qos", 2) + .option("require_tls", False) + .load() + .writeStream + .format("memory") + .queryName("mqtt_results") + .start() + ) + + time.sleep(5) + + # Publish the test messages + for (topic, payload, qos, is_persisted) in test_messages: + mqtt_client.publish(topic, payload, qos=2) + + time.sleep(10) + + # Assert Results + results = spark.sql("select * from mqtt_results").collect() + query.stop() + + assert len(results) == len(test_messages) + received = {(row.topic, row.message) for row in results} + # Pull only the topic and the payload to perform the assertion + expected = set((item[0], item[1]) for item in test_messages) + assert received == expected \ No newline at end of file