diff --git a/airos/airos8.py b/airos/airos8.py index 085a4f8..c5086c8 100644 --- a/airos/airos8.py +++ b/airos/airos8.py @@ -18,7 +18,7 @@ KeyDataMissingError, ) -logger = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) class AirOS: @@ -101,10 +101,10 @@ async def login(self) -> bool: headers=login_request_headers, ) as response: if response.status == 403: - logger.error("Authentication denied.") + _LOGGER.error("Authentication denied.") raise ConnectionAuthenticationError from None if not response.cookies: - logger.exception("Empty cookies after login, bailing out.") + _LOGGER.exception("Empty cookies after login, bailing out.") raise ConnectionSetupError from None else: for _, morsel in response.cookies.items(): @@ -155,7 +155,7 @@ async def login(self) -> bool: airos_cookie_found = False ok_cookie_found = False if not self.session.cookie_jar: # pragma: no cover - logger.exception( + _LOGGER.exception( "COOKIE JAR IS EMPTY after login POST. This is a major issue." ) raise ConnectionSetupError from None @@ -176,24 +176,24 @@ async def login(self) -> bool: self.connected = True return True except json.JSONDecodeError as err: - logger.exception("JSON Decode Error") + _LOGGER.exception("JSON Decode Error") raise DataMissingError from err else: log = f"Login failed with status {response.status}. Full Response: {response.text}" - logger.error(log) + _LOGGER.error(log) raise ConnectionAuthenticationError from None except ( aiohttp.ClientError, aiohttp.client_exceptions.ConnectionTimeoutError, ) as err: - logger.exception("Error during login") + _LOGGER.exception("Error during login") raise DeviceConnectionError from err async def status(self) -> AirOSData: """Retrieve status from the device.""" if not self.connected: - logger.error("Not connected, login first") + _LOGGER.error("Not connected, login first") raise DeviceConnectionError from None # --- Step 2: Verify authenticated access by fetching status.cgi --- @@ -213,32 +213,32 @@ async def status(self) -> AirOSData: try: airos_data = AirOSData.from_dict(response_json) except (MissingField, InvalidFieldValue) as err: - logger.exception("Failed to deserialize AirOS data") + _LOGGER.exception("Failed to deserialize AirOS data") raise KeyDataMissingError from err return airos_data except json.JSONDecodeError: - logger.exception( + _LOGGER.exception( "JSON Decode Error in authenticated status response" ) raise DataMissingError from None else: log = f"Authenticated status.cgi failed: {response.status}. Response: {response_text}" - logger.error(log) + _LOGGER.error(log) except ( aiohttp.ClientError, aiohttp.client_exceptions.ConnectionTimeoutError, ) as err: - logger.exception("Error during authenticated status.cgi call") + _LOGGER.exception("Error during authenticated status.cgi call") raise DeviceConnectionError from err async def stakick(self, mac_address: str = None) -> bool: """Reconnect client station.""" if not self.connected: - logger.error("Not connected, login first") + _LOGGER.error("Not connected, login first") raise DeviceConnectionError from None if not mac_address: - logger.error("Device mac-address missing") + _LOGGER.error("Device mac-address missing") raise DataMissingError from None kick_request_headers = {**self._common_headers} @@ -262,11 +262,11 @@ async def stakick(self, mac_address: str = None) -> bool: return True response_text = await response.text() log = f"Unable to restart connection response status {response.status} with {response_text}" - logger.error(log) + _LOGGER.error(log) return False except ( aiohttp.ClientError, aiohttp.client_exceptions.ConnectionTimeoutError, ) as err: - logger.exception("Error during reconnect stakick.cgi call") + _LOGGER.exception("Error during reconnect stakick.cgi call") raise DeviceConnectionError from err diff --git a/airos/discovery.py b/airos/discovery.py new file mode 100644 index 0000000..5f85a05 --- /dev/null +++ b/airos/discovery.py @@ -0,0 +1,275 @@ +"""Discover Ubiquiti UISP airOS device broadcasts.""" + +import asyncio +from collections.abc import Callable +import logging +import socket +import struct +from typing import Any + +from .exceptions import AirosDiscoveryError, AirosEndpointError, AirosListenerError + +_LOGGER = logging.getLogger(__name__) + +DISCOVERY_PORT: int = 10002 +BUFFER_SIZE: int = 1024 + + +class AirosDiscoveryProtocol(asyncio.DatagramProtocol): + """A UDP protocol implementation for discovering Ubiquiti airOS devices. + + This class listens for UDP broadcast announcements from airOS devices + on a specific port (10002) and parses the proprietary packet format + to extract device information. It acts as the low-level listener. + + Attributes: + callback: An asynchronous callable that will be invoked with + the parsed device information upon discovery. + transport: The UDP transport layer object, set once the connection is made. + + """ + + def __init__(self, callback: Callable[[dict[str, Any]], None]) -> None: + """Initialize AirosDiscoveryProtocol. + + Args: + callback: An asynchronous function to call when a device is discovered. + It should accept a dictionary containing device information. + + """ + self.callback = callback + self.transport: asyncio.DatagramTransport | None = None + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """Set up the UDP socket for broadcasting and reusing the address.""" + self.transport = transport # type: ignore[assignment] # transport is DatagramTransport + sock: socket.socket = self.transport.get_extra_info("socket") + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + log = f"Airos discovery listener (low-level) started on UDP port {DISCOVERY_PORT}." + _LOGGER.debug(log) + + def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: + """Parse the received UDP packet and, if successful, schedules the callback. + + Errors during parsing are logged internally by parse_airos_packet. + """ + host_ip: str = addr[0] + try: + parsed_data: dict[str, Any] | None = self.parse_airos_packet(data, host_ip) + if parsed_data: + # Schedule the user-provided callback, don't await to keep listener responsive + asyncio.create_task(self.callback(parsed_data)) # noqa: RUF006 + except (AirosEndpointError, AirosListenerError) as err: + # These are expected types of malformed packets. Log the specific error + # and then re-raise as AirosDiscoveryError. + log = f"Parsing failed for packet from {host_ip}: {err}" + _LOGGER.exception(log) + raise AirosDiscoveryError(f"Malformed packet from {host_ip}") from err + except Exception as err: + # General error during datagram reception (e.g., in callback itself) + log = f"Error processing Airos discovery packet from {host_ip}. Data hex: {data.hex()}" + _LOGGER.exception(log) + raise AirosDiscoveryError from err + + def error_received(self, exc: Exception | None) -> None: + """Handle send or receive operation raises an OSError.""" + if exc: + log = f"UDP error received in AirosDiscoveryProtocol: {exc}" + _LOGGER.error(log) + + def connection_lost(self, exc: Exception | None) -> None: + """Handle connection is lost or closed.""" + _LOGGER.debug("AirosDiscoveryProtocol connection lost.") + if exc: + _LOGGER.exception("AirosDiscoveryProtocol connection lost due to") + raise AirosDiscoveryError from None + + def parse_airos_packet(self, data: bytes, host_ip: str) -> dict[str, Any] | None: + """Parse a raw airOS discovery UDP packet. + + This method extracts various pieces of information from the proprietary + Ubiquiti airOS discovery packet format, which includes a fixed header + followed by a series of Type-Length-Value (TLV) entries. Different + TLV types use different length encoding schemes (fixed, 1-byte, 2-byte). + + Args: + data: The raw byte data of the UDP packet payload. + host_ip: The IP address of the sender, used as a fallback or initial IP. + + Returns: + A dictionary containing parsed device information if successful, + otherwise None. Values will be None if not found or cannot be parsed. + + """ + parsed_info: dict[str, str | int | None] = { + "ip_address": host_ip, + "mac_address": None, + "hostname": None, + "model": None, + "firmware_version": None, + "uptime_seconds": None, + "ssid": None, + "full_model_name": None, + } + + # --- Fixed Header (6 bytes) --- + if len(data) < 6: + log = f"Packet too short for initial fixed header. Length: {len(data)}. Data: {data.hex()}" + _LOGGER.debug(log) + raise AirosEndpointError(f"Malformed packet: {log}") + + if data[0] != 0x01 or data[1] != 0x06: + log = f"Packet does not start with expected Airos header (0x01 0x06). Actual: {data[0:2].hex()}" + _LOGGER.debug(log) + raise AirosEndpointError(f"Malformed packet: {log}") + + offset: int = 6 + + # --- Main TLV Parsing Loop --- + try: + while offset < len(data): + if (len(data) - offset) < 1: + log = f"Not enough bytes for next TLV type. Remaining: {data[offset:].hex()}" + _LOGGER.debug(log) + break + + tlv_type: int = data[offset] + offset += 1 + + if tlv_type == 0x06: # Device MAC Address (fixed 6-byte value) + expected_length: int = 6 + if (len(data) - offset) >= expected_length: + mac_bytes: bytes = data[offset : offset + expected_length] + parsed_info["mac_address"] = ":".join( + f"{b:02x}" for b in mac_bytes + ).upper() + offset += expected_length + log = f"Parsed MAC from type 0x06: {parsed_info['mac_address']}" + _LOGGER.debug(log) + else: + log = f"Truncated MAC address TLV (Type 0x06). Expected {expected_length}, got {len(data) - offset} bytes. Remaining: {data[offset:].hex()}" + _LOGGER.warning(log) + log = f"Malformed packet: {log}" + raise AirosEndpointError(log) + + elif tlv_type in [ + 0x02, + 0x03, + 0x0A, + 0x0B, + 0x0C, + 0x0D, + 0x0E, + 0x10, + 0x14, + 0x18, + ]: + if (len(data) - offset) < 2: + log = f"Truncated TLV (Type {tlv_type:#x}), no 2-byte length field. Remaining: {data[offset:].hex()}" + _LOGGER.warning(log) + log = f"Malformed packet: {log}" + raise AirosEndpointError(log) + + tlv_length: int = struct.unpack_from(">H", data, offset)[0] + offset += 2 + + if tlv_length > (len(data) - offset): + log = f"TLV type {tlv_type:#x} length {tlv_length} exceeds remaining data " + _LOGGER.warning(log) + log = f"({len(data) - offset} bytes left). Packet malformed. " + _LOGGER.warning(log) + log = f"Data from TLV start: {data[offset - 3 :].hex()}" + _LOGGER.warning(log) + log = f"Malformed packet: {log}" + raise AirosEndpointError(log) + + tlv_value: bytes = data[offset : offset + tlv_length] + + if tlv_type == 0x02: + if tlv_length == 10: + ip_bytes: bytes = tlv_value[6:10] + parsed_info["ip_address"] = ".".join(map(str, ip_bytes)) + log = f"Parsed IP from type 0x02 block: {parsed_info['ip_address']}" + _LOGGER.debug(log) + else: + log = f"Unexpected length for 0x02 TLV (MAC+IP). Expected 10, got {tlv_length}. Value: {tlv_value.hex()}" + _LOGGER.warning(log) + + elif tlv_type == 0x03: + parsed_info["firmware_version"] = tlv_value.decode( + "ascii", errors="ignore" + ) + log = f"Parsed Firmware: {parsed_info['firmware_version']}" + _LOGGER.debug(log) + + elif tlv_type == 0x0A: + if tlv_length == 4: + parsed_info["uptime_seconds"] = struct.unpack( + ">I", tlv_value + )[0] + log = f"Parsed Uptime: {parsed_info['uptime_seconds']}s" + _LOGGER.debug(log) + else: + log = f"Unexpected length for Uptime (Type 0x0A): {tlv_length}. Value: {tlv_value.hex()}" + _LOGGER.warning(log) + + elif tlv_type == 0x0B: + parsed_info["hostname"] = tlv_value.decode( + "utf-8", errors="ignore" + ) + log = f"Parsed Hostname: {parsed_info['hostname']}" + _LOGGER.debug(log) + + elif tlv_type == 0x0C: + parsed_info["model"] = tlv_value.decode( + "ascii", errors="ignore" + ) + log = f"Parsed Model: {parsed_info['model']}" + _LOGGER.debug(log) + + elif tlv_type == 0x0D: + parsed_info["ssid"] = tlv_value.decode("utf-8", errors="ignore") + log = f"Parsed SSID: {parsed_info['ssid']}" + _LOGGER.debug(log) + + elif tlv_type == 0x14: + parsed_info["full_model_name"] = tlv_value.decode( + "utf-8", errors="ignore" + ) + log = ( + f"Parsed Full Model Name: {parsed_info['full_model_name']}" + ) + _LOGGER.debug(log) + + elif tlv_type == 0x18: + if tlv_length == 4 and tlv_value == b"\x00\x00\x00\x00": + _LOGGER.debug("Detected end marker (Type 0x18).") + else: + log = f"Unhandled TLV type: {tlv_type:#x} with length {tlv_length}. Value: {tlv_value.hex()}" + _LOGGER.debug(log) + elif tlv_type in [0x0E, 0x10]: + log = f"Unhandled TLV type: {tlv_type:#x} with length {tlv_length}. Value: {tlv_value.hex()}" + _LOGGER.debug(log) + + offset += tlv_length + + else: + log = f"Unhandled TLV type: {tlv_type:#x} at offset {offset - 1}. " + log += f"Cannot determine length, stopping parsing. Remaining: {data[offset - 1 :].hex()}" + _LOGGER.warning(log) + log = f"Malformed packet: {log}" + raise AirosEndpointError(log) + + except (struct.error, IndexError) as err: + log = f"Parsing error (struct/index) in AirosDiscoveryProtocol: {err} at offset {offset}. Remaining data: {data[offset:].hex()}" + _LOGGER.debug(log) + log = f"Malformed packet: {log}" + raise AirosEndpointError(log) from err + except AirosEndpointError: # Catch AirosEndpointError specifically, re-raise it + raise + except Exception as err: + _LOGGER.exception("Unexpected error during Airos packet parsing") + raise AirosListenerError from err + + return parsed_info diff --git a/airos/exceptions.py b/airos/exceptions.py index ffcc3b8..ac199e7 100644 --- a/airos/exceptions.py +++ b/airos/exceptions.py @@ -23,3 +23,15 @@ class KeyDataMissingError(AirOSException): class DeviceConnectionError(AirOSException): """Raised when unable to connect.""" + + +class AirosDiscoveryError(AirOSException): + """Base exception for Airos discovery issues.""" + + +class AirosListenerError(AirosDiscoveryError): + """Raised when the Airos listener encounters an error.""" + + +class AirosEndpointError(AirosDiscoveryError): + """Raised when there's an issue with the network endpoint.""" diff --git a/fixtures/airos_sta_discovery_packet.bin b/fixtures/airos_sta_discovery_packet.bin new file mode 100644 index 0000000..dad4c6b Binary files /dev/null and b/fixtures/airos_sta_discovery_packet.bin differ diff --git a/pyproject.toml b/pyproject.toml index e8b92d1..66800dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "airos" -version = "0.1.8" +version = "0.2.0" license = "MIT" description = "Ubiquity airOS module(s) for Python 3." readme = "README.md" diff --git a/script/generate_discovery_fixture.py b/script/generate_discovery_fixture.py new file mode 100644 index 0000000..4ca90d0 --- /dev/null +++ b/script/generate_discovery_fixture.py @@ -0,0 +1,103 @@ +"""Generate mock discovery packet for testing.""" + +import logging +import os +import socket +import struct + +_LOGGER = logging.getLogger(__name__) + +# Define the path to save the fixture +fixture_dir = os.path.join(os.path.dirname(__file__), "../fixtures") +os.makedirs(fixture_dir, exist_ok=True) # Ensure the directory exists +fixture_path = os.path.join(fixture_dir, "airos_sta_discovery_packet.bin") + +# Header: 0x01 0x06 (2 bytes) + 4 reserved bytes = 6 bytes +HEADER = b"\x01\x06\x00\x00\x00\x00" + +# --- Scrubbed Values --- +SCRUBBED_MAC = "01:23:45:67:89:CD" +SCRUBBED_MAC_BYTES = bytes.fromhex(SCRUBBED_MAC.replace(":", "")) +SCRUBBED_IP = "192.168.1.3" +SCRUBBED_IP_BYTES = socket.inet_aton(SCRUBBED_IP) +SCRUBBED_HOSTNAME = "name" +SCRUBBED_HOSTNAME_BYTES = SCRUBBED_HOSTNAME.encode("utf-8") + +# --- Values from provided "schuur" JSON (not scrubbed) --- +FIRMWARE_VERSION = "WA.V8.7.17" +FIRMWARE_VERSION_BYTES = FIRMWARE_VERSION.encode("ascii") +UPTIME_SECONDS = 265375 +MODEL = "NanoStation 5AC loco" +MODEL_BYTES = MODEL.encode("ascii") +SSID = "DemoSSID" +SSID_BYTES = SSID.encode("utf-8") +FULL_MODEL_NAME = ( + "NanoStation 5AC loco" # Using the same as Model, as is often the case +) +FULL_MODEL_NAME_BYTES = FULL_MODEL_NAME.encode("utf-8") + +# TLV Type 0x06: MAC Address (fixed 6-byte value) +TLV_MAC_TYPE = b"\x06" +TLV_MAC = TLV_MAC_TYPE + SCRUBBED_MAC_BYTES + +# TLV Type 0x02: MAC + IP Address (10 bytes value, with 2-byte length field) +# Value contains first 6 bytes often MAC, last 4 bytes IP +TLV_IP_TYPE = b"\x02" +TLV_IP_VALUE = ( + SCRUBBED_MAC_BYTES + SCRUBBED_IP_BYTES +) # 6 bytes MAC + 4 bytes IP = 10 bytes +TLV_IP_LENGTH = len(TLV_IP_VALUE).to_bytes(2, "big") +TLV_IP = TLV_IP_TYPE + TLV_IP_LENGTH + TLV_IP_VALUE + +# TLV Type 0x03: Firmware Version (variable length string) +TLV_FW_TYPE = b"\x03" +TLV_FW_LENGTH = len(FIRMWARE_VERSION_BYTES).to_bytes(2, "big") +TLV_FW = TLV_FW_TYPE + TLV_FW_LENGTH + FIRMWARE_VERSION_BYTES + +# TLV Type 0x0A: Uptime (4-byte integer) +TLV_UPTIME_TYPE = b"\x0a" +TLV_UPTIME_VALUE = struct.pack(">I", UPTIME_SECONDS) # Unsigned int, big-endian +TLV_UPTIME_LENGTH = len(TLV_UPTIME_VALUE).to_bytes(2, "big") +TLV_UPTIME = TLV_UPTIME_TYPE + TLV_UPTIME_LENGTH + TLV_UPTIME_VALUE + +# TLV Type 0x0B: Hostname (variable length string) +TLV_HOSTNAME_TYPE = b"\x0b" +TLV_HOSTNAME_LENGTH = len(SCRUBBED_HOSTNAME_BYTES).to_bytes(2, "big") +TLV_HOSTNAME = TLV_HOSTNAME_TYPE + TLV_HOSTNAME_LENGTH + SCRUBBED_HOSTNAME_BYTES + +# TLV Type 0x0C: Model (variable length string) +TLV_MODEL_TYPE = b"\x0c" +TLV_MODEL_LENGTH = len(MODEL_BYTES).to_bytes(2, "big") +TLV_MODEL = TLV_MODEL_TYPE + TLV_MODEL_LENGTH + MODEL_BYTES + +# TLV Type 0x0D: SSID (variable length string) +TLV_SSID_TYPE = b"\x0d" +TLV_SSID_LENGTH = len(SSID_BYTES).to_bytes(2, "big") +TLV_SSID = TLV_SSID_TYPE + TLV_SSID_LENGTH + SSID_BYTES + +# TLV Type 0x14: Full Model Name (variable length string) +TLV_FULL_MODEL_TYPE = b"\x14" +TLV_FULL_MODEL_LENGTH = len(FULL_MODEL_NAME_BYTES).to_bytes(2, "big") +TLV_FULL_MODEL = TLV_FULL_MODEL_TYPE + TLV_FULL_MODEL_LENGTH + FULL_MODEL_NAME_BYTES + +# Combine all parts +FULL_PACKET = ( + HEADER + + TLV_MAC + + TLV_IP + + TLV_FW + + TLV_UPTIME + + TLV_HOSTNAME + + TLV_MODEL + + TLV_SSID + + TLV_FULL_MODEL +) + +# Write the actual binary file +with open(fixture_path, "wb") as f: + f.write(FULL_PACKET) + +log = f"Generated discovery packet fixture at: {fixture_path}" +log += f"Packet length: {len(FULL_PACKET)} bytes" +log += f"Packet hex: {FULL_PACKET.hex()}" +_LOGGER.info(log) diff --git a/tests/test_discovery.py b/tests/test_discovery.py new file mode 100644 index 0000000..5d22e46 --- /dev/null +++ b/tests/test_discovery.py @@ -0,0 +1,207 @@ +"""Test discovery of Ubiquiti airOS devices.""" + +import asyncio +import os +import socket # Add this import +from unittest.mock import AsyncMock, MagicMock, patch + +from airos.discovery import DISCOVERY_PORT, AirosDiscoveryProtocol +from airos.exceptions import AirosDiscoveryError, AirosEndpointError +import pytest + + +# Helper to load binary fixture +async def _read_binary_fixture(fixture_name: str) -> bytes: + """Read a binary fixture file.""" + fixture_dir = os.path.join(os.path.dirname(__file__), "../fixtures") + path = os.path.join(fixture_dir, fixture_name) + try: + + def _read_file(): + with open(path, "rb") as f: + return f.read() + + return await asyncio.to_thread(_read_file) + except FileNotFoundError: + pytest.fail(f"Fixture file not found: {path}") + except Exception as e: + pytest.fail(f"Error reading fixture file {path}: {e}") + + +@pytest.fixture +async def mock_airos_packet() -> bytes: + """Fixture for a valid airos discovery packet with scrubbed data.""" + return await _read_binary_fixture("airos_sta_discovery_packet.bin") + + +@pytest.mark.asyncio +async def test_parse_airos_packet_success(mock_airos_packet): + """Test parse_airos_packet with a valid packet containing scrubbed data.""" + protocol = AirosDiscoveryProtocol( + AsyncMock() + ) # Callback won't be called directly in this unit test + host_ip = ( + "192.168.1.3" # The IP address from the packet sender (as per scrubbed data) + ) + + # Directly call the parsing method + parsed_data = protocol.parse_airos_packet(mock_airos_packet, host_ip) + + assert parsed_data is not None + assert parsed_data["ip_address"] == "192.168.1.3" + assert parsed_data["mac_address"] == "01:23:45:67:89:CD" # Expected scrubbed MAC + assert parsed_data["hostname"] == "name" # Expected scrubbed hostname + assert parsed_data["model"] == "NanoStation 5AC loco" + assert parsed_data["firmware_version"] == "WA.V8.7.17" + assert parsed_data["uptime_seconds"] == 265375 + assert parsed_data["ssid"] == "DemoSSID" + assert parsed_data["full_model_name"] == "NanoStation 5AC loco" + + +@pytest.mark.asyncio +async def test_parse_airos_packet_invalid_header(): + """Test parse_airos_packet with an invalid header.""" + protocol = AirosDiscoveryProtocol(AsyncMock()) + invalid_data = b"\x00\x00\x00\x00\x00\x00" + b"someotherdata" + host_ip = "192.168.1.100" + + # Patch the _LOGGER.debug to verify the log message + with patch("airos.discovery._LOGGER.debug") as mock_log_debug: + with pytest.raises(AirosEndpointError): + protocol.parse_airos_packet(invalid_data, host_ip) + mock_log_debug.assert_called_once() + assert ( + "does not start with expected Airos header" + in mock_log_debug.call_args[0][0] + ) + + +@pytest.mark.asyncio +async def test_parse_airos_packet_too_short(): + """Test parse_airos_packet with data too short for header.""" + protocol = AirosDiscoveryProtocol(AsyncMock()) + too_short_data = b"\x01\x06\x00" + host_ip = "192.168.1.100" + + # Patch the _LOGGER.debug to verify the log message + with patch("airos.discovery._LOGGER.debug") as mock_log_debug: + with pytest.raises(AirosEndpointError): + protocol.parse_airos_packet(too_short_data, host_ip) + mock_log_debug.assert_called_once() + assert ( + "Packet too short for initial fixed header" + in mock_log_debug.call_args[0][0] + ) + + +@pytest.mark.asyncio +async def test_parse_airos_packet_truncated_tlv(): + """Test parse_airos_packet with a truncated TLV.""" + protocol = AirosDiscoveryProtocol(AsyncMock()) + # Header + MAC TLV (valid) + then a truncated TLV_IP + truncated_data = ( + b"\x01\x06\x00\x00\x00\x00" # Header + + b"\x06" + + bytes.fromhex("0123456789CD") # Valid MAC (scrubbed) + + b"\x02\x00" # TLV type 0x02, followed by only 1 byte for length (should be 2) + ) + host_ip = "192.168.1.100" + + # Expect AirosEndpointError due to struct.error or IndexError + with pytest.raises(AirosEndpointError): + protocol.parse_airos_packet(truncated_data, host_ip) + + +@pytest.mark.asyncio +async def test_datagram_received_calls_callback(mock_airos_packet): + """Test that datagram_received correctly calls the callback.""" + mock_callback = AsyncMock() + protocol = AirosDiscoveryProtocol(mock_callback) + host_ip = "192.168.1.3" # Sender IP + + with patch("asyncio.create_task") as mock_create_task: + protocol.datagram_received(mock_airos_packet, (host_ip, DISCOVERY_PORT)) + + # Verify the task was created and get the coroutine + mock_create_task.assert_called_once() + task_coro = mock_create_task.call_args[0][0] + + # Manually await the coroutine to test the callback + await task_coro + + mock_callback.assert_called_once() + called_args, _ = mock_callback.call_args + parsed_data = called_args[0] + assert parsed_data["ip_address"] == "192.168.1.3" + assert parsed_data["mac_address"] == "01:23:45:67:89:CD" # Verify scrubbed MAC + + +@pytest.mark.asyncio +async def test_datagram_received_handles_parsing_error(): + """Test datagram_received handles exceptions during parsing.""" + mock_callback = AsyncMock() + protocol = AirosDiscoveryProtocol(mock_callback) + invalid_data = b"\x00\x00" # Too short, will cause parsing error + host_ip = "192.168.1.100" + + with patch("airos.discovery._LOGGER.exception") as mock_log_exception: + # datagram_received catches errors internally and re-raises AirosDiscoveryError + with pytest.raises(AirosDiscoveryError): + protocol.datagram_received(invalid_data, (host_ip, DISCOVERY_PORT)) + mock_callback.assert_not_called() + mock_log_exception.assert_called_once() # Ensure exception is logged + + +@pytest.mark.asyncio +async def test_connection_made_sets_transport(): + """Test connection_made sets up transport and socket options.""" + protocol = AirosDiscoveryProtocol(AsyncMock()) + mock_transport = MagicMock(spec=asyncio.DatagramTransport) + mock_sock = MagicMock(spec=socket.socket) # Corrected: socket import added + mock_transport.get_extra_info.return_value = mock_sock + + with patch("airos.discovery._LOGGER.debug") as mock_log_debug: + protocol.connection_made(mock_transport) + + assert protocol.transport is mock_transport + mock_sock.setsockopt.assert_any_call(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + mock_sock.setsockopt.assert_any_call(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + mock_log_debug.assert_called_once() + + +@pytest.mark.asyncio +async def test_connection_lost_without_exception(): + """Test connection_lost without an exception.""" + protocol = AirosDiscoveryProtocol(AsyncMock()) + with patch("airos.discovery._LOGGER.debug") as mock_log_debug: + protocol.connection_lost(None) + mock_log_debug.assert_called_once_with( + "AirosDiscoveryProtocol connection lost." + ) + + +@pytest.mark.asyncio +async def test_connection_lost_with_exception(): + """Test connection_lost with an exception.""" + protocol = AirosDiscoveryProtocol(AsyncMock()) + test_exception = Exception("Test connection lost error") + with ( + patch("airos.discovery._LOGGER.exception") as mock_log_exception, + pytest.raises( + AirosDiscoveryError + ), # connection_lost now re-raises AirosDiscoveryError + ): + protocol.connection_lost(test_exception) + mock_log_exception.assert_called_once() + + +@pytest.mark.asyncio +async def test_error_received(): + """Test error_received logs the error.""" + protocol = AirosDiscoveryProtocol(AsyncMock()) + test_exception = Exception("Test network error") + with patch("airos.discovery._LOGGER.error") as mock_log_error: + protocol.error_received(test_exception) + mock_log_error.assert_called_once_with( + f"UDP error received in AirosDiscoveryProtocol: {test_exception}" + )