1- use std:: sync:: { Arc , Mutex , OnceLock } ;
1+ use std:: {
2+ sync:: { Arc , Mutex , OnceLock } ,
3+ time:: Duration ,
4+ } ;
25
36use libwebauthn:: {
47 self ,
58 ops:: webauthn:: { GetAssertionResponse , MakeCredentialResponse } ,
6- transport:: Device as _,
9+ transport:: { hid :: HidDevice , Device as _} ,
710 webauthn:: { Error as WebAuthnError , WebAuthn } ,
811 UxUpdate ,
912} ;
@@ -62,32 +65,67 @@ impl CredentialService {
6265
6366 pub ( crate ) async fn poll_device_discovery_usb ( & mut self ) -> Result < UsbState , String > {
6467 debug ! ( "polling for USB status" ) ;
65- let prev_usb_state = * self . usb_state . lock ( ) . await ;
68+ let prev_usb_state = self . usb_state . lock ( ) . await . clone ( ) ;
6669 let next_usb_state = match prev_usb_state {
6770 UsbState :: Idle | UsbState :: Waiting => {
68- let devices = libwebauthn:: transport:: hid:: list_devices ( ) . await . unwrap ( ) ;
69- if devices . is_empty ( ) {
71+ let mut hid_devices = libwebauthn:: transport:: hid:: list_devices ( ) . await . unwrap ( ) ;
72+ if hid_devices . is_empty ( ) {
7073 let state = UsbState :: Waiting ;
71- * self . usb_state . lock ( ) . await = state;
74+ * self . usb_state . lock ( ) . await = state. clone ( ) ;
7275 return Ok ( state) ;
73- }
74- if devices. is_empty ( ) {
75- Ok ( UsbState :: Waiting )
76+ } else if hid_devices. len ( ) == 1 {
77+ Ok ( UsbState :: Connected ( hid_devices. swap_remove ( 0 ) ) )
7678 } else {
77- Ok ( UsbState :: Connected )
79+ Ok ( UsbState :: SelectingDevice ( hid_devices ) )
7880 }
7981 }
80- UsbState :: Connected => {
81- // TODO: I'm not sure how we want to handle multiple usb devices
82- // just take the first one found for now.
83- // TODO: store this device reference, perhaps in the enum itself
82+ UsbState :: SelectingDevice ( hid_devices) => {
83+ let ( blinking_tx, mut blinking_rx) =
84+ tokio:: sync:: mpsc:: channel :: < Option < HidDevice > > ( hid_devices. len ( ) ) ;
85+ let mut expected_answers = hid_devices. len ( ) ;
86+ for mut device in hid_devices {
87+ let tx = blinking_tx. clone ( ) ;
88+ tokio ( ) . spawn ( async move {
89+ let ( mut channel, _state_rx) = device. channel ( ) . await . unwrap ( ) ;
90+ let res = channel
91+ . blink_and_wait_for_user_presence ( Duration :: from_secs ( 300 ) )
92+ . await ;
93+ drop ( channel) ;
94+ match res {
95+ Ok ( true ) => {
96+ let _ = tx. send ( Some ( device) ) . await ;
97+ }
98+ Ok ( false ) | Err ( _) => {
99+ let _ = tx. send ( None ) . await ;
100+ }
101+ }
102+ } ) ;
103+ }
104+ let mut state = UsbState :: Idle ;
105+ while let Some ( msg) = blinking_rx. recv ( ) . await {
106+ expected_answers -= 1 ;
107+ match msg {
108+ Some ( device) => {
109+ state = UsbState :: Connected ( device) ;
110+ break ;
111+ }
112+ None => {
113+ if expected_answers == 0 {
114+ break ;
115+ } else {
116+ continue ;
117+ }
118+ }
119+ }
120+ }
121+ Ok ( state)
122+ }
123+ UsbState :: Connected ( mut device) => {
84124 let handler = self . usb_uv_handler . clone ( ) ;
85125 let cred_request = self . cred_request . clone ( ) ;
86126 let signal_tx = self . usb_uv_handler . signal_tx . clone ( ) ;
87127 let pin_rx = self . usb_uv_handler . pin_rx . clone ( ) ;
88128 tokio ( ) . spawn ( async move {
89- let mut devices = libwebauthn:: transport:: hid:: list_devices ( ) . await . unwrap ( ) ;
90- let device = devices. first_mut ( ) . unwrap ( ) ;
91129 let ( mut channel, state_rx) = device. channel ( ) . await . unwrap ( ) ;
92130 tokio ( ) . spawn ( async move {
93131 handle_usb_updates ( signal_tx, pin_rx, state_rx) . await ;
@@ -252,7 +290,7 @@ impl CredentialService {
252290 UsbState :: Completed => Ok ( prev_usb_state) ,
253291 } ?;
254292
255- * self . usb_state . lock ( ) . await = next_usb_state;
293+ * self . usb_state . lock ( ) . await = next_usb_state. clone ( ) ;
256294 Ok ( next_usb_state)
257295 }
258296
@@ -263,7 +301,7 @@ impl CredentialService {
263301 }
264302
265303 pub ( crate ) async fn validate_usb_device_pin ( & mut self , pin : & str ) -> Result < ( ) , ( ) > {
266- let current_state = * self . usb_state . lock ( ) . await ;
304+ let current_state = self . usb_state . lock ( ) . await . clone ( ) ;
267305 match current_state {
268306 UsbState :: NeedsPin {
269307 attempts_left : Some ( attempts_left) ,
@@ -281,7 +319,7 @@ impl CredentialService {
281319 }
282320}
283321
284- #[ derive( Copy , Clone , Debug , Default , PartialEq ) ]
322+ #[ derive( Clone , Debug , Default ) ]
285323pub enum UsbState {
286324 /// Not polling for FIDO USB device.
287325 #[ default]
@@ -291,13 +329,17 @@ pub enum UsbState {
291329 Waiting ,
292330
293331 /// USB device connected, prompt user to tap
294- Connected ,
332+ Connected ( HidDevice ) ,
295333
296334 /// The device needs the PIN to be entered.
297- NeedsPin { attempts_left : Option < u32 > } ,
335+ NeedsPin {
336+ attempts_left : Option < u32 > ,
337+ } ,
298338
299339 /// The device needs on-device user verification.
300- NeedsUserVerification { attempts_left : Option < u32 > } ,
340+ NeedsUserVerification {
341+ attempts_left : Option < u32 > ,
342+ } ,
301343
302344 /// The device needs evidence of user presence (e.g. touch) to release the credential.
303345 NeedsUserPresence ,
@@ -306,7 +348,11 @@ pub enum UsbState {
306348 Completed ,
307349 // TODO: implement cancellation
308350 // This isn't actually sent from the server.
309- // UserCancelled,
351+ //UserCancelled,
352+
353+ // When we encounter multiple devices, we let all of them blink and continue
354+ // with the one that was tapped.
355+ SelectingDevice ( Vec < HidDevice > ) ,
310356}
311357
312358#[ derive( Clone , Debug ) ]
0 commit comments