|
1 | 1 | import base64 |
| 2 | +import dataclasses |
| 3 | +import enum |
2 | 4 | import json |
3 | 5 | import os |
4 | | -from typing import MutableMapping |
| 6 | +from typing import Any, Dict, MutableMapping, Self |
5 | 7 |
|
6 | 8 | from enapter import mqtt |
7 | 9 |
|
8 | 10 |
|
| 11 | +@dataclasses.dataclass |
9 | 12 | class Config: |
10 | 13 |
|
| 14 | + communication_config: "CommunicationConfig" |
| 15 | + |
| 16 | + @property |
| 17 | + def communication(self) -> "CommunicationConfig": |
| 18 | + return self.communication_config |
| 19 | + |
11 | 20 | @classmethod |
12 | 21 | def from_env( |
13 | 22 | cls, env: MutableMapping[str, str] = os.environ, namespace: str = "ENAPTER_" |
14 | | - ) -> "Config": |
15 | | - prefix = namespace + "STANDALONE_" |
| 23 | + ) -> Self: |
| 24 | + communication_config = CommunicationConfig.from_env(env, namespace=namespace) |
| 25 | + return cls(communication_config=communication_config) |
16 | 26 |
|
17 | | - try: |
18 | | - blob = env[prefix + "BLOB"] |
19 | | - except KeyError: |
20 | | - pass |
21 | | - else: |
22 | | - config = cls.from_blob(blob) |
23 | | - try: |
24 | | - config.channel_id = env[prefix + "CHANNEL_ID"] |
25 | | - except KeyError: |
26 | | - pass |
27 | | - return config |
28 | 27 |
|
29 | | - hardware_id = env[prefix + "HARDWARE_ID"] |
30 | | - channel_id = env[prefix + "CHANNEL_ID"] |
| 28 | +@dataclasses.dataclass |
| 29 | +class CommunicationConfig: |
31 | 30 |
|
32 | | - mqtt_config = mqtt.Config.from_env(env, namespace=namespace) |
| 31 | + mqtt_config: mqtt.Config |
| 32 | + hardware_id: str |
| 33 | + channel_id: str |
| 34 | + ucm_needed: bool |
33 | 35 |
|
34 | | - start_ucm = env.get(prefix + "START_UCM", "1") != "0" |
| 36 | + @property |
| 37 | + def mqtt(self) -> mqtt.Config: |
| 38 | + return self.mqtt_config |
35 | 39 |
|
36 | | - return cls( |
37 | | - hardware_id=hardware_id, |
38 | | - channel_id=channel_id, |
39 | | - mqtt=mqtt_config, |
40 | | - start_ucm=start_ucm, |
41 | | - ) |
| 40 | + @classmethod |
| 41 | + def from_env( |
| 42 | + cls, env: MutableMapping[str, str] = os.environ, namespace: str = "ENAPTER_" |
| 43 | + ) -> Self: |
| 44 | + prefix = namespace + "STANDALONE_COMMUNICATION_" |
| 45 | + blob = env[prefix + "CONFIG"] |
| 46 | + return cls.from_blob(blob) |
42 | 47 |
|
43 | 48 | @classmethod |
44 | | - def from_blob(cls, blob: str) -> "Config": |
45 | | - payload = json.loads(base64.b64decode(blob)) |
| 49 | + def from_blob(cls, blob: str) -> Self: |
| 50 | + dto = json.loads(base64.b64decode(blob)) |
| 51 | + if "ucm_id" in dto: |
| 52 | + config_v1 = CommunicationConfigV1.from_dto(dto) |
| 53 | + return cls.from_config_v1(config_v1) |
| 54 | + else: |
| 55 | + config_v3 = CommunicationConfigV3.from_dto(dto) |
| 56 | + return cls.from_config_v3(config_v3) |
46 | 57 |
|
| 58 | + @classmethod |
| 59 | + def from_config_v1(cls, config: "CommunicationConfigV1") -> Self: |
47 | 60 | mqtt_config = mqtt.Config( |
48 | | - host=payload["mqtt_host"], |
49 | | - port=int(payload["mqtt_port"]), |
50 | | - tls=mqtt.TLSConfig( |
51 | | - ca_cert=payload["mqtt_ca"], |
52 | | - cert=payload["mqtt_cert"], |
53 | | - secret_key=payload["mqtt_private_key"], |
| 61 | + host=config.mqtt_host, |
| 62 | + port=config.mqtt_port, |
| 63 | + tls_config=mqtt.TLSConfig( |
| 64 | + secret_key=config.mqtt_private_key, |
| 65 | + cert=config.mqtt_cert, |
| 66 | + ca_cert=config.mqtt_ca, |
54 | 67 | ), |
55 | 68 | ) |
| 69 | + return cls( |
| 70 | + mqtt_config=mqtt_config, |
| 71 | + hardware_id=config.ucm_id, |
| 72 | + channel_id=config.channel_id, |
| 73 | + ucm_needed=True, |
| 74 | + ) |
| 75 | + |
| 76 | + @classmethod |
| 77 | + def from_config_v3(cls, config: "CommunicationConfigV3") -> Self: |
| 78 | + mqtt_config: mqtt.Config | None = None |
| 79 | + match config.mqtt_protocol: |
| 80 | + case CommunicationConfigV3.MQTTProtocol.MQTT: |
| 81 | + assert isinstance( |
| 82 | + config.mqtt_credentials, CommunicationConfigV3.MQTTCredentials |
| 83 | + ) |
| 84 | + mqtt_config = mqtt.Config( |
| 85 | + host=config.mqtt_host, |
| 86 | + port=config.mqtt_port, |
| 87 | + user=config.mqtt_credentials.username, |
| 88 | + password=config.mqtt_credentials.password, |
| 89 | + ) |
| 90 | + case CommunicationConfigV3.MQTTProtocol.MQTTS: |
| 91 | + assert isinstance( |
| 92 | + config.mqtt_credentials, CommunicationConfigV3.MQTTSCredentials |
| 93 | + ) |
| 94 | + mqtt_config = mqtt.Config( |
| 95 | + host=config.mqtt_host, |
| 96 | + port=config.mqtt_port, |
| 97 | + tls_config=mqtt.TLSConfig( |
| 98 | + secret_key=config.mqtt_credentials.private_key, |
| 99 | + cert=config.mqtt_credentials.certificate, |
| 100 | + ca_cert=config.mqtt_credentials.ca_chain, |
| 101 | + ), |
| 102 | + ) |
| 103 | + case _: |
| 104 | + raise NotImplementedError(config.mqtt_protocol) |
| 105 | + assert mqtt_config is not None |
| 106 | + return cls( |
| 107 | + mqtt_config=mqtt_config, |
| 108 | + hardware_id=config.hardware_id, |
| 109 | + channel_id=config.channel_id, |
| 110 | + ucm_needed=False, |
| 111 | + ) |
| 112 | + |
56 | 113 |
|
| 114 | +@dataclasses.dataclass |
| 115 | +class CommunicationConfigV1: |
| 116 | + |
| 117 | + mqtt_host: str |
| 118 | + mqtt_port: int |
| 119 | + mqtt_ca: str |
| 120 | + mqtt_cert: str |
| 121 | + mqtt_private_key: str |
| 122 | + ucm_id: str |
| 123 | + channel_id: str |
| 124 | + |
| 125 | + @classmethod |
| 126 | + def from_dto(cls, dto: Dict[str, Any]) -> Self: |
57 | 127 | return cls( |
58 | | - hardware_id=payload["ucm_id"], |
59 | | - channel_id=payload["channel_id"], |
60 | | - mqtt=mqtt_config, |
| 128 | + mqtt_host=dto["mqtt_host"], |
| 129 | + mqtt_port=int(dto["mqtt_port"]), |
| 130 | + mqtt_ca=dto["mqtt_ca"], |
| 131 | + mqtt_cert=dto["mqtt_cert"], |
| 132 | + mqtt_private_key=dto["mqtt_private_key"], |
| 133 | + ucm_id=dto["ucm_id"], |
| 134 | + channel_id=dto["channel_id"], |
61 | 135 | ) |
62 | 136 |
|
63 | | - def __init__( |
64 | | - self, |
65 | | - hardware_id: str, |
66 | | - channel_id: str, |
67 | | - mqtt: mqtt.Config, |
68 | | - start_ucm: bool = True, |
69 | | - ) -> None: |
70 | | - self.hardware_id = hardware_id |
71 | | - self.channel_id = channel_id |
72 | | - self.mqtt = mqtt |
73 | | - self.start_ucm = start_ucm |
| 137 | + |
| 138 | +@dataclasses.dataclass |
| 139 | +class CommunicationConfigV3: |
| 140 | + |
| 141 | + class MQTTProtocol(enum.Enum): |
| 142 | + |
| 143 | + MQTT = "MQTT" |
| 144 | + MQTTS = "MQTTS" |
| 145 | + |
| 146 | + @dataclasses.dataclass |
| 147 | + class MQTTCredentials: |
| 148 | + |
| 149 | + username: str |
| 150 | + password: str |
| 151 | + |
| 152 | + @classmethod |
| 153 | + def from_dto(cls, dto: Dict[str, Any]) -> Self: |
| 154 | + return cls(username=dto["username"], password=dto["password"]) |
| 155 | + |
| 156 | + @dataclasses.dataclass |
| 157 | + class MQTTSCredentials: |
| 158 | + |
| 159 | + private_key: str |
| 160 | + certificate: str |
| 161 | + ca_chain: str |
| 162 | + |
| 163 | + @classmethod |
| 164 | + def from_dto(cls, dto: Dict[str, Any]) -> Self: |
| 165 | + return cls( |
| 166 | + private_key=dto["private_key"], |
| 167 | + certificate=dto["certificate"], |
| 168 | + ca_chain=dto["ca_chain"], |
| 169 | + ) |
| 170 | + |
| 171 | + mqtt_host: str |
| 172 | + mqtt_port: int |
| 173 | + mqtt_protocol: MQTTProtocol |
| 174 | + mqtt_credentials: MQTTCredentials | MQTTSCredentials |
| 175 | + hardware_id: str |
| 176 | + channel_id: str |
| 177 | + |
| 178 | + @classmethod |
| 179 | + def from_dto(cls, dto: Dict[str, Any]) -> Self: |
| 180 | + mqtt_protocol = cls.MQTTProtocol(dto["mqtt_protocol"]) |
| 181 | + mqtt_credentials: ( |
| 182 | + CommunicationConfigV3.MQTTCredentials |
| 183 | + | CommunicationConfigV3.MQTTSCredentials |
| 184 | + | None |
| 185 | + ) = None |
| 186 | + match mqtt_protocol: |
| 187 | + case cls.MQTTProtocol.MQTT: |
| 188 | + mqtt_credentials = cls.MQTTCredentials.from_dto(dto["mqtt_credentials"]) |
| 189 | + case cls.MQTTProtocol.MQTTS: |
| 190 | + mqtt_credentials = cls.MQTTSCredentials.from_dto( |
| 191 | + dto["mqtt_credentials"] |
| 192 | + ) |
| 193 | + case _: |
| 194 | + raise NotImplementedError(mqtt_protocol) |
| 195 | + assert mqtt_credentials is not None |
| 196 | + return cls( |
| 197 | + mqtt_host=dto["mqtt_host"], |
| 198 | + mqtt_port=int(dto["mqtt_port"]), |
| 199 | + mqtt_credentials=mqtt_credentials, |
| 200 | + mqtt_protocol=mqtt_protocol, |
| 201 | + hardware_id=dto["hardware_id"], |
| 202 | + channel_id=dto["channel_id"], |
| 203 | + ) |
0 commit comments