Skip to content

Commit 67e2351

Browse files
Desktop MFA mobile approve (#138)
* compiles * update proto * add mfa token validation * Update src/handlers/desktop_client_mfa.rs Co-authored-by: Adam <[email protected]> * fix cargo lock * remove dashmap * review changes * fmt * review changes * Update defguard-ui * update proto * upgrade ui module * Update desktop_client_mfa.rs * Update desktop_client_mfa.rs --------- Co-authored-by: Adam <[email protected]>
1 parent 29059aa commit 67e2351

File tree

20 files changed

+944
-708
lines changed

20 files changed

+944
-708
lines changed

Cargo.lock

Lines changed: 53 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ repository = "https://github.com/DefGuard/proxy"
88

99
[dependencies]
1010
# base `axum` deps
11-
axum = { version = "0.7", features = ["macros", "tracing"] }
11+
axum = { version = "0.7", features = ["macros", "tracing", "ws"] }
1212
axum-client-ip = "0.6"
1313
axum-extra = { version = "0.9", features = [
1414
"cookie",
@@ -48,6 +48,8 @@ tower_governor = "0.4"
4848
rust-embed = { version = "8.5", features = ["include-exclude"] }
4949
mime_guess = "2.0"
5050
base64 = "0.22.1"
51+
futures = "0.3.31"
52+
futures-util = "0.3.31"
5153

5254
[build-dependencies]
5355
tonic-prost-build = "0.14"

proto

src/handlers/desktop_client_mfa.rs

Lines changed: 166 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
1-
use axum::{extract::State, routing::post, Json, Router};
1+
use axum::{
2+
extract::{
3+
ws::{Message, WebSocket},
4+
Query, State, WebSocketUpgrade,
5+
},
6+
response::{IntoResponse, Response},
7+
routing::{get, post},
8+
Json, Router,
9+
};
10+
use futures_util::{sink::SinkExt, stream::StreamExt};
11+
use serde::Deserialize;
12+
use serde_json::json;
13+
use std::collections::hash_map::Entry;
14+
use tokio::{sync::oneshot, task::JoinSet};
215

316
use crate::{
417
error::ApiError,
@@ -14,9 +27,116 @@ pub(crate) fn router() -> Router<AppState> {
1427
Router::new()
1528
.route("/start", post(start_client_mfa))
1629
.route("/finish", post(finish_client_mfa))
30+
.route("/remote", get(await_remote_auth))
31+
.route("/finish-remote", post(finish_remote_mfa))
1732
}
1833

19-
#[instrument(level = "debug", skip(state))]
34+
#[derive(Debug, Clone, Deserialize)]
35+
pub(crate) struct RemoteMfaRequestQuery {
36+
pub token: String,
37+
}
38+
39+
// Allows desktop client to await for another device to complete MFA for it via mobile client
40+
#[instrument(level = "debug", skip(state, req))]
41+
async fn await_remote_auth(
42+
ws: WebSocketUpgrade,
43+
Query(req): Query<RemoteMfaRequestQuery>,
44+
State(state): State<AppState>,
45+
device_info: DeviceInfo,
46+
) -> Result<Response, impl IntoResponse> {
47+
let token = req.token;
48+
// let core validate token first
49+
let rx = state.grpc_server.send(
50+
core_request::Payload::ClientMfaTokenValidation(
51+
crate::proto::ClientMfaTokenValidationRequest {
52+
token: token.clone(),
53+
},
54+
),
55+
device_info,
56+
)?;
57+
let payload = get_core_response(rx).await?;
58+
if let core_response::Payload::ClientMfaTokenValidation(response) = payload {
59+
if !response.token_valid {
60+
return Err(ApiError::Unauthorized(String::new()));
61+
}
62+
// check if its already in the map
63+
let contains_key = {
64+
let sessions = state.remote_mfa_sessions.lock().await;
65+
sessions.contains_key(&token)
66+
};
67+
if contains_key {
68+
return Err(ApiError::Unauthorized(String::new()));
69+
};
70+
Ok(ws.on_upgrade(move |socket| handle_remote_auth_socket(socket, state.clone(), token)))
71+
} else {
72+
Err(ApiError::InvalidResponseType)
73+
}
74+
}
75+
76+
// handle axum ws socket upgrade for await_remote_auth
77+
async fn handle_remote_auth_socket(socket: WebSocket, state: AppState, token: String) {
78+
let (tx, rx) = oneshot::channel::<String>();
79+
let (mut ws_tx, mut ws_rx) = socket.split();
80+
81+
let occupied = {
82+
let mut sessions = state.remote_mfa_sessions.lock().await;
83+
match sessions.entry(token.clone()) {
84+
Entry::Occupied(_) => true,
85+
Entry::Vacant(v) => {
86+
v.insert(tx);
87+
false
88+
}
89+
}
90+
};
91+
if occupied {
92+
let _ = ws_tx.close().await;
93+
return;
94+
}
95+
96+
let mut set = JoinSet::new();
97+
98+
set.spawn(async move {
99+
if let Ok(msg) = rx.await {
100+
let payload = json!({
101+
"type": "mfa_success",
102+
"preshared_key": &msg,
103+
});
104+
if let Ok(serialized) = serde_json::to_string(&payload) {
105+
let message = Message::Text(serialized);
106+
if ws_tx.send(message).await.is_err() {
107+
error!("Failed to send preshared key via ws");
108+
}
109+
} else {
110+
error!("Failed to serialize remote mfa ws client response message");
111+
}
112+
} else {
113+
error!("Failed to receive preshared key from receiver")
114+
}
115+
let _ = ws_tx.close().await;
116+
});
117+
set.spawn(async move {
118+
while let Some(msg_result) = ws_rx.next().await {
119+
match msg_result {
120+
Ok(msg) => {
121+
if let Message::Close(_) = msg {
122+
break;
123+
}
124+
}
125+
Err(e) => {
126+
error!("Remote desktop mfa WS client listen error {e}");
127+
break;
128+
}
129+
}
130+
}
131+
});
132+
133+
let _ = set.join_next().await;
134+
set.shutdown().await;
135+
// will remove token if it's still there
136+
state.remote_mfa_sessions.lock().await.remove(&token);
137+
}
138+
139+
#[instrument(level = "debug", skip(state, req))]
20140
async fn start_client_mfa(
21141
State(state): State<AppState>,
22142
device_info: DeviceInfo,
@@ -38,7 +158,7 @@ async fn start_client_mfa(
38158
}
39159
}
40160

41-
#[instrument(level = "debug", skip(state))]
161+
#[instrument(level = "debug", skip(state, req))]
42162
async fn finish_client_mfa(
43163
State(state): State<AppState>,
44164
device_info: DeviceInfo,
@@ -50,10 +170,52 @@ async fn finish_client_mfa(
50170
.send(core_request::Payload::ClientMfaFinish(req), device_info)?;
51171
let payload = get_core_response(rx).await?;
52172
if let core_response::Payload::ClientMfaFinish(response) = payload {
53-
info!("Finished desktop client authorization");
54173
Ok(Json(response))
55174
} else {
56175
error!("Received invalid gRPC response type: {payload:#?}");
57176
Err(ApiError::InvalidResponseType)
58177
}
59178
}
179+
180+
#[instrument(level = "debug", skip(state, req))]
181+
async fn finish_remote_mfa(
182+
State(state): State<AppState>,
183+
device_info: DeviceInfo,
184+
Json(req): Json<ClientMfaFinishRequest>,
185+
) -> Result<Json<serde_json::Value>, ApiError> {
186+
info!("Finishing desktop client authorization");
187+
let rx = state
188+
.grpc_server
189+
.send(core_request::Payload::ClientMfaFinish(req), device_info)?;
190+
let payload = get_core_response(rx).await?;
191+
if let core_response::Payload::ClientMfaFinish(response) = payload {
192+
// check if this needs to be forwarded
193+
match response.token {
194+
Some(token) => {
195+
let sender_option = {
196+
let mut sessions = state.remote_mfa_sessions.lock().await;
197+
sessions.remove(&token)
198+
};
199+
match sender_option {
200+
Some(sender) => {
201+
let _ = sender.send(response.preshared_key);
202+
}
203+
// if desktop stopped listening for the result there will be no palce to send the result
204+
None => {
205+
error!("Remote MFA approve finished but session was not found.");
206+
return Err(ApiError::Unexpected(String::new()));
207+
}
208+
}
209+
info!("Finished desktop client authorization via mobile device");
210+
Ok(Json(json!({})))
211+
}
212+
None => {
213+
error!("Remote MFA Unexpected core response, token was not returned");
214+
Err(ApiError::Unexpected(String::new()))
215+
}
216+
}
217+
} else {
218+
error!("Received invalid gRPC response type: {payload:#?}");
219+
Err(ApiError::InvalidResponseType)
220+
}
221+
}

src/http.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use std::{
2+
collections::HashMap,
23
fs::read_to_string,
34
net::{IpAddr, Ipv4Addr, SocketAddr},
4-
sync::atomic::Ordering,
5+
sync::{atomic::Ordering, Arc},
56
time::Duration,
67
};
78

@@ -16,7 +17,7 @@ use axum::{
1617
use axum_extra::extract::cookie::Key;
1718
use clap::crate_version;
1819
use serde::Serialize;
19-
use tokio::{net::TcpListener, task::JoinSet};
20+
use tokio::{net::TcpListener, sync::oneshot, task::JoinSet};
2021
use tonic::transport::{Identity, Server, ServerTlsConfig};
2122
use tower_governor::{
2223
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
@@ -42,6 +43,8 @@ const RATE_LIMITER_CLEANUP_PERIOD: Duration = Duration::from_secs(60);
4243
#[derive(Clone)]
4344
pub(crate) struct AppState {
4445
pub(crate) grpc_server: ProxyServer,
46+
pub(crate) remote_mfa_sessions:
47+
Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<String>>>>,
4548
key: Key,
4649
url: Url,
4750
}
@@ -129,6 +132,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
129132
debug!("Setting up API server");
130133
let shared_state = AppState {
131134
grpc_server: grpc_server.clone(),
135+
remote_mfa_sessions: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
132136
// Generate secret key for encrypting cookies.
133137
key: Key::generate(),
134138
url: config.url.clone(),

web/biome.json

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"$schema": "https://biomejs.dev/schemas/2.1.2/schema.json",
2+
"$schema": "https://biomejs.dev/schemas/2.2.0/schema.json",
33
"vcs": { "enabled": false, "clientKind": "git", "useIgnoreFile": false },
44
"files": {
55
"ignoreUnknown": false,
@@ -8,7 +8,7 @@
88
"!src/i18n/*.ts",
99
"!src/i18n/*.tsx",
1010
"!src/i18n/i18n-util",
11-
"!dist/**"
11+
"!dist"
1212
]
1313
},
1414
"formatter": {
@@ -40,7 +40,8 @@
4040
"noUnusedVariables": "error",
4141
"useExhaustiveDependencies": "error",
4242
"useHookAtTopLevel": "error",
43-
"useJsxKeyInIterable": "error"
43+
"useJsxKeyInIterable": "error",
44+
"useUniqueElementIds": "off"
4445
},
4546
"security": { "noDangerouslySetInnerHtmlWithChildren": "error" },
4647
"style": {

0 commit comments

Comments
 (0)