|
26 | 26 | * Adafruit CircuitPython firmware for the supported boards:
|
27 | 27 | https://github.com/adafruit/circuitpython/releases
|
28 | 28 |
|
| 29 | +* Adafruit's Connection Manager library: |
| 30 | + https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager |
| 31 | +
|
29 | 32 | """
|
30 | 33 | import errno
|
31 | 34 | import struct
|
32 | 35 | import time
|
33 | 36 | from random import randint
|
34 | 37 |
|
| 38 | +from adafruit_connectionmanager import ( |
| 39 | + get_connection_manager, |
| 40 | + SocketGetOSError, |
| 41 | + SocketConnectMemoryError, |
| 42 | +) |
| 43 | + |
35 | 44 | try:
|
36 | 45 | from typing import List, Optional, Tuple, Type, Union
|
37 | 46 | except ImportError:
|
|
78 | 87 | _default_sock = None # pylint: disable=invalid-name
|
79 | 88 | _fake_context = None # pylint: disable=invalid-name
|
80 | 89 |
|
| 90 | +TemporaryError = (SocketGetOSError, SocketConnectMemoryError) |
| 91 | + |
81 | 92 |
|
82 | 93 | class MMQTTException(Exception):
|
83 | 94 | """MiniMQTT Exception class."""
|
84 | 95 |
|
85 |
| - # pylint: disable=unnecessary-pass |
86 |
| - # pass |
87 |
| - |
88 |
| - |
89 |
| -class TemporaryError(Exception): |
90 |
| - """Temporary error class used for handling reconnects.""" |
91 |
| - |
92 |
| - |
93 |
| -# Legacy ESP32SPI Socket API |
94 |
| -def set_socket(sock, iface=None) -> None: |
95 |
| - """Legacy API for setting the socket and network interface. |
96 |
| -
|
97 |
| - :param sock: socket object. |
98 |
| - :param iface: internet interface object |
99 |
| -
|
100 |
| - """ |
101 |
| - global _default_sock # pylint: disable=invalid-name, global-statement |
102 |
| - global _fake_context # pylint: disable=invalid-name, global-statement |
103 |
| - _default_sock = sock |
104 |
| - if iface: |
105 |
| - _default_sock.set_interface(iface) |
106 |
| - _fake_context = _FakeSSLContext(iface) |
107 |
| - |
108 |
| - |
109 |
| -class _FakeSSLSocket: |
110 |
| - def __init__(self, socket, tls_mode) -> None: |
111 |
| - self._socket = socket |
112 |
| - self._mode = tls_mode |
113 |
| - self.settimeout = socket.settimeout |
114 |
| - self.send = socket.send |
115 |
| - self.recv = socket.recv |
116 |
| - self.close = socket.close |
117 |
| - |
118 |
| - def connect(self, address): |
119 |
| - """connect wrapper to add non-standard mode parameter""" |
120 |
| - try: |
121 |
| - return self._socket.connect(address, self._mode) |
122 |
| - except RuntimeError as error: |
123 |
| - raise OSError(errno.ENOMEM) from error |
124 |
| - |
125 |
| - |
126 |
| -class _FakeSSLContext: |
127 |
| - def __init__(self, iface) -> None: |
128 |
| - self._iface = iface |
129 |
| - |
130 |
| - def wrap_socket(self, socket, server_hostname=None) -> _FakeSSLSocket: |
131 |
| - """Return the same socket""" |
132 |
| - # pylint: disable=unused-argument |
133 |
| - return _FakeSSLSocket(socket, self._iface.TLS_MODE) |
134 |
| - |
135 | 96 |
|
136 | 97 | class NullLogger:
|
137 | 98 | """Fake logger class that does not do anything"""
|
138 | 99 |
|
139 | 100 | # pylint: disable=unused-argument
|
140 | 101 | def nothing(self, msg: str, *args) -> None:
|
141 | 102 | """no action"""
|
142 |
| - pass |
143 | 103 |
|
144 | 104 | def __init__(self) -> None:
|
145 | 105 | for log_level in ["debug", "info", "warning", "error", "critical"]:
|
@@ -194,6 +154,7 @@ def __init__(
|
194 | 154 | user_data=None,
|
195 | 155 | use_imprecise_time: Optional[bool] = None,
|
196 | 156 | ) -> None:
|
| 157 | + self._connection_manager = get_connection_manager(socket_pool) |
197 | 158 | self._socket_pool = socket_pool
|
198 | 159 | self._ssl_context = ssl_context
|
199 | 160 | self._sock = None
|
@@ -300,77 +261,6 @@ def get_monotonic_time(self) -> float:
|
300 | 261 |
|
301 | 262 | return time.monotonic()
|
302 | 263 |
|
303 |
| - # pylint: disable=too-many-branches |
304 |
| - def _get_connect_socket(self, host: str, port: int, *, timeout: int = 1): |
305 |
| - """Obtains a new socket and connects to a broker. |
306 |
| -
|
307 |
| - :param str host: Desired broker hostname |
308 |
| - :param int port: Desired broker port |
309 |
| - :param int timeout: Desired socket timeout, in seconds |
310 |
| - """ |
311 |
| - # For reconnections - check if we're using a socket already and close it |
312 |
| - if self._sock: |
313 |
| - self._sock.close() |
314 |
| - self._sock = None |
315 |
| - |
316 |
| - # Legacy API - use the interface's socket instead of a passed socket pool |
317 |
| - if self._socket_pool is None: |
318 |
| - self._socket_pool = _default_sock |
319 |
| - |
320 |
| - # Legacy API - fake the ssl context |
321 |
| - if self._ssl_context is None: |
322 |
| - self._ssl_context = _fake_context |
323 |
| - |
324 |
| - if not isinstance(port, int): |
325 |
| - raise RuntimeError("Port must be an integer") |
326 |
| - |
327 |
| - if self._is_ssl and not self._ssl_context: |
328 |
| - raise RuntimeError( |
329 |
| - "ssl_context must be set before using adafruit_mqtt for secure MQTT." |
330 |
| - ) |
331 |
| - |
332 |
| - if self._is_ssl: |
333 |
| - self.logger.info(f"Establishing a SECURE SSL connection to {host}:{port}") |
334 |
| - else: |
335 |
| - self.logger.info(f"Establishing an INSECURE connection to {host}:{port}") |
336 |
| - |
337 |
| - addr_info = self._socket_pool.getaddrinfo( |
338 |
| - host, port, 0, self._socket_pool.SOCK_STREAM |
339 |
| - )[0] |
340 |
| - |
341 |
| - try: |
342 |
| - sock = self._socket_pool.socket(addr_info[0], addr_info[1]) |
343 |
| - except OSError as exc: |
344 |
| - # Do not consider this for back-off. |
345 |
| - self.logger.warning( |
346 |
| - f"Failed to create socket for host {addr_info[0]} and port {addr_info[1]}" |
347 |
| - ) |
348 |
| - raise TemporaryError from exc |
349 |
| - |
350 |
| - connect_host = addr_info[-1][0] |
351 |
| - if self._is_ssl: |
352 |
| - sock = self._ssl_context.wrap_socket(sock, server_hostname=host) |
353 |
| - connect_host = host |
354 |
| - sock.settimeout(timeout) |
355 |
| - |
356 |
| - last_exception = None |
357 |
| - try: |
358 |
| - sock.connect((connect_host, port)) |
359 |
| - except MemoryError as exc: |
360 |
| - sock.close() |
361 |
| - self.logger.warning(f"Failed to allocate memory for connect: {exc}") |
362 |
| - # Do not consider this for back-off. |
363 |
| - raise TemporaryError from exc |
364 |
| - except OSError as exc: |
365 |
| - sock.close() |
366 |
| - last_exception = exc |
367 |
| - |
368 |
| - if last_exception: |
369 |
| - raise last_exception |
370 |
| - |
371 |
| - self._backwards_compatible_sock = not hasattr(sock, "recv_into") |
372 |
| - return sock |
373 |
| - |
374 | 264 | def __enter__(self):
|
375 | 265 | return self
|
376 | 266 |
|
@@ -593,8 +483,15 @@ def _connect(
|
593 | 483 | time.sleep(self._reconnect_timeout)
|
594 | 484 |
|
595 | 485 | # Get a new socket
|
596 |
| - self._sock = self._get_connect_socket( |
597 |
| - self.broker, self.port, timeout=self._socket_timeout |
| 486 | + self._sock = self._connection_manager.get_socket( |
| 487 | + self.broker, |
| 488 | + self.port, |
| 489 | + "mqtt:", |
| 490 | + timeout=self._socket_timeout, |
| 491 | + is_ssl=self._is_ssl, |
| 492 | + ssl_context=self._ssl_context, |
| 493 | + max_retries=1, # setting to 1 since we want to handle backoff internally |
| 494 | + exception_passthrough=True, |
598 | 495 | )
|
599 | 496 |
|
600 | 497 | # Fixed Header
|
@@ -689,7 +586,7 @@ def disconnect(self) -> None:
|
689 | 586 | except RuntimeError as e:
|
690 | 587 | self.logger.warning(f"Unable to send DISCONNECT packet: {e}")
|
691 | 588 | self.logger.debug("Closing socket")
|
692 |
| - self._sock.close() |
| 589 | + self._connection_manager.free_socket(self._sock) |
693 | 590 | self._is_connected = False
|
694 | 591 | self._subscribed_topics = []
|
695 | 592 | if self.on_disconnect is not None:
|
|
0 commit comments