-
Notifications
You must be signed in to change notification settings - Fork 128
Expand file tree
/
Copy pathEndpoint.py
More file actions
105 lines (86 loc) · 3.26 KB
/
Endpoint.py
File metadata and controls
105 lines (86 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from enum import Enum
from typing import Any, Dict, Iterable, Optional, Type
import validators
from pydantic import constr, model_validator
from pydantic.dataclasses import dataclass
class EndpointType(str, Enum):
"""Supported Mavlink endpoint types."""
UDPServer = "udpin"
UDPClient = "udpout"
TCPServer = "tcpin"
TCPClient = "tcpout"
Serial = "serial"
Zenoh = "zenoh"
@dataclass
# pylint: disable=too-many-instance-attributes
class Endpoint:
name: constr(strip_whitespace=True, min_length=3, max_length=50) # type: ignore
owner: constr(strip_whitespace=True, min_length=3, max_length=50) # type: ignore
connection_type: str
place: str
argument: Optional[int] = None
persistent: Optional[bool] = False
protected: Optional[bool] = False
enabled: Optional[bool] = True
overwrite_settings: Optional[bool] = False
@model_validator(mode="before")
@classmethod
def is_mavlink_endpoint(cls: Type["Endpoint"], values: Any) -> Any:
if isinstance(values, dict):
connection_type, place, argument = (
values.get("connection_type"),
values.get("place"),
values.get("argument"),
)
else:
return values
if connection_type in [
EndpointType.UDPServer,
EndpointType.UDPClient,
EndpointType.TCPServer,
EndpointType.TCPClient,
EndpointType.Zenoh,
]:
if not (validators.domain(place) or validators.ipv4(place) or validators.ipv6(place)):
raise ValueError(f"Invalid network address: {place}")
if argument not in range(1, 65536):
raise ValueError(f"Ports must be in the range 1:65535. Received {argument}.")
return values
if connection_type == EndpointType.Serial.value:
if not place.startswith("/") or place.endswith("/"):
raise ValueError(f"Bad serial address: {place}. Make sure to use an absolute path.")
if argument not in VALID_SERIAL_BAUDRATES:
raise ValueError(f"Invalid serial baudrate: {argument}. Valid option are {VALID_SERIAL_BAUDRATES}.")
return values
raise ValueError(
f"Invalid connection_type: {connection_type}. Valid types are: {[e.value for e in EndpointType]}."
)
@staticmethod
def filter_enabled(endpoints: Iterable["Endpoint"]) -> Iterable["Endpoint"]:
return [endpoint for endpoint in endpoints if endpoint.enabled is True]
def __str__(self) -> str:
return ":".join([self.connection_type, self.place, str(self.argument)])
def as_dict(self) -> Dict[str, Any]:
return dict(filter(lambda field: field[0] != "__initialised__", self.__dict__.items()))
def __hash__(self) -> int:
return hash(str(self))
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
raise NotImplementedError
return str(self) == str(other) and self.connection_type == other.connection_type and self.place == other.place
VALID_SERIAL_BAUDRATES = [
3000000,
2000000,
1000000,
921600,
570600,
460800,
257600,
250000,
230400,
115200,
57600,
38400,
19200,
9600,
]