From 6e05e36b7ad8b07f6b4bba5f04582ed961c79ba8 Mon Sep 17 00:00:00 2001 From: Thom Wright Date: Fri, 26 Sep 2025 17:50:45 +0100 Subject: [PATCH] Cache PG salted password and client key --- sqlx-postgres/src/connection/mod.rs | 2 + sqlx-postgres/src/connection/sasl.rs | 133 +++++++++++++++++++++++++-- sqlx-postgres/src/message/mod.rs | 2 +- sqlx-postgres/src/options/mod.rs | 9 +- 4 files changed, 133 insertions(+), 13 deletions(-) diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 4e05cd867b..e636270e32 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -23,6 +23,8 @@ use sqlx_core::sql_str::SqlSafeStr; pub use self::stream::PgStream; +pub use sasl::ClientKeyCache; + pub(crate) mod describe; mod establish; mod executor; diff --git a/sqlx-postgres/src/connection/sasl.rs b/sqlx-postgres/src/connection/sasl.rs index 94fdfc689f..5190eb1ef3 100644 --- a/sqlx-postgres/src/connection/sasl.rs +++ b/sqlx-postgres/src/connection/sasl.rs @@ -1,6 +1,12 @@ +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; + use crate::connection::stream::PgStream; use crate::error::Error; -use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse}; +use crate::message::{ + Authentication, AuthenticationSasl, AuthenticationSaslContinue, SaslInitialResponse, + SaslResponse, +}; use crate::rt; use crate::PgConnectOptions; use hmac::{Hmac, Mac}; @@ -16,6 +22,100 @@ const USERNAME_ATTR: &str = "n"; const CLIENT_PROOF_ATTR: &str = "p"; const NONCE_ATTR: &str = "r"; +/// A single-entry cache for the client key derived from the password. +/// +/// Salting the password and deriving the client key can be expensive, so this cache stores the most +/// recently used client key along with the parameters used to derive it. +/// +/// According to [RFC-7677](https://datatracker.ietf.org/doc/html/rfc7677): +/// +/// > This computational cost can be avoided by caching the ClientKey (assuming the Salt and hash +/// > iteration-count is stable). +#[derive(Debug, Clone)] +pub struct ClientKeyCache { + inner: Arc>>, +} + +#[derive(Debug, PartialEq, Eq)] +struct CacheKey { + host: String, + port: u16, + socket: Option, + database: Option, + username: String, + password: String, + salt: Vec, + iterations: u32, +} + +impl From<(&PgConnectOptions, &AuthenticationSaslContinue)> for CacheKey { + fn from((options, cont): (&PgConnectOptions, &AuthenticationSaslContinue)) -> Self { + CacheKey { + host: options.host.clone(), + port: options.port, + socket: options.socket.clone(), + database: options.database.clone(), + username: options.username.clone(), + password: options.password.clone().unwrap_or_default(), + salt: cont.salt.clone(), + iterations: cont.iterations, + } + } +} + +#[derive(Debug)] +struct CacheInner { + cache_key: CacheKey, + salted_password: [u8; 32], + client_key: Hmac, +} + +impl ClientKeyCache { + pub fn new() -> Self { + ClientKeyCache { + inner: Arc::new(Mutex::new(None)), + } + } + + fn get( + &self, + options: &PgConnectOptions, + cont: &AuthenticationSaslContinue, + ) -> Option<([u8; 32], Hmac)> { + let key = CacheKey::from((options, cont)); + + self.inner + .lock() + .expect("BUG: panicked while holding a lock") + .as_ref() + .and_then(|inner| { + if inner.cache_key == key { + Some((inner.salted_password, inner.client_key.clone())) + } else { + None + } + }) + } + + fn set( + &self, + options: &PgConnectOptions, + cont: &AuthenticationSaslContinue, + salted_password: [u8; 32], + client_key: Hmac, + ) { + let mut inner = self + .inner + .lock() + .expect("BUG: panicked while holding a lock"); + *inner = Some(CacheInner { + cache_key: CacheKey::from((options, cont)), + salted_password, + client_key, + }); + } +} + pub(crate) async fn authenticate( stream: &mut PgStream, options: &PgConnectOptions, @@ -86,16 +186,29 @@ pub(crate) async fn authenticate( } }; - // SaltedPassword := Hi(Normalize(password), salt, i) - let salted_password = hi( - options.password.as_deref().unwrap_or_default(), - &cont.salt, - cont.iterations, - ) - .await?; + let (salted_password, mut mac) = { + if let Some(cached) = options.sasl_client_key_cache.get(options, &cont) { + cached + } else { + // SaltedPassword := Hi(Normalize(password), salt, i) + let salted_password = hi( + options.password.as_deref().unwrap_or_default(), + &cont.salt, + cont.iterations, + ) + .await?; + + // ClientKey := HMAC(SaltedPassword, "Client Key") + let mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; + + options + .sasl_client_key_cache + .set(options, &cont, salted_password, mac.clone()); + + (salted_password, mac) + } + }; - // ClientKey := HMAC(SaltedPassword, "Client Key") - let mut mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; mac.update(b"Client Key"); let client_key = mac.finalize().into_bytes(); diff --git a/sqlx-postgres/src/message/mod.rs b/sqlx-postgres/src/message/mod.rs index e62f9bebb3..e7648f4419 100644 --- a/sqlx-postgres/src/message/mod.rs +++ b/sqlx-postgres/src/message/mod.rs @@ -30,7 +30,7 @@ mod startup; mod sync; mod terminate; -pub use authentication::{Authentication, AuthenticationSasl}; +pub use authentication::{Authentication, AuthenticationSasl, AuthenticationSaslContinue}; pub use backend_key_data::BackendKeyData; pub use bind::Bind; pub use close::Close; diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index efbc43989b..765b0159b9 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -5,7 +5,10 @@ use std::path::{Path, PathBuf}; pub use ssl_mode::PgSslMode; -use crate::{connection::LogSettings, net::tls::CertificateInput}; +use crate::{ + connection::{ClientKeyCache, LogSettings}, + net::tls::CertificateInput, +}; mod connect; mod parse; @@ -30,6 +33,7 @@ pub struct PgConnectOptions { pub(crate) log_settings: LogSettings, pub(crate) extra_float_digits: Option>, pub(crate) options: Option, + pub(crate) sasl_client_key_cache: ClientKeyCache, } impl Default for PgConnectOptions { @@ -90,6 +94,7 @@ impl PgConnectOptions { extra_float_digits: Some("2".into()), log_settings: Default::default(), options: var("PGOPTIONS").ok(), + sasl_client_key_cache: ClientKeyCache::new(), } } @@ -267,7 +272,7 @@ impl PgConnectOptions { /// -----BEGIN CERTIFICATE----- /// /// -----END CERTIFICATE-----"; - /// + /// /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa)