Skip to content

Commit 50fdf36

Browse files
committed
pybricksdev.dfu: add some type hints
Add some type hints to DFU code. It still isn't perfect since pyusb doesn't have type hints.
1 parent be626ad commit 50fdf36

File tree

3 files changed

+78
-62
lines changed

3 files changed

+78
-62
lines changed

pybricksdev/_vendored/dfu_create.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,29 @@
1313
import sys
1414
import zlib
1515
from optparse import OptionParser
16+
from typing import Any, TypedDict
1617

1718
DEFAULT_DEVICE = "0x0483:0xdf11"
1819

1920

20-
def named(tuple, names):
21+
def named(tuple: tuple[Any], names: str) -> dict[str, Any]:
2122
return dict(zip(names.split(), tuple))
2223

2324

24-
def consume(fmt, data, names):
25+
def consume(fmt: str, data: bytes, names: str) -> tuple[dict[str, Any], bytes]:
2526
n = struct.calcsize(fmt)
2627
return named(struct.unpack(fmt, data[:n]), names), data[n:]
2728

2829

29-
def cstring(string):
30+
def cstring(string: str) -> str:
3031
return string.split("\0", 1)[0]
3132

3233

33-
def compute_crc(data):
34+
def compute_crc(data: bytes) -> int:
3435
return 0xFFFFFFFF & -zlib.crc32(data) - 1
3536

3637

37-
def parse(file, dump_images=False):
38+
def parse(file: str, dump_images: bool = False):
3839
print('File: "%s"' % file)
3940
data = open(file, "rb").read()
4041
crc = compute_crc(data[:-4])
@@ -84,10 +85,16 @@ def parse(file, dump_images=False):
8485
print("PARSE ERROR")
8586

8687

87-
def build(file, targets, device=DEFAULT_DEVICE):
88+
class Image(TypedDict):
89+
address: int
90+
data: bytes
91+
92+
93+
def build(file: str, targets: list[list[Image]], device: str = DEFAULT_DEVICE) -> None:
8894
data = b""
89-
for t, target in enumerate(targets):
95+
for target in targets:
9096
tdata = b""
97+
9198
for image in target:
9299
# pad image to 8 bytes (needed at least for L476)
93100
pad = (8 - len(image["data"]) % 8) % 8
@@ -143,7 +150,7 @@ def build(file, targets, device=DEFAULT_DEVICE):
143150
(options, args) = parser.parse_args()
144151

145152
if options.binfiles and len(args) == 1:
146-
target = []
153+
target: list[Image] = []
147154
for arg in options.binfiles:
148155
try:
149156
address, binfile = arg.split(":", 1)

pybricksdev/_vendored/dfu_upload.py

Lines changed: 57 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from __future__ import print_function
1919

2020
import argparse
21-
import collections
22-
import inspect
2321
import re
2422
import struct
2523
import sys
2624
import zlib
25+
from collections.abc import Callable
26+
from typing import Any, NamedTuple, TypedDict
2727

2828
import usb.core
2929
import usb.util
@@ -84,31 +84,27 @@
8484
# USB DFU interface
8585
__DFU_INTERFACE = 0
8686

87-
if "length" in inspect.getfullargspec(usb.util.get_string).args:
88-
# PyUSB 1.0.0.b1 has the length argument
89-
def get_string(dev, index):
90-
return usb.util.get_string(dev, 255, index)
9187

92-
else:
93-
# PyUSB 1.0.0.b2 dropped the length argument
94-
def get_string(dev, index):
95-
return usb.util.get_string(dev, index)
88+
class Element(TypedDict):
89+
num: int
90+
addr: int
91+
size: int
92+
data: bytes
9693

9794

98-
def find_dfu_cfg_descr(descr):
95+
class CfgDescr(NamedTuple):
96+
bLength: int
97+
bDescriptorType: int
98+
bmAttributes: int
99+
wDetachTimeOut: int
100+
wTransferSize: int
101+
bcdDFUVersion: int
102+
103+
104+
def find_dfu_cfg_descr(descr: bytes) -> CfgDescr | None:
99105
if len(descr) == 9 and descr[0] == 9 and descr[1] == _DFU_DESCRIPTOR_TYPE:
100-
nt = collections.namedtuple(
101-
"CfgDescr",
102-
[
103-
"bLength",
104-
"bDescriptorType",
105-
"bmAttributes",
106-
"wDetachTimeOut",
107-
"wTransferSize",
108-
"bcdDFUVersion",
109-
],
110-
)
111-
return nt(*struct.unpack("<BBBHHH", bytearray(descr)))
106+
return CfgDescr(*struct.unpack("<BBBHHH", bytearray(descr)))
107+
112108
return None
113109

114110

@@ -138,7 +134,7 @@ def init():
138134
break
139135

140136
# Get device into idle state
141-
for attempt in range(4):
137+
for _ in range(4):
142138
status = get_status()
143139
if status == __DFU_STATE_DFU_IDLE:
144140
break
@@ -151,30 +147,30 @@ def init():
151147
clr_status()
152148

153149

154-
def abort_request():
150+
def abort_request() -> None:
155151
"""Sends an abort request."""
156152
__dev.ctrl_transfer(0x21, __DFU_ABORT, 0, __DFU_INTERFACE, None, __TIMEOUT)
157153

158154

159-
def clr_status():
155+
def clr_status() -> None:
160156
"""Clears any error status (perhaps left over from a previous session)."""
161157
__dev.ctrl_transfer(0x21, __DFU_CLRSTATUS, 0, __DFU_INTERFACE, None, __TIMEOUT)
162158

163159

164-
def get_status():
160+
def get_status() -> int:
165161
"""Get the status of the last operation."""
166162
stat = __dev.ctrl_transfer(0xA1, __DFU_GETSTATUS, 0, __DFU_INTERFACE, 6, 20000)
167163

168164
# firmware can provide an optional string for any error
169165
if stat[5]:
170-
message = get_string(__dev, stat[5])
166+
message = usb.util.get_string(__dev, stat[5])
171167
if message:
172168
print(message)
173169

174170
return stat[4]
175171

176172

177-
def check_status(stage, expected):
173+
def check_status(stage: str, expected: int) -> None:
178174
status = get_status()
179175
if status != expected:
180176
raise SystemExit(
@@ -194,7 +190,7 @@ def mass_erase():
194190
check_status("erase", __DFU_STATE_DFU_DOWNLOAD_IDLE)
195191

196192

197-
def page_erase(addr):
193+
def page_erase(addr: int) -> None:
198194
"""Erases a single page."""
199195
if __verbose:
200196
print("Erasing page: 0x%x..." % (addr))
@@ -210,7 +206,7 @@ def page_erase(addr):
210206
check_status("erase", __DFU_STATE_DFU_DOWNLOAD_IDLE)
211207

212208

213-
def set_address(addr):
209+
def set_address(addr: int) -> None:
214210
"""Sets the address for the next operation."""
215211
# Send DNLOAD with first byte=0x21 and page address
216212
buf = struct.pack("<BI", 0x21, addr)
@@ -223,7 +219,13 @@ def set_address(addr):
223219
check_status("set address", __DFU_STATE_DFU_DOWNLOAD_IDLE)
224220

225221

226-
def write_memory(addr, buf, progress=None, progress_addr=0, progress_size=0):
222+
def write_memory(
223+
addr: int,
224+
buf: bytes,
225+
progress: Callable[[int, int, int], None] | None = None,
226+
progress_addr: int = 0,
227+
progress_size: int = 0,
228+
) -> None:
227229
"""Writes a buffer into memory. This routine assumes that memory has
228230
already been erased.
229231
"""
@@ -268,7 +270,7 @@ def write_memory(addr, buf, progress=None, progress_addr=0, progress_size=0):
268270
xfer_bytes += chunk
269271

270272

271-
def write_page(buf, xfer_offset):
273+
def write_page(buf: bytes, xfer_offset: int) -> None:
272274
"""Writes a single page. This routine assumes that memory has already
273275
been erased.
274276
"""
@@ -291,7 +293,7 @@ def write_page(buf, xfer_offset):
291293
print("Write: 0x%x " % (xfer_base + xfer_offset))
292294

293295

294-
def exit_dfu():
296+
def exit_dfu() -> None:
295297
"""Exit DFU mode, and start running the program."""
296298
# Set jump address
297299
set_address(0x08000000)
@@ -310,12 +312,12 @@ def exit_dfu():
310312
pass
311313

312314

313-
def named(values, names):
315+
def named(values: tuple[Any], names: str) -> dict[str, Any]:
314316
"""Creates a dict with `names` as fields, and `values` as values."""
315317
return dict(zip(names.split(), values))
316318

317319

318-
def consume(fmt, data, names):
320+
def consume(fmt: str, data: bytes, names: str) -> tuple[dict[str, Any], bytes]:
319321
"""Parses the struct defined by `fmt` from `data`, stores the parsed fields
320322
into a named tuple using `names`. Returns the named tuple, and the data
321323
with the struct stripped off."""
@@ -324,17 +326,17 @@ def consume(fmt, data, names):
324326
return named(struct.unpack(fmt, data[:size]), names), data[size:]
325327

326328

327-
def cstring(string):
329+
def cstring(string: bytes) -> str:
328330
"""Extracts a null-terminated string from a byte array."""
329331
return string.decode("utf-8").split("\0", 1)[0]
330332

331333

332-
def compute_crc(data):
334+
def compute_crc(data: bytes) -> int:
333335
"""Computes the CRC32 value for the data passed in."""
334336
return 0xFFFFFFFF & -zlib.crc32(data) - 1
335337

336338

337-
def read_dfu_file(filename):
339+
def read_dfu_file(filename: str) -> list[Element] | None:
338340
"""Reads a DFU file, and parses the individual elements from the file.
339341
Returns an array of elements. Each element is a dictionary with the
340342
following keys:
@@ -349,7 +351,7 @@ def read_dfu_file(filename):
349351
with open(filename, "rb") as fin:
350352
data = fin.read()
351353
crc = compute_crc(data[:-4])
352-
elements = []
354+
elements: list[Element] = []
353355

354356
# Decode the DFU Prefix
355357
#
@@ -430,11 +432,11 @@ def read_dfu_file(filename):
430432
)
431433
if crc != dfu_suffix["crc"]:
432434
print("CRC ERROR: computed crc32 is 0x%08x" % crc)
433-
return
435+
return None
434436
data = data[16:]
435437
if data:
436438
print("PARSE ERROR")
437-
return
439+
return None
438440

439441
return elements
440442

@@ -444,13 +446,15 @@ class FilterDFU(object):
444446
mode.
445447
"""
446448

447-
def __call__(self, device):
449+
def __call__(self, device: usb.core.Device) -> bool | None:
448450
for cfg in device:
449451
for intf in cfg:
450452
return intf.bInterfaceClass == 0xFE and intf.bInterfaceSubClass == 1
451453

454+
return None
452455

453-
def get_dfu_devices(*args, **kwargs):
456+
457+
def get_dfu_devices(*args: Any, **kwargs: Any) -> list[usb.core.Device]:
454458
"""Returns a list of USB devices which are currently in DFU mode.
455459
Additional filters (like idProduct and idVendor) can be passed in
456460
to refine the search.
@@ -460,7 +464,7 @@ def get_dfu_devices(*args, **kwargs):
460464
return list(usb.core.find(*args, find_all=True, custom_match=FilterDFU(), **kwargs))
461465

462466

463-
def get_memory_layout(device):
467+
def get_memory_layout(device: usb.core.Device) -> list[dict[str, Any]]:
464468
"""Returns an array which identifies the memory layout. Each entry
465469
of the array will contain a dictionary with the following keys:
466470
addr - Address of this memory segment.
@@ -472,7 +476,7 @@ def get_memory_layout(device):
472476

473477
cfg = device[0]
474478
intf = cfg[(0, 0)]
475-
mem_layout_str = get_string(device, intf.iInterface)
479+
mem_layout_str = usb.util.get_string(device, intf.iInterface)
476480
mem_layout = mem_layout_str.split("/")
477481
result = []
478482
for mem_layout_index in range(1, len(mem_layout), 2):
@@ -521,7 +525,11 @@ def list_dfu_devices(*args, **kwargs):
521525
)
522526

523527

524-
def write_elements(elements, mass_erase_used, progress=None):
528+
def write_elements(
529+
elements: list[Element],
530+
mass_erase_used: bool,
531+
progress: Callable[[int, int, int], None] | None = None,
532+
):
525533
"""Writes the indicated elements into the target memory,
526534
erasing as needed.
527535
"""
@@ -556,7 +564,7 @@ def write_elements(elements, mass_erase_used, progress=None):
556564
progress(elem_addr, addr - elem_addr, elem_size)
557565

558566

559-
def cli_progress(addr, offset, size):
567+
def cli_progress(addr: int, offset: int, size: int) -> None:
560568
"""Prints a progress report suitable for use on the command line."""
561569
width = 25
562570
done = offset * width // size
@@ -574,7 +582,7 @@ def cli_progress(addr, offset, size):
574582
print("")
575583

576584

577-
def main():
585+
def main() -> None:
578586
"""Test program for verifying this files functionality."""
579587
global __verbose
580588
global __VID
@@ -617,7 +625,6 @@ def main():
617625
args = parser.parse_args()
618626

619627
__verbose = args.verbose
620-
621628
__VID = args.vid
622629
__PID = args.pid
623630

0 commit comments

Comments
 (0)