Skip to content

Commit d61eafd

Browse files
committed
Merge branch 'dev' of https://github.com/DefGuard/proxy into proxy-pairing-1
2 parents 48134f2 + 7490067 commit d61eafd

File tree

3 files changed

+68
-54
lines changed

3 files changed

+68
-54
lines changed

proto

src/grpc.rs

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::{
44
net::SocketAddr,
55
sync::{
66
atomic::{AtomicBool, AtomicU64, Ordering},
7-
Arc, Mutex,
7+
Arc, Mutex, RwLock,
88
},
99
};
1010

@@ -31,7 +31,6 @@ use crate::{
3131

3232
// connected clients
3333
type ClientMap = HashMap<SocketAddr, mpsc::UnboundedSender<Result<CoreRequest, Status>>>;
34-
static COOKIE_KEY_HEADER: &str = "dg-cookie-key-bin";
3534

3635
#[derive(Debug, Clone, Default)]
3736
pub struct Configuration {
@@ -43,18 +42,18 @@ pub(crate) struct ProxyServer {
4342
current_id: Arc<AtomicU64>,
4443
clients: Arc<Mutex<ClientMap>>,
4544
results: Arc<Mutex<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
46-
http_channel: mpsc::UnboundedSender<Key>,
4745
pub(crate) connected: Arc<AtomicBool>,
4846
pub(crate) core_version: Arc<Mutex<Option<Version>>>,
4947
config: Arc<Mutex<Option<Configuration>>>,
48+
cookie_key: Arc<RwLock<Option<Key>>>,
5049
}
5150

5251
impl ProxyServer {
5352
#[must_use]
5453
/// Create new `ProxyServer`.
55-
pub(crate) fn new(http_channel: mpsc::UnboundedSender<Key>) -> Self {
54+
pub(crate) fn new(cookie_key: Arc<RwLock<Option<Key>>>) -> Self {
5655
Self {
57-
http_channel,
56+
cookie_key,
5857
current_id: Arc::new(AtomicU64::new(1)),
5958
clients: Arc::new(Mutex::new(HashMap::new())),
6059
results: Arc::new(Mutex::new(HashMap::new())),
@@ -175,7 +174,7 @@ impl Clone for ProxyServer {
175174
results: Arc::clone(&self.results),
176175
connected: Arc::clone(&self.connected),
177176
core_version: Arc::clone(&self.core_version),
178-
http_channel: self.http_channel.clone(),
177+
cookie_key: Arc::clone(&self.cookie_key),
179178
config: Arc::clone(&self.config),
180179
}
181180
}
@@ -213,31 +212,6 @@ impl proxy_server::Proxy for ProxyServer {
213212
let _guard = span.enter();
214213

215214
info!("Defguard Core gRPC client connected from: {address}");
216-
217-
// Retrieve private cookies key from the header.
218-
let cookie_key = request.metadata().get_bin(COOKIE_KEY_HEADER);
219-
let key = match cookie_key {
220-
Some(key) => Key::from(&key.to_bytes().map_err(|err| {
221-
error!("Failed to decode private cookie key: {err:?}");
222-
Status::internal("Failed to decode private cookie key")
223-
})?),
224-
// If the header is missing, fall back to generating a local key.
225-
// This preserves compatibility with older Core versions that did not
226-
// provide a shared cookie key. In this mode, cookie-based sessions will
227-
// not be shared across proxy instances and HA won't work.
228-
None => {
229-
warn!(
230-
"Private cookie key not provided by Core; falling back to a locally generated key. \
231-
This typically indicates an older Core version and disables cookie sharing across proxies."
232-
);
233-
Key::generate()
234-
}
235-
};
236-
self.http_channel.send(key).map_err(|err| {
237-
error!("Failed to send private cookies key to HTTP server: {err:?}");
238-
Status::internal("Failed to send private cookies key to HTTP server")
239-
})?;
240-
241215
let (tx, rx) = mpsc::unbounded_channel();
242216
self.clients
243217
.lock()
@@ -250,22 +224,32 @@ impl proxy_server::Proxy for ProxyServer {
250224
let clients = Arc::clone(&self.clients);
251225
let results = Arc::clone(&self.results);
252226
let connected = Arc::clone(&self.connected);
253-
let mut stream = request.into_inner();
227+
let cookie_key = Arc::clone(&self.cookie_key);
254228
tokio::spawn(
255229
async move {
230+
let mut stream = request.into_inner();
256231
loop {
257232
match stream.message().await {
258233
Ok(Some(response)) => {
259234
debug!("Received message from Defguard Core ID={}", response.id);
260235
connected.store(true, Ordering::Relaxed);
261236
if let Some(payload) = response.payload {
262-
let maybe_rx = results.lock().expect("Failed to acquire lock on results hashmap when processing response").remove(&response.id);
263-
if let Some(rx) = maybe_rx {
264-
if let Err(err) = rx.send(payload) {
265-
error!("Failed to send message to rx {:?}", err.type_id());
237+
match payload {
238+
core_response::Payload::InitialInfo(payload) => {
239+
info!("Received private cookies key");
240+
let key = Key::from(&payload.private_cookies_key);
241+
*cookie_key.write().unwrap() = Some(key);
242+
},
243+
_ => {
244+
let maybe_rx = results.lock().expect("Failed to acquire lock on results hashmap when processing response").remove(&response.id);
245+
if let Some(rx) = maybe_rx {
246+
if let Err(err) = rx.send(payload) {
247+
error!("Failed to send message to rx {:?}", err.type_id());
248+
}
249+
} else {
250+
error!("Missing receiver for response #{}", response.id);
251+
}
266252
}
267-
} else {
268-
error!("Missing receiver for response #{}", response.id);
269253
}
270254
}
271255
}

src/http.rs

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{
22
collections::HashMap,
33
net::{IpAddr, Ipv4Addr, SocketAddr},
44
path::Path,
5-
sync::{atomic::Ordering, Arc},
5+
sync::{atomic::Ordering, Arc, LazyLock, RwLock},
66
time::Duration,
77
};
88

@@ -12,6 +12,7 @@ use axum::{
1212
extract::{ConnectInfo, FromRef, State},
1313
http::{header::HeaderValue, Request, Response, StatusCode},
1414
middleware::{self, Next},
15+
response::IntoResponse,
1516
routing::{get, post},
1617
serve, Json, Router,
1718
};
@@ -21,7 +22,7 @@ use defguard_version::{server::DefguardVersionLayer, Version};
2122
use serde::Serialize;
2223
use tokio::{
2324
net::TcpListener,
24-
sync::{mpsc, oneshot},
25+
sync::{mpsc, oneshot, Mutex},
2526
task::JoinSet,
2627
};
2728
use tower_governor::{
@@ -57,7 +58,7 @@ pub(crate) struct AppState {
5758
pub(crate) grpc_server: ProxyServer,
5859
pub(crate) remote_mfa_sessions:
5960
Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<String>>>>,
60-
key: Key,
61+
cookie_key: Arc<RwLock<Option<Key>>>,
6162
url: Url,
6263
}
6364

@@ -79,7 +80,10 @@ impl AppState {
7980

8081
impl FromRef<AppState> for Key {
8182
fn from_ref(state: &AppState) -> Self {
82-
state.key.clone()
83+
let maybe_key = state.cookie_key.read().unwrap().clone();
84+
// We return the dummy key only to satisfy the `FromRef` trait, but it is never
85+
// used in practice because of the `ensure_configured` middleware.
86+
maybe_key.unwrap_or_else(|| Key::from(&[0; 64]))
8387
}
8488
}
8589

@@ -206,18 +210,45 @@ pub async fn run_setup(
206210
Ok(configuration)
207211
}
208212

213+
/// Middleware that gates all HTTP endpoints except health checks until the proxy
214+
/// is fully configured.
215+
///
216+
/// The proxy cannot safely handle requests that rely on encrypted cookies
217+
/// (e.g. OpenID / MFA flows) until it receives the cookie encryption key from
218+
/// the core. This key is provided asynchronously after the core connects.
219+
///
220+
/// Until the key is available, only health check endpoints are served and all
221+
/// other requests return HTTP 503 (Service Unavailable). Once the key is set,
222+
/// the middleware becomes a no-op and all routes are enabled.
223+
async fn ensure_configured(
224+
State(state): State<AppState>,
225+
request: Request<Body>,
226+
next: Next,
227+
) -> Response<Body> {
228+
// Allow healthchecks even before core connects and gives us the cookie key.
229+
let path = request.uri().path();
230+
if matches!(path, "/api/v1/health" | "/api/v1/health-grpc") {
231+
return next.run(request).await;
232+
}
233+
234+
// Block all other requests until cookie key is configured.
235+
if state.cookie_key.read().unwrap().is_none() {
236+
return StatusCode::SERVICE_UNAVAILABLE.into_response();
237+
}
238+
239+
next.run(request).await
240+
}
241+
209242
pub async fn run_server(env_config: EnvConfig, config: Configuration) -> anyhow::Result<()> {
210243
info!("Starting Defguard Proxy server");
211244
debug!("Using config: {env_config:?}");
212245

213246
let mut tasks = JoinSet::new();
214-
215-
// Prepare the channel for gRPC -> http server communication.
216-
// The channel sends private cookies key once core connects to gRPC.
217-
let (tx, mut rx) = mpsc::unbounded_channel::<Key>();
247+
let cookie_key = Default::default();
218248

219249
// connect to upstream gRPC server
220-
let grpc_server = ProxyServer::new(tx);
250+
let grpc_server = ProxyServer::new(Arc::clone(&cookie_key));
251+
221252
let server_clone = grpc_server.clone();
222253
grpc_server.configure(config);
223254

@@ -241,15 +272,10 @@ pub async fn run_server(env_config: EnvConfig, config: Configuration) -> anyhow:
241272
}
242273
});
243274

244-
// Wait for core to connect to gRPC and send private cookies encryption key.
245-
let Some(key) = rx.recv().await else {
246-
return Err(anyhow::Error::msg("http channel closed"));
247-
};
248-
249275
// build application
250276
debug!("Setting up API server");
251277
let shared_state = AppState {
252-
key,
278+
cookie_key,
253279
grpc_server,
254280
remote_mfa_sessions: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
255281
url: env_config.url.clone(),
@@ -309,6 +335,10 @@ pub async fn run_server(env_config: EnvConfig, config: Configuration) -> anyhow:
309335
.route("/info", get(app_info)),
310336
)
311337
.fallback_service(get(handle_404))
338+
.layer(middleware::from_fn_with_state(
339+
shared_state.clone(),
340+
ensure_configured,
341+
))
312342
.layer(middleware::map_response(powered_by_header))
313343
.layer(middleware::from_fn_with_state(
314344
shared_state.clone(),

0 commit comments

Comments
 (0)