diff --git a/src/grpc.rs b/src/grpc.rs index 3dd2c21..5c20ae2 100644 --- a/src/grpc.rs +++ b/src/grpc.rs @@ -8,6 +8,7 @@ use std::{ }, }; +use axum_extra::extract::cookie::Key; use defguard_version::{ get_tracing_variables, server::{grpc::DefguardVersionInterceptor, DefguardVersionLayer}, @@ -31,6 +32,7 @@ use crate::{ // connected clients type ClientMap = HashMap>>; +static COOKIE_KEY_HEADER: &str = "dg-cookie-key-bin"; #[derive(Debug, Clone, Default)] pub(crate) struct Configuration { @@ -42,6 +44,7 @@ pub(crate) struct ProxyServer { current_id: Arc, clients: Arc>, results: Arc>>>, + http_channel: mpsc::UnboundedSender, pub(crate) connected: Arc, pub(crate) core_version: Arc>>, config: Arc>>, @@ -51,8 +54,9 @@ pub(crate) struct ProxyServer { impl ProxyServer { #[must_use] /// Create new `ProxyServer`. - pub(crate) fn new() -> Self { + pub(crate) fn new(http_channel: mpsc::UnboundedSender) -> Self { Self { + http_channel, current_id: Arc::new(AtomicU64::new(1)), clients: Arc::new(Mutex::new(HashMap::new())), results: Arc::new(Mutex::new(HashMap::new())), @@ -189,6 +193,7 @@ impl Clone for ProxyServer { results: Arc::clone(&self.results), connected: Arc::clone(&self.connected), core_version: Arc::clone(&self.core_version), + http_channel: self.http_channel.clone(), config: Arc::clone(&self.config), setup_in_progress: Arc::clone(&self.setup_in_progress), } @@ -228,6 +233,30 @@ impl proxy_server::Proxy for ProxyServer { info!("Defguard Core gRPC client connected from: {address}"); + // Retrieve private cookies key from the header. + let cookie_key = request.metadata().get_bin(COOKIE_KEY_HEADER); + let key = match cookie_key { + Some(key) => Key::from(&key.to_bytes().map_err(|err| { + error!("Failed to decode private cookie key: {err:?}"); + Status::internal("Failed to decode private cookie key") + })?), + // If the header is missing, fall back to generating a local key. + // This preserves compatibility with older Core versions that did not + // provide a shared cookie key. In this mode, cookie-based sessions will + // not be shared across proxy instances and HA won't work. + None => { + warn!( + "Private cookie key not provided by Core; falling back to a locally generated key. \ + This typically indicates an older Core version and disables cookie sharing across proxies." + ); + Key::generate() + } + }; + self.http_channel.send(key).map_err(|err| { + error!("Failed to send private cookies key to HTTP server: {err:?}"); + Status::internal("Failed to send private cookies key to HTTP server") + })?; + let (tx, rx) = mpsc::unbounded_channel(); self.clients .lock() diff --git a/src/http.rs b/src/http.rs index f18d52d..56ee352 100644 --- a/src/http.rs +++ b/src/http.rs @@ -22,7 +22,7 @@ use defguard_version::{server::DefguardVersionLayer, Version}; use serde::Serialize; use tokio::{ net::TcpListener, - sync::{oneshot, Mutex}, + sync::{mpsc, oneshot, Mutex}, task::JoinSet, }; use tower_governor::{ @@ -178,18 +178,13 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { debug!("Using config: {config:?}"); let mut tasks = JoinSet::new(); - // connect to upstream gRPC server - let grpc_server = ProxyServer::new(); - // build application - debug!("Setting up API server"); - let shared_state = AppState { - grpc_server: grpc_server.clone(), - remote_mfa_sessions: Arc::new(tokio::sync::Mutex::new(HashMap::new())), - // Generate secret key for encrypting cookies. - key: Key::generate(), - url: config.url.clone(), - }; + // Prepare the channel for gRPC -> http server communication. + // The channel sends private cookies key once core connects to gRPC. + let (tx, mut rx) = mpsc::unbounded_channel::(); + + // connect to upstream gRPC server + let grpc_server = ProxyServer::new(tx); let server_clone = grpc_server.clone(); @@ -261,6 +256,20 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { } }); + // Wait for core to connect to gRPC and send private cookies encryption key. + let Some(key) = rx.recv().await else { + return Err(anyhow::Error::msg("http channel closed")); + }; + + // build application + debug!("Setting up API server"); + let shared_state = AppState { + key, + grpc_server: grpc_server, + remote_mfa_sessions: Arc::new(tokio::sync::Mutex::new(HashMap::new())), + url: config.url.clone(), + }; + // Setup tower_governor rate-limiter debug!( "Configuring rate limiter, per_second: {}, burst: {}",