Skip to content

Commit c3a080b

Browse files
committed
add some type annotations
1 parent 988332f commit c3a080b

File tree

1 file changed

+74
-52
lines changed

1 file changed

+74
-52
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232
import time
3333
from random import randint
3434

35+
try:
36+
from typing import List, Tuple, Type, Union
37+
except ImportError:
38+
pass
39+
3540
from micropython import const
3641

3742
from .matcher import MQTTMatcher
@@ -84,7 +89,7 @@ class TemporaryError(Exception):
8489

8590

8691
# Legacy ESP32SPI Socket API
87-
def set_socket(sock, iface=None):
92+
def set_socket(sock, iface=None) -> None:
8893
"""Legacy API for setting the socket and network interface.
8994
9095
:param sock: socket object.
@@ -100,7 +105,7 @@ def set_socket(sock, iface=None):
100105

101106

102107
class _FakeSSLSocket:
103-
def __init__(self, socket, tls_mode):
108+
def __init__(self, socket, tls_mode) -> None:
104109
self._socket = socket
105110
self._mode = tls_mode
106111
self.settimeout = socket.settimeout
@@ -117,10 +122,10 @@ def connect(self, address):
117122

118123

119124
class _FakeSSLContext:
120-
def __init__(self, iface):
125+
def __init__(self, iface) -> None:
121126
self._iface = iface
122127

123-
def wrap_socket(self, socket, server_hostname=None):
128+
def wrap_socket(self, socket, server_hostname=None) -> _FakeSSLSocket:
124129
"""Return the same socket"""
125130
# pylint: disable=unused-argument
126131
return _FakeSSLSocket(socket, self._iface.TLS_MODE)
@@ -134,7 +139,7 @@ def nothing(self, msg: str, *args) -> None:
134139
"""no action"""
135140
pass
136141

137-
def __init__(self):
142+
def __init__(self) -> None:
138143
for log_level in ["debug", "info", "warning", "error", "critical"]:
139144
setattr(NullLogger, log_level, self.nothing)
140145

@@ -166,21 +171,21 @@ class MQTT:
166171
def __init__(
167172
self,
168173
*,
169-
broker,
170-
port=None,
171-
username=None,
172-
password=None,
173-
client_id=None,
174-
is_ssl=None,
175-
keep_alive=60,
176-
recv_timeout=10,
174+
broker: str,
175+
port: Union[int, None] = None,
176+
username: Union[str, None] = None,
177+
password: Union[str, None] = None,
178+
client_id: Union[str, None] = None,
179+
is_ssl: Union[bool, None] = None,
180+
keep_alive: int = 60,
181+
recv_timeout: int = 10,
177182
socket_pool=None,
178183
ssl_context=None,
179-
use_binary_mode=False,
180-
socket_timeout=1,
181-
connect_retries=5,
184+
use_binary_mode: bool = False,
185+
socket_timeout: int = 1,
186+
connect_retries: int = 5,
182187
user_data=None,
183-
):
188+
) -> None:
184189

185190
self._socket_pool = socket_pool
186191
self._ssl_context = ssl_context
@@ -253,7 +258,7 @@ def __init__(
253258
self._lw_retain = False
254259

255260
# List of subscribed topics, used for tracking
256-
self._subscribed_topics = []
261+
self._subscribed_topics: List[str] = []
257262
self._on_message_filtered = MQTTMatcher()
258263

259264
# Default topic callback methods
@@ -265,7 +270,7 @@ def __init__(
265270
self.on_unsubscribe = None
266271

267272
# pylint: disable=too-many-branches
268-
def _get_connect_socket(self, host, port, *, timeout=1):
273+
def _get_connect_socket(self, host: str, port: int, *, timeout: int = 1):
269274
"""Obtains a new socket and connects to a broker.
270275
271276
:param str host: Desired broker hostname
@@ -338,20 +343,20 @@ def _get_connect_socket(self, host, port, *, timeout=1):
338343
def __enter__(self):
339344
return self
340345

341-
def __exit__(self, exception_type, exception_value, traceback):
346+
def __exit__(self, exception_type, exception_value, traceback) -> None:
342347
self.deinit()
343348

344-
def deinit(self):
349+
def deinit(self) -> None:
345350
"""De-initializes the MQTT client and disconnects from the mqtt broker."""
346351
self.disconnect()
347352

348353
@property
349-
def mqtt_msg(self):
354+
def mqtt_msg(self) -> Tuple[int, int]:
350355
"""Returns maximum MQTT payload and topic size."""
351356
return self._msg_size_lim, MQTT_TOPIC_LENGTH_LIMIT
352357

353358
@mqtt_msg.setter
354-
def mqtt_msg(self, msg_size):
359+
def mqtt_msg(self, msg_size: int) -> None:
355360
"""Sets the maximum MQTT message payload size.
356361
357362
:param int msg_size: Maximum MQTT payload size.
@@ -388,7 +393,7 @@ def will_set(self, topic=None, payload=None, qos=0, retain=False):
388393
self._lw_msg = payload
389394
self._lw_retain = retain
390395

391-
def add_topic_callback(self, mqtt_topic, callback_method):
396+
def add_topic_callback(self, mqtt_topic: str, callback_method) -> None:
392397
"""Registers a callback_method for a specific MQTT topic.
393398
394399
:param str mqtt_topic: MQTT topic identifier.
@@ -398,7 +403,7 @@ def add_topic_callback(self, mqtt_topic, callback_method):
398403
raise ValueError("MQTT topic and callback method must both be defined.")
399404
self._on_message_filtered[mqtt_topic] = callback_method
400405

401-
def remove_topic_callback(self, mqtt_topic):
406+
def remove_topic_callback(self, mqtt_topic: str) -> None:
402407
"""Removes a registered callback method.
403408
404409
:param str mqtt_topic: MQTT topic identifier string.
@@ -421,10 +426,10 @@ def on_message(self):
421426
return self._on_message
422427

423428
@on_message.setter
424-
def on_message(self, method):
429+
def on_message(self, method) -> None:
425430
self._on_message = method
426431

427-
def _handle_on_message(self, client, topic, message):
432+
def _handle_on_message(self, client, topic: str, message: str):
428433
matched = False
429434
if topic is not None:
430435
for callback in self._on_message_filtered.iter_match(topic):
@@ -434,7 +439,7 @@ def _handle_on_message(self, client, topic, message):
434439
if not matched and self.on_message: # regular on_message
435440
self.on_message(client, topic, message)
436441

437-
def username_pw_set(self, username, password=None):
442+
def username_pw_set(self, username: str, password: Union[str, None] = None) -> None:
438443
"""Set client's username and an optional password.
439444
440445
:param str username: Username to use with your MQTT broker.
@@ -447,7 +452,13 @@ def username_pw_set(self, username, password=None):
447452
if password is not None:
448453
self._password = password
449454

450-
def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
455+
def connect(
456+
self,
457+
clean_session: bool = True,
458+
host: Union[str, None] = None,
459+
port: Union[int, None] = None,
460+
keep_alive: Union[int, None] = None,
461+
) -> int:
451462
"""Initiates connection with the MQTT Broker. Will perform exponential back-off
452463
on connect failures.
453464
@@ -503,7 +514,13 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
503514
raise MMQTTException(exc_msg)
504515

505516
# pylint: disable=too-many-branches, too-many-statements, too-many-locals
506-
def _connect(self, clean_session=True, host=None, port=None, keep_alive=None):
517+
def _connect(
518+
self,
519+
clean_session: bool = True,
520+
host: Union[str, None] = None,
521+
port: Union[int, None] = None,
522+
keep_alive: Union[int, None] = None,
523+
) -> int:
507524
"""Initiates connection with the MQTT Broker.
508525
509526
:param bool clean_session: Establishes a persistent session.
@@ -616,7 +633,7 @@ def _connect(self, clean_session=True, host=None, port=None, keep_alive=None):
616633
f"No data received from broker for {self._recv_timeout} seconds."
617634
)
618635

619-
def disconnect(self):
636+
def disconnect(self) -> None:
620637
"""Disconnects the MiniMQTT client from the MQTT broker."""
621638
self._connected()
622639
self.logger.debug("Sending DISCONNECT packet to broker")
@@ -631,7 +648,7 @@ def disconnect(self):
631648
if self.on_disconnect is not None:
632649
self.on_disconnect(self, self._user_data, 0)
633650

634-
def ping(self):
651+
def ping(self) -> list[int]:
635652
"""Pings the MQTT Broker to confirm if the broker is alive or if
636653
there is an active network connection.
637654
Returns response codes of any messages received while waiting for PINGRESP.
@@ -651,7 +668,13 @@ def ping(self):
651668
return rcs
652669

653670
# pylint: disable=too-many-branches, too-many-statements
654-
def publish(self, topic, msg, retain=False, qos=0):
671+
def publish(
672+
self,
673+
topic: str,
674+
msg: Union[str, int, float, bytes],
675+
retain: bool = False,
676+
qos: int = 0,
677+
) -> None:
655678
"""Publishes a message to a topic provided.
656679
657680
:param str topic: Unique topic identifier.
@@ -740,7 +763,7 @@ def publish(self, topic, msg, retain=False, qos=0):
740763
f"No data received from broker for {self._recv_timeout} seconds."
741764
)
742765

743-
def subscribe(self, topic, qos=0):
766+
def subscribe(self, topic: str, qos: int = 0) -> None:
744767
"""Subscribes to a topic on the MQTT Broker.
745768
This method can subscribe to one topics or multiple topics.
746769
@@ -807,7 +830,7 @@ def subscribe(self, topic, qos=0):
807830
f"No data received from broker for {self._recv_timeout} seconds."
808831
)
809832

810-
def unsubscribe(self, topic):
833+
def unsubscribe(self, topic: str) -> None:
811834
"""Unsubscribes from a MQTT topic.
812835
813836
:param str|list topic: Unique MQTT topic identifier string or list.
@@ -861,7 +884,7 @@ def unsubscribe(self, topic):
861884
f"No data received from broker for {self._recv_timeout} seconds."
862885
)
863886

864-
def _recompute_reconnect_backoff(self):
887+
def _recompute_reconnect_backoff(self) -> None:
865888
"""
866889
Recompute the reconnection timeout. The self._reconnect_timeout will be used
867890
in self._connect() to perform the actual sleep.
@@ -891,7 +914,7 @@ def _recompute_reconnect_backoff(self):
891914
)
892915
self._reconnect_timeout += jitter
893916

894-
def _reset_reconnect_backoff(self):
917+
def _reset_reconnect_backoff(self) -> None:
895918
"""
896919
Reset reconnect back-off to the initial state.
897920
@@ -900,7 +923,7 @@ def _reset_reconnect_backoff(self):
900923
self._reconnect_attempt = 0
901924
self._reconnect_timeout = float(0)
902925

903-
def reconnect(self, resub_topics=True):
926+
def reconnect(self, resub_topics: bool = True) -> int:
904927
"""Attempts to reconnect to the MQTT broker.
905928
Return the value from connect() if successful. Will disconnect first if already connected.
906929
Will perform exponential back-off on connect failures.
@@ -924,13 +947,13 @@ def reconnect(self, resub_topics=True):
924947

925948
return ret
926949

927-
def loop(self, timeout=0):
950+
def loop(self, timeout: float = 0) -> Union[list[int], None]:
928951
# pylint: disable = too-many-return-statements
929952
"""Non-blocking message loop. Use this method to
930953
check incoming subscription messages.
931954
Returns response codes of any messages received.
932955
933-
:param int timeout: Socket timeout, in seconds.
956+
:param float timeout: Socket timeout, in seconds.
934957
935958
"""
936959

@@ -964,7 +987,7 @@ def loop(self, timeout=0):
964987

965988
return rcs if rcs else None
966989

967-
def _wait_for_msg(self, timeout=0.1):
990+
def _wait_for_msg(self, timeout: float = 0.1) -> Union[int, None]:
968991
# pylint: disable = too-many-return-statements
969992

970993
"""Reads and processes network events.
@@ -1004,7 +1027,7 @@ def _wait_for_msg(self, timeout=0.1):
10041027
sz = self._recv_len()
10051028
# topic length MSB & LSB
10061029
topic_len = self._sock_exact_recv(2)
1007-
topic_len = (topic_len[0] << 8) | topic_len[1]
1030+
topic_len = int((topic_len[0] << 8) | topic_len[1])
10081031

10091032
if topic_len > sz - 2:
10101033
raise MMQTTException(
@@ -1034,19 +1057,18 @@ def _wait_for_msg(self, timeout=0.1):
10341057

10351058
return res[0]
10361059

1037-
def _recv_len(self):
1060+
def _recv_len(self) -> int:
10381061
"""Unpack MQTT message length."""
10391062
n = 0
10401063
sh = 0
1041-
b = bytearray(1)
10421064
while True:
10431065
b = self._sock_exact_recv(1)[0]
10441066
n |= (b & 0x7F) << sh
10451067
if not b & 0x80:
10461068
return n
10471069
sh += 7
10481070

1049-
def _sock_exact_recv(self, bufsize):
1071+
def _sock_exact_recv(self, bufsize: int) -> bytearray:
10501072
"""Reads _exact_ number of bytes from the connected socket. Will only return
10511073
string with the exact number of bytes requested.
10521074
@@ -1100,7 +1122,7 @@ def _sock_exact_recv(self, bufsize):
11001122
)
11011123
return rc
11021124

1103-
def _send_str(self, string):
1125+
def _send_str(self, string: str) -> None:
11041126
"""Encodes a string and sends it to a socket.
11051127
11061128
:param str string: String to write to the socket.
@@ -1114,7 +1136,7 @@ def _send_str(self, string):
11141136
self._sock.send(string)
11151137

11161138
@staticmethod
1117-
def _valid_topic(topic):
1139+
def _valid_topic(topic: str) -> None:
11181140
"""Validates if topic provided is proper MQTT topic format.
11191141
11201142
:param str topic: Topic identifier
@@ -1130,7 +1152,7 @@ def _valid_topic(topic):
11301152
raise MMQTTException("Topic length is too large.")
11311153

11321154
@staticmethod
1133-
def _valid_qos(qos_level):
1155+
def _valid_qos(qos_level: int) -> None:
11341156
"""Validates if the QoS level is supported by this library
11351157
11361158
:param int qos_level: Desired QoS level.
@@ -1142,21 +1164,21 @@ def _valid_qos(qos_level):
11421164
else:
11431165
raise MMQTTException("QoS must be an integer.")
11441166

1145-
def _connected(self):
1167+
def _connected(self) -> None:
11461168
"""Returns MQTT client session status as True if connected, raises
11471169
a `MMQTTException` if `False`.
11481170
"""
11491171
if not self.is_connected():
11501172
raise MMQTTException("MiniMQTT is not connected")
11511173

1152-
def is_connected(self):
1174+
def is_connected(self) -> bool:
11531175
"""Returns MQTT client session status as True if connected, False
11541176
if not.
11551177
"""
11561178
return self._is_connected and self._sock is not None
11571179

11581180
# Logging
1159-
def enable_logger(self, log_pkg, log_level=20, logger_name="log"):
1181+
def enable_logger(self, log_pkg, log_level: int = 20, logger_name: str = "log"):
11601182
"""Enables library logging by getting logger from the specified logging package
11611183
and setting its log level.
11621184
@@ -1173,6 +1195,6 @@ def enable_logger(self, log_pkg, log_level=20, logger_name="log"):
11731195

11741196
return self.logger
11751197

1176-
def disable_logger(self):
1198+
def disable_logger(self) -> None:
11771199
"""Disables logging."""
11781200
self.logger = NullLogger()

0 commit comments

Comments
 (0)