Skip to content

Commit b8008ae

Browse files
committed
Add type annotations for advertising/standard.py
1 parent 652fb27 commit b8008ae

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

adafruit_ble/advertising/standard.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,29 @@
2424
)
2525
from ..uuid import StandardUUID, VendorUUID
2626

27+
try:
28+
from typing import Optional, List, Tuple, Union, Type, Iterator, Iterable, Any
29+
from typing_extensions import Protocol
30+
from adafruit_ble.uuid import UUID
31+
from adafruit_ble.characteristics import Characteristic
32+
from adafruit_ble.services import Service
33+
from _bleio import ScanEntry
34+
35+
UsesServicesAdvertisement = Union["ProvideServicesAdvertisement", "SolicitServicesAdvertisement"]
36+
37+
38+
39+
except ImportError:
40+
pass
41+
2742
__version__ = "0.0.0-auto.0"
2843
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_BLE.git"
2944

3045

3146
class BoundServiceList:
3247
"""Sequence-like object of Service UUID objects. It stores both standard and vendor UUIDs."""
3348

34-
def __init__(self, advertisement, *, standard_services, vendor_services):
49+
def __init__(self, advertisement: UsesServicesAdvertisement, *, standard_services: List[int], vendor_services: List[int]) -> None:
3550
self._advertisement = advertisement
3651
self._standard_service_fields = standard_services
3752
self._vendor_service_fields = vendor_services
@@ -50,13 +65,13 @@ def __init__(self, advertisement, *, standard_services, vendor_services):
5065
uuid = VendorUUID(data[16 * i : 16 * (i + 1)])
5166
self._vendor_services.append(uuid)
5267

53-
def __contains__(self, key):
68+
def __contains__(self, key: Union[UUID, Characteristic]) -> bool:
5469
uuid = key
5570
if hasattr(key, "uuid"):
5671
uuid = key.uuid
5772
return uuid in self._vendor_services or uuid in self._standard_services
5873

59-
def _update(self, adt, uuids):
74+
def _update(self, adt: int, uuids: List[UUID]) -> None:
6075
if not uuids:
6176
# uuids is empty
6277
del self._advertisement.data_dict[adt]
@@ -68,13 +83,13 @@ def _update(self, adt, uuids):
6883
i += uuid_length
6984
self._advertisement.data_dict[adt] = b
7085

71-
def __iter__(self):
86+
def __iter__(self) -> Iterator[UUID]:
7287
all_services = list(self._standard_services)
7388
all_services.extend(self._vendor_services)
7489
return iter(all_services)
7590

7691
# TODO: Differentiate between complete and incomplete lists.
77-
def append(self, service):
92+
def append(self, service: Service) -> None:
7893
"""Append a service to the list."""
7994
if (
8095
isinstance(service.uuid, StandardUUID)
@@ -90,7 +105,7 @@ def append(self, service):
90105
self._update(self._vendor_service_fields[0], self._vendor_services)
91106

92107
# TODO: Differentiate between complete and incomplete lists.
93-
def extend(self, services):
108+
def extend(self, services: Iterable[Service]) -> None:
94109
"""Appends all services in the iterable to the list."""
95110
standard = False
96111
vendor = False
@@ -113,7 +128,7 @@ def extend(self, services):
113128
if vendor:
114129
self._update(self._vendor_service_fields[0], self._vendor_services)
115130

116-
def __str__(self):
131+
def __str__(self) -> str:
117132
data = []
118133
for service_uuid in self._standard_services:
119134
data.append(str(service_uuid))
@@ -125,11 +140,11 @@ def __str__(self):
125140
class ServiceList(AdvertisingDataField):
126141
"""Descriptor for a list of Service UUIDs that lazily binds a corresponding BoundServiceList."""
127142

128-
def __init__(self, *, standard_services, vendor_services):
143+
def __init__(self, *, standard_services: List[int], vendor_services: List[int]) -> None:
129144
self.standard_services = standard_services
130145
self.vendor_services = vendor_services
131146

132-
def _present(self, obj):
147+
def _present(self, obj: UsesServicesAdvertisement) -> bool:
133148
for adt in self.standard_services:
134149
if adt in obj.data_dict:
135150
return True
@@ -138,7 +153,7 @@ def _present(self, obj):
138153
return True
139154
return False
140155

141-
def __get__(self, obj, cls):
156+
def __get__(self, obj: Optional[UsesServicesAdvertisement], cls: Type[UsesServicesAdvertisement]) -> Union[UsesServicesAdvertisement, Tuple[()], "ServiceList"]:
142157
if obj is None:
143158
return self
144159
if not self._present(obj) and not obj.mutable:
@@ -159,7 +174,7 @@ class ProvideServicesAdvertisement(Advertisement):
159174
services = ServiceList(standard_services=[0x02, 0x03], vendor_services=[0x06, 0x07])
160175
"""List of services the device can provide."""
161176

162-
def __init__(self, *services, entry=None):
177+
def __init__(self, *services: Service, entry: Optional[ScanEntry] = None) -> None:
163178
super().__init__(entry=entry)
164179
if entry:
165180
if services:
@@ -173,7 +188,7 @@ def __init__(self, *services, entry=None):
173188
self.flags.le_only = True
174189

175190
@classmethod
176-
def matches(cls, entry):
191+
def matches(cls, entry: ScanEntry) -> bool:
177192
"""Only one kind of service list need be present in a ProvideServicesAdvertisement,
178193
so override the default behavior and match any prefix, not all.
179194
"""
@@ -189,7 +204,7 @@ class SolicitServicesAdvertisement(Advertisement):
189204
solicited_services = ServiceList(standard_services=[0x14], vendor_services=[0x15])
190205
"""List of services the device would like to use."""
191206

192-
def __init__(self, *services, entry=None):
207+
def __init__(self, *services: Service, entry: Optional[ScanEntry] = None) -> None:
193208
super().__init__(entry=entry)
194209
if entry:
195210
if services:
@@ -212,8 +227,8 @@ class ManufacturerData(AdvertisingDataField):
212227
"""
213228

214229
def __init__(
215-
self, obj, *, advertising_data_type=0xFF, company_id, key_encoding="B"
216-
):
230+
self, obj: UsesServicesAdvertisement, *, advertising_data_type: int = 0xFF, company_id: int, key_encoding: str = "B"
231+
) -> None:
217232
self._obj = obj
218233
self._company_id = company_id
219234
self._adt = advertising_data_type
@@ -231,15 +246,15 @@ def __init__(
231246
self.data = decode_data(existing_data[2:], key_encoding=key_encoding)
232247
self._key_encoding = key_encoding
233248

234-
def __len__(self):
249+
def __len__(self) -> int:
235250
return 2 + compute_length(self.data, key_encoding=self._key_encoding)
236251

237-
def __bytes__(self):
252+
def __bytes__(self) -> bytes:
238253
return struct.pack("<H", self.company_id) + encode_data(
239254
self.data, key_encoding=self._key_encoding
240255
)
241256

242-
def __str__(self):
257+
def __str__(self) -> str:
243258
hex_data = to_hex(encode_data(self.data, key_encoding=self._key_encoding))
244259
return "<ManufacturerData company_id={:04x} data={} >".format(
245260
self.company_id, hex_data
@@ -249,7 +264,7 @@ def __str__(self):
249264
class ManufacturerDataField:
250265
"""A single piece of data within the manufacturer specific data. The format can be repeated."""
251266

252-
def __init__(self, key, value_format, field_names=None):
267+
def __init__(self, key: int, value_format: str, field_names: Optional[Iterable[str]] = None) -> None:
253268
self._key = key
254269
self._format = value_format
255270
# TODO: Support format strings that use numbers to repeat a given type. For now, we strip
@@ -267,7 +282,7 @@ def __init__(self, key, value_format, field_names=None):
267282
# Mostly, this is to raise a ValueError if field_names has invalid entries
268283
self.mdf_tuple = namedtuple("mdf_tuple", self.field_names)
269284

270-
def __get__(self, obj, cls):
285+
def __get__(self, obj: Optional[Advertisement], cls: Type[Advertisement])-> Optional[Union["ManufacturerDataField", Tuple, namedtuple]]:
271286
if obj is None:
272287
return self
273288
if self._key not in obj.manufacturer_data.data:
@@ -293,7 +308,7 @@ def __get__(self, obj, cls):
293308
unpacked[i] = unpacked[i][0]
294309
return tuple(unpacked)
295310

296-
def __set__(self, obj, value):
311+
def __set__(self, obj: "Advertisement", value: Any) -> None:
297312
if not obj.mutable:
298313
raise AttributeError()
299314
if isinstance(value, tuple) and (
@@ -317,16 +332,16 @@ class ServiceData(AdvertisingDataField):
317332
"""Encapsulates service data. It is read as a memoryview which can be manipulated or set as a
318333
bytearray to change the size."""
319334

320-
def __init__(self, service):
335+
def __init__(self, service: Characteristic) -> None:
321336
if isinstance(service.uuid, StandardUUID):
322337
self._adt = 0x16
323338
elif isinstance(service.uuid, VendorUUID):
324339
self._adt = 0x21
325340
self._prefix = bytes(service.uuid)
326341

327342
def __get__(
328-
self, obj, cls
329-
): # pylint: disable=too-many-return-statements,too-many-branches
343+
self, obj: Optional[Service], cls: Type[Service]
344+
) -> Optional[Union["ServiceData", memoryview]]: # pylint: disable=too-many-return-statements,too-many-branches
330345
if obj is None:
331346
return self
332347
# If not present at all and mutable, then we init it, otherwise None.
@@ -366,7 +381,7 @@ def __get__(
366381

367382
return None
368383

369-
def __set__(self, obj, value):
384+
def __set__(self, obj: Advertisement, value: bytearray) -> None:
370385
if not obj.mutable:
371386
raise RuntimeError("Advertisement immutable")
372387
if not isinstance(value, bytearray):

0 commit comments

Comments
 (0)