1- use axum:: { extract:: State , routing:: post, Json , Router } ;
1+ use axum:: {
2+ extract:: {
3+ ws:: { Message , WebSocket } ,
4+ Query , State , WebSocketUpgrade ,
5+ } ,
6+ response:: { IntoResponse , Response } ,
7+ routing:: { get, post} ,
8+ Json , Router ,
9+ } ;
10+ use futures_util:: { sink:: SinkExt , stream:: StreamExt } ;
11+ use serde:: Deserialize ;
12+ use serde_json:: json;
13+ use std:: collections:: hash_map:: Entry ;
14+ use tokio:: { sync:: oneshot, task:: JoinSet } ;
215
316use crate :: {
417 error:: ApiError ,
@@ -14,9 +27,116 @@ pub(crate) fn router() -> Router<AppState> {
1427 Router :: new ( )
1528 . route ( "/start" , post ( start_client_mfa) )
1629 . route ( "/finish" , post ( finish_client_mfa) )
30+ . route ( "/remote" , get ( await_remote_auth) )
31+ . route ( "/finish-remote" , post ( finish_remote_mfa) )
1732}
1833
19- #[ instrument( level = "debug" , skip( state) ) ]
34+ #[ derive( Debug , Clone , Deserialize ) ]
35+ pub ( crate ) struct RemoteMfaRequestQuery {
36+ pub token : String ,
37+ }
38+
39+ // Allows desktop client to await for another device to complete MFA for it via mobile client
40+ #[ instrument( level = "debug" , skip( state, req) ) ]
41+ async fn await_remote_auth (
42+ ws : WebSocketUpgrade ,
43+ Query ( req) : Query < RemoteMfaRequestQuery > ,
44+ State ( state) : State < AppState > ,
45+ device_info : DeviceInfo ,
46+ ) -> Result < Response , impl IntoResponse > {
47+ let token = req. token ;
48+ // let core validate token first
49+ let rx = state. grpc_server . send (
50+ core_request:: Payload :: ClientMfaTokenValidation (
51+ crate :: proto:: ClientMfaTokenValidationRequest {
52+ token : token. clone ( ) ,
53+ } ,
54+ ) ,
55+ device_info,
56+ ) ?;
57+ let payload = get_core_response ( rx) . await ?;
58+ if let core_response:: Payload :: ClientMfaTokenValidation ( response) = payload {
59+ if !response. token_valid {
60+ return Err ( ApiError :: Unauthorized ( String :: new ( ) ) ) ;
61+ }
62+ // check if its already in the map
63+ let contains_key = {
64+ let sessions = state. remote_mfa_sessions . lock ( ) . await ;
65+ sessions. contains_key ( & token)
66+ } ;
67+ if contains_key {
68+ return Err ( ApiError :: Unauthorized ( String :: new ( ) ) ) ;
69+ } ;
70+ Ok ( ws. on_upgrade ( move |socket| handle_remote_auth_socket ( socket, state. clone ( ) , token) ) )
71+ } else {
72+ Err ( ApiError :: InvalidResponseType )
73+ }
74+ }
75+
76+ // handle axum ws socket upgrade for await_remote_auth
77+ async fn handle_remote_auth_socket ( socket : WebSocket , state : AppState , token : String ) {
78+ let ( tx, rx) = oneshot:: channel :: < String > ( ) ;
79+ let ( mut ws_tx, mut ws_rx) = socket. split ( ) ;
80+
81+ let occupied = {
82+ let mut sessions = state. remote_mfa_sessions . lock ( ) . await ;
83+ match sessions. entry ( token. clone ( ) ) {
84+ Entry :: Occupied ( _) => true ,
85+ Entry :: Vacant ( v) => {
86+ v. insert ( tx) ;
87+ false
88+ }
89+ }
90+ } ;
91+ if occupied {
92+ let _ = ws_tx. close ( ) . await ;
93+ return ;
94+ }
95+
96+ let mut set = JoinSet :: new ( ) ;
97+
98+ set. spawn ( async move {
99+ if let Ok ( msg) = rx. await {
100+ let payload = json ! ( {
101+ "type" : "mfa_success" ,
102+ "preshared_key" : & msg,
103+ } ) ;
104+ if let Ok ( serialized) = serde_json:: to_string ( & payload) {
105+ let message = Message :: Text ( serialized) ;
106+ if ws_tx. send ( message) . await . is_err ( ) {
107+ error ! ( "Failed to send preshared key via ws" ) ;
108+ }
109+ } else {
110+ error ! ( "Failed to serialize remote mfa ws client response message" ) ;
111+ }
112+ } else {
113+ error ! ( "Failed to receive preshared key from receiver" )
114+ }
115+ let _ = ws_tx. close ( ) . await ;
116+ } ) ;
117+ set. spawn ( async move {
118+ while let Some ( msg_result) = ws_rx. next ( ) . await {
119+ match msg_result {
120+ Ok ( msg) => {
121+ if let Message :: Close ( _) = msg {
122+ break ;
123+ }
124+ }
125+ Err ( e) => {
126+ error ! ( "Remote desktop mfa WS client listen error {e}" ) ;
127+ break ;
128+ }
129+ }
130+ }
131+ } ) ;
132+
133+ let _ = set. join_next ( ) . await ;
134+ set. shutdown ( ) . await ;
135+ // will remove token if it's still there
136+ state. remote_mfa_sessions . lock ( ) . await . remove ( & token) ;
137+ }
138+
139+ #[ instrument( level = "debug" , skip( state, req) ) ]
20140async fn start_client_mfa (
21141 State ( state) : State < AppState > ,
22142 device_info : DeviceInfo ,
@@ -38,7 +158,7 @@ async fn start_client_mfa(
38158 }
39159}
40160
41- #[ instrument( level = "debug" , skip( state) ) ]
161+ #[ instrument( level = "debug" , skip( state, req ) ) ]
42162async fn finish_client_mfa (
43163 State ( state) : State < AppState > ,
44164 device_info : DeviceInfo ,
@@ -50,10 +170,52 @@ async fn finish_client_mfa(
50170 . send ( core_request:: Payload :: ClientMfaFinish ( req) , device_info) ?;
51171 let payload = get_core_response ( rx) . await ?;
52172 if let core_response:: Payload :: ClientMfaFinish ( response) = payload {
53- info ! ( "Finished desktop client authorization" ) ;
54173 Ok ( Json ( response) )
55174 } else {
56175 error ! ( "Received invalid gRPC response type: {payload:#?}" ) ;
57176 Err ( ApiError :: InvalidResponseType )
58177 }
59178}
179+
180+ #[ instrument( level = "debug" , skip( state, req) ) ]
181+ async fn finish_remote_mfa (
182+ State ( state) : State < AppState > ,
183+ device_info : DeviceInfo ,
184+ Json ( req) : Json < ClientMfaFinishRequest > ,
185+ ) -> Result < Json < serde_json:: Value > , ApiError > {
186+ info ! ( "Finishing desktop client authorization" ) ;
187+ let rx = state
188+ . grpc_server
189+ . send ( core_request:: Payload :: ClientMfaFinish ( req) , device_info) ?;
190+ let payload = get_core_response ( rx) . await ?;
191+ if let core_response:: Payload :: ClientMfaFinish ( response) = payload {
192+ // check if this needs to be forwarded
193+ match response. token {
194+ Some ( token) => {
195+ let sender_option = {
196+ let mut sessions = state. remote_mfa_sessions . lock ( ) . await ;
197+ sessions. remove ( & token)
198+ } ;
199+ match sender_option {
200+ Some ( sender) => {
201+ let _ = sender. send ( response. preshared_key ) ;
202+ }
203+ // if desktop stopped listening for the result there will be no palce to send the result
204+ None => {
205+ error ! ( "Remote MFA approve finished but session was not found." ) ;
206+ return Err ( ApiError :: Unexpected ( String :: new ( ) ) ) ;
207+ }
208+ }
209+ info ! ( "Finished desktop client authorization via mobile device" ) ;
210+ Ok ( Json ( json ! ( { } ) ) )
211+ }
212+ None => {
213+ error ! ( "Remote MFA Unexpected core response, token was not returned" ) ;
214+ Err ( ApiError :: Unexpected ( String :: new ( ) ) )
215+ }
216+ }
217+ } else {
218+ error ! ( "Received invalid gRPC response type: {payload:#?}" ) ;
219+ Err ( ApiError :: InvalidResponseType )
220+ }
221+ }
0 commit comments