Skip to content

Commit 1ccab3a

Browse files
committed
Improve DRY and add unsupported errors
1 parent 994c78e commit 1ccab3a

File tree

8 files changed

+113
-363
lines changed

8 files changed

+113
-363
lines changed

airos/airos6.py

Lines changed: 40 additions & 221 deletions
Original file line numberDiff line numberDiff line change
@@ -2,83 +2,28 @@
22

33
from __future__ import annotations
44

5-
import asyncio
6-
from http.cookies import SimpleCookie
7-
import json
85
import logging
96
from typing import Any
10-
from urllib.parse import urlparse
11-
12-
import aiohttp
13-
from mashumaro.exceptions import InvalidFieldValue, MissingField
14-
15-
from .data import (
16-
AirOS6Data as AirOSData,
17-
DerivedWirelessMode,
18-
DerivedWirelessRole,
19-
redact_data_smart,
20-
)
21-
from .exceptions import (
22-
AirOSConnectionAuthenticationError,
23-
AirOSConnectionSetupError,
24-
AirOSDataMissingError,
25-
AirOSDeviceConnectionError,
26-
AirOSKeyDataMissingError,
27-
)
7+
8+
from .airos8 import AirOS
9+
from .data import AirOS6Data, DerivedWirelessRole
10+
from .exceptions import AirOSNotSupportedError
2811

2912
_LOGGER = logging.getLogger(__name__)
3013

3114

32-
class AirOS:
15+
class AirOS6(AirOS):
3316
"""AirOS 6 connection class."""
3417

35-
def __init__(
36-
self,
37-
host: str,
38-
username: str,
39-
password: str,
40-
session: aiohttp.ClientSession,
41-
use_ssl: bool = True,
42-
):
43-
"""Initialize AirOS6 class."""
44-
self.username = username
45-
self.password = password
46-
47-
parsed_host = urlparse(host)
48-
scheme = (
49-
parsed_host.scheme
50-
if parsed_host.scheme
51-
else ("https" if use_ssl else "http")
52-
)
53-
hostname = parsed_host.hostname if parsed_host.hostname else host
54-
55-
self.base_url = f"{scheme}://{hostname}"
56-
57-
self.session = session
58-
59-
self._login_url = f"{self.base_url}/api/auth"
60-
self._status_cgi_url = f"{self.base_url}/status.cgi"
61-
self.current_csrf_token: str | None = None
62-
63-
self._use_json_for_login_post = False
64-
65-
self._auth_cookie: str | None = None
66-
self._csrf_id: str | None = None
67-
self.connected: bool = False
18+
data_model = AirOS6Data
6819

6920
@staticmethod
70-
def derived_data(response: dict[str, Any]) -> dict[str, Any]:
71-
"""Add derived data to the device response."""
72-
derived: dict[str, Any] = {
73-
"station": False,
74-
"access_point": False,
75-
"ptp": False,
76-
"ptmp": False,
77-
"role": DerivedWirelessRole.STATION,
78-
"mode": DerivedWirelessMode.PTP,
79-
}
80-
21+
def derived_wireless_data(
22+
derived: dict[str, Any], response: dict[str, Any]
23+
) -> dict[str, Any]:
24+
"""Add derived wireless data to the device response."""
8125
# Access Point / Station - no info on ptp/ptmp
26+
# assuming ptp for station mode
8227
derived["ptp"] = True
8328
wireless_mode = response.get("wireless", {}).get("mode", "")
8429
match wireless_mode:
@@ -88,158 +33,32 @@ def derived_data(response: dict[str, Any]) -> dict[str, Any]:
8833
case "sta":
8934
derived["station"] = True
9035

91-
# INTERFACES
92-
addresses = {}
93-
interface_order = ["br0", "eth0", "ath0"]
94-
95-
interfaces = response.get("interfaces", [])
96-
97-
# No interfaces, no mac, no usability
98-
if not interfaces:
99-
_LOGGER.error("Failed to determine interfaces from AirOS data")
100-
raise AirOSKeyDataMissingError from None
101-
102-
for interface in interfaces:
103-
if interface["enabled"]: # Only consider if enabled
104-
addresses[interface["ifname"]] = interface["hwaddr"]
105-
106-
# Fallback take fist alternate interface found
107-
derived["mac"] = interfaces[0]["hwaddr"]
108-
derived["mac_interface"] = interfaces[0]["ifname"]
109-
110-
for interface in interface_order:
111-
if interface in addresses:
112-
derived["mac"] = addresses[interface]
113-
derived["mac_interface"] = interface
114-
break
115-
116-
response["derived"] = derived
117-
118-
return response
119-
120-
def _get_authenticated_headers(
121-
self,
122-
ct_json: bool = False,
123-
ct_form: bool = False,
124-
) -> dict[str, str]:
125-
"""Construct headers for an authenticated request."""
126-
headers = {}
127-
if ct_json:
128-
headers["Content-Type"] = "application/json"
129-
elif ct_form:
130-
headers["Content-Type"] = "application/x-www-form-urlencoded"
131-
132-
if self._csrf_id:
133-
headers["X-CSRF-ID"] = self._csrf_id
134-
135-
if self._auth_cookie:
136-
headers["Cookie"] = f"AIROS_{self._auth_cookie}"
137-
138-
return headers
139-
140-
def _store_auth_data(self, response: aiohttp.ClientResponse) -> None:
141-
"""Parse the response from a successful login and store auth data."""
142-
self._csrf_id = response.headers.get("X-CSRF-ID")
143-
144-
# Parse all Set-Cookie headers to ensure we don't miss AIROS_* cookie
145-
cookie = SimpleCookie()
146-
for set_cookie in response.headers.getall("Set-Cookie", []):
147-
cookie.load(set_cookie)
148-
for key, morsel in cookie.items():
149-
if key.startswith("AIROS_"):
150-
self._auth_cookie = morsel.key[6:] + "=" + morsel.value
151-
break
152-
153-
async def _request_json(
154-
self,
155-
method: str,
156-
url: str,
157-
headers: dict[str, Any] | None = None,
158-
json_data: dict[str, Any] | None = None,
159-
form_data: dict[str, Any] | None = None,
160-
authenticated: bool = False,
161-
ct_json: bool = False,
162-
ct_form: bool = False,
163-
) -> dict[str, Any] | Any:
164-
"""Make an authenticated API request and return JSON response."""
165-
# Pass the content type flags to the header builder
166-
request_headers = (
167-
self._get_authenticated_headers(ct_json=ct_json, ct_form=ct_form)
168-
if authenticated
169-
else {}
170-
)
171-
if headers:
172-
request_headers.update(headers)
173-
174-
try:
175-
if url != self._login_url and not self.connected:
176-
_LOGGER.error("Not connected, login first")
177-
raise AirOSDeviceConnectionError from None
178-
179-
async with self.session.request(
180-
method,
181-
url,
182-
json=json_data,
183-
data=form_data,
184-
headers=request_headers, # Pass the constructed headers
185-
) as response:
186-
response.raise_for_status()
187-
response_text = await response.text()
188-
_LOGGER.debug("Successfully fetched JSON from %s", url)
189-
190-
# If this is the login request, we need to store the new auth data
191-
if url == self._login_url:
192-
self._store_auth_data(response)
193-
self.connected = True
194-
195-
return json.loads(response_text)
196-
except aiohttp.ClientResponseError as err:
197-
_LOGGER.error(
198-
"Request to %s failed with status %s: %s", url, err.status, err.message
199-
)
200-
if err.status == 401:
201-
raise AirOSConnectionAuthenticationError from err
202-
raise AirOSConnectionSetupError from err
203-
except (TimeoutError, aiohttp.ClientError) as err:
204-
_LOGGER.exception("Error during API call to %s", url)
205-
raise AirOSDeviceConnectionError from err
206-
except json.JSONDecodeError as err:
207-
_LOGGER.error("Failed to decode JSON from %s", url)
208-
raise AirOSDataMissingError from err
209-
except asyncio.CancelledError:
210-
_LOGGER.warning("Request to %s was cancelled", url)
211-
raise
212-
213-
async def login(self) -> None:
214-
"""Login to AirOS device."""
215-
payload = {"username": self.username, "password": self.password}
216-
try:
217-
await self._request_json("POST", self._login_url, json_data=payload)
218-
except (AirOSConnectionAuthenticationError, AirOSConnectionSetupError) as err:
219-
raise AirOSConnectionSetupError("Failed to login to AirOS device") from err
220-
221-
async def status(self) -> AirOSData:
222-
"""Retrieve status from the device."""
223-
response = await self._request_json(
224-
"GET", self._status_cgi_url, authenticated=True
225-
)
226-
227-
try:
228-
adjusted_json = self.derived_data(response)
229-
return AirOSData.from_dict(adjusted_json)
230-
except InvalidFieldValue as err:
231-
# Log with .error() as this is a specific, known type of issue
232-
redacted_data = redact_data_smart(response)
233-
_LOGGER.error(
234-
"Failed to deserialize AirOS data due to an invalid field value: %s",
235-
redacted_data,
236-
)
237-
raise AirOSKeyDataMissingError from err
238-
except MissingField as err:
239-
# Log with .exception() for a full stack trace
240-
redacted_data = redact_data_smart(response)
241-
_LOGGER.exception(
242-
"Failed to deserialize AirOS data due to a missing field: %s",
243-
redacted_data,
244-
)
245-
raise AirOSKeyDataMissingError from err
36+
return derived
37+
38+
async def update_check(self, force: bool = False) -> dict[str, Any]:
39+
"""Check for firmware updates. Not supported on AirOS6."""
40+
raise AirOSNotSupportedError("Firmware update check not supported on AirOS6.")
41+
42+
async def stakick(self, mac_address: str | None = None) -> bool:
43+
"""Kick a station off the AP. Not supported on AirOS6."""
44+
raise AirOSNotSupportedError("Station kick not supported on AirOS6.")
45+
46+
async def provmode(self, active: bool = False) -> bool:
47+
"""Enable/Disable provisioning mode. Not supported on AirOS6."""
48+
raise AirOSNotSupportedError("Provisioning mode not supported on AirOS6.")
49+
50+
async def warnings(self) -> dict[str, Any]:
51+
"""Get device warnings. Not supported on AirOS6."""
52+
raise AirOSNotSupportedError("Device warnings not supported on AirOS6.")
53+
54+
async def progress(self) -> dict[str, Any]:
55+
"""Get firmware progress. Not supported on AirOS6."""
56+
raise AirOSNotSupportedError("Firmware progress not supported on AirOS6.")
57+
58+
async def download(self) -> dict[str, Any]:
59+
"""Download the device firmware. Not supported on AirOS6."""
60+
raise AirOSNotSupportedError("Firmware download not supported on AirOS6.")
61+
62+
async def install(self) -> dict[str, Any]:
63+
"""Install a firmware update. Not supported on AirOS6."""
64+
raise AirOSNotSupportedError("Firmware install not supported on AirOS6.")

airos/airos8.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
from __future__ import annotations
44

55
import asyncio
6+
from collections.abc import Callable
67
from http.cookies import SimpleCookie
78
import json
89
import logging
9-
from typing import Any
10+
from typing import Any, TypeVar
1011
from urllib.parse import urlparse
1112

1213
import aiohttp
1314
from mashumaro.exceptions import InvalidFieldValue, MissingField
1415

1516
from .data import (
16-
AirOS8Data as AirOSData,
17+
AirOS8Data,
18+
AirOSDataBaseClass,
1719
DerivedWirelessMode,
1820
DerivedWirelessRole,
1921
redact_data_smart,
@@ -28,10 +30,14 @@
2830

2931
_LOGGER = logging.getLogger(__name__)
3032

33+
AirOSDataModel = TypeVar("AirOSDataModel", bound=AirOSDataBaseClass)
34+
3135

3236
class AirOS:
3337
"""AirOS 8 connection class."""
3438

39+
data_model: type[AirOSDataBaseClass] = AirOS8Data
40+
3541
def __init__(
3642
self,
3743
host: str,
@@ -74,17 +80,10 @@ def __init__(
7480
self.connected: bool = False
7581

7682
@staticmethod
77-
def derived_data(response: dict[str, Any]) -> dict[str, Any]:
78-
"""Add derived data to the device response."""
79-
derived: dict[str, Any] = {
80-
"station": False,
81-
"access_point": False,
82-
"ptp": False,
83-
"ptmp": False,
84-
"role": DerivedWirelessRole.STATION,
85-
"mode": DerivedWirelessMode.PTP,
86-
}
87-
83+
def derived_wireless_data(
84+
derived: dict[str, Any], response: dict[str, Any]
85+
) -> dict[str, Any]:
86+
"""Add derived wireless data to the device response."""
8887
# Access Point / Station vs PTP/PtMP
8988
wireless_mode = response.get("wireless", {}).get("mode", "")
9089
match wireless_mode:
@@ -104,6 +103,27 @@ def derived_data(response: dict[str, Any]) -> dict[str, Any]:
104103
case "sta-ptp":
105104
derived["station"] = True
106105
derived["ptp"] = True
106+
return derived
107+
108+
@staticmethod
109+
def _derived_data_helper(
110+
response: dict[str, Any],
111+
derived_wireless_data_func: Callable[
112+
[dict[str, Any], dict[str, Any]], dict[str, Any]
113+
],
114+
) -> dict[str, Any]:
115+
"""Add derived data to the device response."""
116+
derived: dict[str, Any] = {
117+
"station": False,
118+
"access_point": False,
119+
"ptp": False,
120+
"ptmp": False,
121+
"role": DerivedWirelessRole.STATION,
122+
"mode": DerivedWirelessMode.PTP,
123+
}
124+
125+
# WIRELESS
126+
derived = derived_wireless_data_func(derived, response)
107127

108128
# INTERFACES
109129
addresses = {}
@@ -134,6 +154,10 @@ def derived_data(response: dict[str, Any]) -> dict[str, Any]:
134154

135155
return response
136156

157+
def derived_data(self, response: dict[str, Any]) -> dict[str, Any]:
158+
"""Add derived data to the device response (instance method for polymorphism)."""
159+
return self._derived_data_helper(response, self.derived_wireless_data)
160+
137161
def _get_authenticated_headers(
138162
self,
139163
ct_json: bool = False,
@@ -235,15 +259,15 @@ async def login(self) -> None:
235259
except (AirOSConnectionAuthenticationError, AirOSConnectionSetupError) as err:
236260
raise AirOSConnectionSetupError("Failed to login to AirOS device") from err
237261

238-
async def status(self) -> AirOSData:
262+
async def status(self) -> AirOSDataBaseClass:
239263
"""Retrieve status from the device."""
240264
response = await self._request_json(
241265
"GET", self._status_cgi_url, authenticated=True
242266
)
243267

244268
try:
245269
adjusted_json = self.derived_data(response)
246-
return AirOSData.from_dict(adjusted_json)
270+
return self.data_model.from_dict(adjusted_json)
247271
except InvalidFieldValue as err:
248272
# Log with .error() as this is a specific, known type of issue
249273
redacted_data = redact_data_smart(response)

0 commit comments

Comments
 (0)