Skip to content

Commit 3b7e669

Browse files
committed
Add just-in-case HID test mode
1 parent 50d2b76 commit 3b7e669

File tree

4 files changed

+404
-34
lines changed

4 files changed

+404
-34
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
import logging
2+
import time
3+
from enum import IntEnum
4+
from random import randint
5+
from typing import Optional, Callable, Sequence, Dict, Tuple
6+
7+
from uhid import UHIDDevice, _ReportType, AsyncioBlockingUHID, Bus
8+
from fido2.pcsc import CtapDevice, CTAPHID, CtapError, CtapPcscDevice
9+
10+
SECONDS_TO_WAIT_FOR_AUTHENTICATOR = 10
11+
"""How long, in seconds, to poll for a USB authenticator before giving up."""
12+
VID = 0x9999
13+
"""USB vendor ID."""
14+
PID = 0x9999
15+
"""USB product ID."""
16+
17+
BROADCAST_CHANNEL = bytes([0xFF, 0xFF, 0xFF, 0xFF])
18+
"""Standard CTAP-HID broadcast channel."""
19+
20+
21+
class CommandType(IntEnum):
22+
"""Catalog of CTAP-HID command type bytes."""
23+
PING = 0x01
24+
MSG = 0x03
25+
INIT = 0x06
26+
WINK = 0x08
27+
CBOR = 0x10
28+
CANCEL = 0x11
29+
KEEPALIVE = 0x3B
30+
ERROR = 0x3F
31+
32+
33+
def _wrap_call_with_device_obj(device: UHIDDevice,
34+
call: Callable[[UHIDDevice, Sequence[int], _ReportType], None]) -> Callable:
35+
"""Pass a UHIDDevice to a given callback."""
36+
return lambda x, y: call(device, x, y)
37+
38+
39+
class CTAPHIDDevice:
40+
device: UHIDDevice
41+
"""Underlying UHID device."""
42+
channels_to_devices: Dict[str, CtapDevice] = {}
43+
"""Mapping from channel strings to CTAP devices."""
44+
channels_to_state: Dict[str, Tuple[CommandType, int, int, bytes]] = {}
45+
"""
46+
Mapping from channel strings to receive buffer state.
47+
48+
Each value consists of:
49+
1. The command type in use on the channel
50+
2. The total length of the incoming request
51+
3. The sequence number of the most recently received packet (-1 for initial)
52+
4. The accumulated data received on the channel
53+
"""
54+
reference_count = 0
55+
"""Number of open handles to the device: clear state when it hits zero."""
56+
fixed_device: Optional[CtapDevice] = None
57+
"""Optional CtapDevice to which to proxy, instead of discovering one over PC/SC"""
58+
59+
def __init__(self, fixed_device: Optional[CtapDevice] = None):
60+
self.fixed_device = fixed_device
61+
self.device = UHIDDevice(
62+
vid=VID, pid=PID, name='FIDO2 Virtual USB Device', report_descriptor=[
63+
0x06, 0xD0, 0xF1, # Usage Page (FIDO)
64+
0x09, 0x01, # Usage (CTAPHID)
65+
0xa1, 0x01, # Collection (Application)
66+
0x09, 0x20, # Usage (Data In)
67+
0x15, 0x00, # Logical min (0)
68+
0x26, 0xFF, 0x00, # Logical max (255)
69+
0x75, 0x08, # Report Size (8)
70+
0x95, 0x40, # Report count (64 bytes per packet)
71+
0x81, 0x02, # Input(HID_Data | HID_Absolute | HID_Variable)
72+
0x09, 0x21, # Usage (Data Out)
73+
0x15, 0x00, # Logical min (0)
74+
0x26, 0xFF, 0x00, # Logical max (255)
75+
0x75, 0x08, # Report Size (8)
76+
0x95, 0x40, # Report count (64 bytes per packet)
77+
0x91, 0x02, # Output(HID_Data | HID_Absolute | HID_Variable)
78+
0xc0, # End Collection
79+
],
80+
backend=AsyncioBlockingUHID,
81+
version=0,
82+
bus=Bus.USB
83+
)
84+
85+
self.device.receive_output = self.process_hid_message
86+
self.device.receive_close = self.process_close
87+
self.device.receive_open = self.process_open
88+
89+
def process_open(self):
90+
self.reference_count += 1
91+
92+
def process_close(self):
93+
self.reference_count -= 1
94+
if self.reference_count == 0:
95+
# Clear all state
96+
self.channels_to_devices = {}
97+
self.channels_to_state = {}
98+
99+
def process_hid_message(self, buffer: Sequence[int], report_type: _ReportType) -> None:
100+
"""Core method: handle incoming HID messages."""
101+
recvd_bytes = bytes(buffer)
102+
logging.debug(f"GOT MESSAGE (type {report_type}): {recvd_bytes.hex()}")
103+
104+
if self.is_initial_packet(recvd_bytes):
105+
channel, lc, cmd, data = self.parse_initial_packet(recvd_bytes)
106+
channel_key = self.get_channel_key(channel)
107+
logging.debug(f"CMD {cmd.name} CHANNEL {channel_key} len {lc} (recvd {len(data)}) data {data.hex()}")
108+
self.channels_to_state[channel_key] = cmd, lc, -1, data
109+
if lc == len(data):
110+
# Complete receive
111+
self.finish_receiving(channel)
112+
else:
113+
channel, seq, new_data = self.parse_subsequent_packet(recvd_bytes)
114+
channel_key = self.get_channel_key(channel)
115+
if channel_key not in self.channels_to_state:
116+
self.send_error(channel, 0x0B)
117+
return
118+
cmd, lc, prev_seq, existing_data = self.channels_to_state[channel_key]
119+
if seq != prev_seq + 1:
120+
self.handle_cancel(channel, b"")
121+
self.send_error(channel, 0x04)
122+
return
123+
remaining = lc - len(existing_data)
124+
data = bytes([x for x in existing_data] + [x for x in new_data[:remaining]])
125+
self.channels_to_state[channel_key] = cmd, lc, seq, data
126+
logging.debug(f"After receive, we have {len(data)} bytes out of {lc}")
127+
if lc == len(data):
128+
self.finish_receiving(channel)
129+
130+
async def start(self):
131+
await self.device.wait_for_start_asyncio()
132+
133+
def parse_initial_packet(self, buffer: bytes) -> Tuple[bytes, int, CommandType, bytes]:
134+
"""Parse an incoming initial packet."""
135+
logging.debug(f"Initial packet {buffer.hex()}")
136+
channel = buffer[1:5]
137+
cmd_byte = buffer[5] & 0x7F
138+
lc = (int(buffer[6]) << 8) + buffer[7]
139+
data = buffer[8:8+lc]
140+
cmd = CommandType(cmd_byte)
141+
return channel, lc, cmd, data
142+
143+
def is_initial_packet(self, buffer: bytes) -> bool:
144+
"""Return true if packet is the start of a new sequence."""
145+
if buffer[5] & 0x80 == 0:
146+
return False
147+
return True
148+
149+
def assign_channel_id(self) -> Sequence[int]:
150+
"""Create a new, random, channel ID."""
151+
return [randint(0, 255), randint(0, 255),
152+
randint(0, 255), randint(0, 255)]
153+
154+
def handle_init(self, channel: bytes, buffer: bytes) -> Optional[Sequence[int]]:
155+
"""Initialize or re-initialize a channel."""
156+
logging.debug(f"INIT on channel {channel}")
157+
158+
new_channel = self.assign_channel_id()
159+
160+
ctap = self.get_pcsc_device(new_channel)
161+
if ctap is None:
162+
return None
163+
164+
if channel == BROADCAST_CHANNEL:
165+
assert len(buffer) == 8
166+
return ([x for x in buffer] +
167+
new_channel +
168+
[
169+
0x02, # protocol version
170+
0x01, # device version major
171+
0x00, # device version minor
172+
0x00, # device version build/point
173+
ctap.capabilities, # capabilities, from the underlying device
174+
])
175+
else:
176+
self.handle_cancel(channel, b"")
177+
178+
def get_pcsc_device(self, channel_id: Sequence[int]) -> Optional[CtapDevice]:
179+
"""Grab a PC/SC device from python-fido2, or use a fixed one if present."""
180+
channel_key = self.get_channel_key(channel_id)
181+
182+
if channel_key not in self.channels_to_devices:
183+
if self.fixed_device is not None:
184+
self.channels_to_devices[channel_key] = self.fixed_device
185+
else:
186+
start_time = time.time()
187+
while time.time() < start_time + SECONDS_TO_WAIT_FOR_AUTHENTICATOR:
188+
devices = list(CtapPcscDevice.list_devices())
189+
if len(devices) == 0:
190+
time.sleep(0.1)
191+
continue
192+
device = devices[0]
193+
self.channels_to_devices[channel_key] = device
194+
return device
195+
# TODO: send timeout error properly
196+
raise ValueError("Could not connect to a PC/SC device in time!")
197+
# self.send_error(channel_id, 0x05)
198+
# return None
199+
200+
return self.channels_to_devices[channel_key]
201+
202+
def handle_cbor(self, channel: Sequence[int], buffer: bytes) -> Optional[Sequence[int]]:
203+
"""Handling an incoming CBOR command."""
204+
ctap = self.get_pcsc_device(channel)
205+
if ctap is None:
206+
return None
207+
logging.debug(f"Sending CBOR to device {ctap}: {buffer}")
208+
try:
209+
res = ctap.call(cmd=CommandType.CBOR, data=buffer)
210+
return [x for x in res]
211+
except CtapError as e:
212+
logging.info(f"Got CTAP error response from device: {e}")
213+
return [e.code]
214+
215+
def handle_cancel(self, channel: Sequence[int], buffer: bytes) -> Optional[Sequence[int]]:
216+
channel_key = self.get_channel_key(channel)
217+
if channel_key in self.channels_to_state:
218+
del self.channels_to_state[channel_key]
219+
if channel_key in self.channels_to_devices:
220+
del self.channels_to_devices[channel_key]
221+
return []
222+
223+
def handle_wink(channel: Sequence[int], buffer: bytes) -> Optional[Sequence[int]]:
224+
"""Do nothing; this can't be done over PC/SC."""
225+
return []
226+
227+
def handle_msg(self, channel: Sequence[int], buffer: bytes) -> Optional[Sequence[int]]:
228+
"""Process a U2F/CTAP1 message."""
229+
device = self.get_pcsc_device(channel)
230+
if device is None:
231+
return None
232+
res = device.call(CTAPHID.MSG, buffer)
233+
return [x for x in res]
234+
235+
def handle_ping(self, channel: Sequence[int], buffer: bytes) -> Optional[Sequence[int]]:
236+
"""Handle an echo request."""
237+
return [x for x in buffer]
238+
239+
def handle_keepalive(self, channel: Sequence[int], buffer: bytes) -> Optional[Sequence[int]]:
240+
"""Placeholder: always returns that the device is processing."""
241+
return [1]
242+
243+
def encode_response_packets(self, channel: Sequence[int], cmd: CommandType, data: Sequence[int]) -> Sequence[bytes]:
244+
"""Chunk response data to be delivered over USB."""
245+
offset_start = 0
246+
seq = 0
247+
responses = []
248+
while offset_start < len(data):
249+
if seq == 0:
250+
capacity = 64 - 7
251+
chunk = data[offset_start:offset_start + capacity]
252+
data_len_upper = len(data) >> 8
253+
data_len_lower = len(data) % 256
254+
response = [x for x in channel] + [cmd | 0x80, data_len_upper, data_len_lower] + chunk
255+
else:
256+
capacity = 64 - 5
257+
chunk = data[offset_start:offset_start + capacity]
258+
response = [x for x in channel] + [seq - 1] + chunk
259+
260+
while len(response) < 64:
261+
response.append(0x00)
262+
263+
responses.append(bytes(response))
264+
offset_start += capacity
265+
seq += 1
266+
267+
return responses
268+
269+
def get_channel_key(self, channel: Sequence[int]) -> str:
270+
return bytes(channel).hex()
271+
272+
def send_error(self, channel: Sequence[int], error_type: int) -> None:
273+
responses = self.encode_response_packets(channel, CommandType.ERROR, [error_type])
274+
for response in responses:
275+
self.device.send_input(response)
276+
277+
def finish_receiving(self, channel: Sequence[int]) -> None:
278+
"""When finished receiving packets, act on them."""
279+
channel_key = self.get_channel_key(channel)
280+
cmd, _, _, data = self.channels_to_state[channel_key]
281+
self.handle_cancel(channel, b"")
282+
283+
try:
284+
handler = getattr(self, f"handle_{cmd.name.lower()}", None)
285+
if handler is not None:
286+
response_body = handler(channel, data)
287+
if response_body is None:
288+
# Already dealt with
289+
return
290+
responses = self.encode_response_packets(channel, cmd, response_body)
291+
else:
292+
self.send_error(channel, 0x01)
293+
return
294+
except Exception as e:
295+
logging.warning(f"Error: {e}")
296+
self.send_error(channel, 0x7F)
297+
return
298+
299+
for response in responses:
300+
self.device.send_input(response)
301+
302+
def parse_subsequent_packet(self, data: bytes) -> Tuple[Sequence[int], int, bytes]:
303+
"""Parse a non-initial packet."""
304+
return data[1:5], data[5], bytes(data[6:])

0 commit comments

Comments
 (0)