From bd0905b08e70c377e217ab0f4eb6804e26d25410 Mon Sep 17 00:00:00 2001 From: benthecarman Date: Wed, 17 Dec 2025 23:16:55 -0600 Subject: [PATCH 1/2] Add HMAC-based authentication for RPC/CLI Implements time-based HMAC-SHA256 authentication using a shared API key. Each request includes a timestamp and HMAC in the X-Auth header, preventing replay attacks with a 60-second tolerance window. Co-Authored-By: Claude Opus 4.5 --- Cargo.lock | 2 + README.md | 4 +- ldk-server-cli/src/main.rs | 5 +- ldk-server-client/Cargo.toml | 1 + ldk-server-client/src/client.rs | 28 +- ldk-server/Cargo.toml | 1 + ldk-server/ldk-server-config.toml | 4 + ldk-server/src/api/error.rs | 4 + ldk-server/src/main.rs | 2 +- ldk-server/src/service.rs | 388 +++++++++++++++++++++++---- ldk-server/src/util/config.rs | 65 ++++- ldk-server/src/util/proto_adapter.rs | 4 +- 12 files changed, 439 insertions(+), 69 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ba02fb5..69dbd04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1717,6 +1717,7 @@ name = "ldk-server" version = "0.1.0" dependencies = [ "async-trait", + "base64 0.21.7", "bytes", "chrono", "futures-util", @@ -1751,6 +1752,7 @@ dependencies = [ name = "ldk-server-client" version = "0.1.0" dependencies = [ + "bitcoin_hashes 0.14.0", "ldk-server-protos", "prost", "reqwest 0.11.27", diff --git a/README.md b/README.md index 8ea7071..e3358c4 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,6 @@ cargo run --bin ldk-server ./ldk-server/ldk-server-config.toml Interact with the node using CLI: ``` -./target/debug/ldk-server-cli -b localhost:3002 onchain-receive # To generate onchain-receive address. -./target/debug/ldk-server-cli -b localhost:3002 help # To print help/available commands. +./target/debug/ldk-server-cli -b localhost:3002 --api-key your-secret-api-key onchain-receive # To generate onchain-receive address. +./target/debug/ldk-server-cli -b localhost:3002 --api-key your-secret-api-key help # To print help/available commands. ``` diff --git a/ldk-server-cli/src/main.rs b/ldk-server-cli/src/main.rs index 6cbfcd8..bab6408 100644 --- a/ldk-server-cli/src/main.rs +++ b/ldk-server-cli/src/main.rs @@ -46,6 +46,9 @@ struct Cli { #[arg(short, long, default_value = "localhost:3000")] base_url: String, + #[arg(short, long)] + api_key: String, + #[command(subcommand)] command: Commands, } @@ -214,7 +217,7 @@ enum Commands { #[tokio::main] async fn main() { let cli = Cli::parse(); - let client = LdkServerClient::new(cli.base_url); + let client = LdkServerClient::new(cli.base_url, cli.api_key); match cli.command { Commands::GetNodeInfo => { diff --git a/ldk-server-client/Cargo.toml b/ldk-server-client/Cargo.toml index ca0ffad..13916fa 100644 --- a/ldk-server-client/Cargo.toml +++ b/ldk-server-client/Cargo.toml @@ -11,3 +11,4 @@ serde = ["ldk-server-protos/serde"] ldk-server-protos = { path = "../ldk-server-protos" } reqwest = { version = "0.11.13", default-features = false, features = ["rustls-tls"] } prost = { version = "0.11.6", default-features = false, features = ["std", "prost-derive"] } +bitcoin_hashes = "0.14" diff --git a/ldk-server-client/src/client.rs b/ldk-server-client/src/client.rs index 9983151..3c76060 100644 --- a/ldk-server-client/src/client.rs +++ b/ldk-server-client/src/client.rs @@ -13,6 +13,8 @@ use crate::error::LdkServerError; use crate::error::LdkServerErrorCode::{ AuthError, InternalError, InternalServerError, InvalidRequestError, LightningError, }; +use bitcoin_hashes::hmac::{Hmac, HmacEngine}; +use bitcoin_hashes::{sha256, Hash, HashEngine}; use ldk_server_protos::api::{ Bolt11ReceiveRequest, Bolt11ReceiveResponse, Bolt11SendRequest, Bolt11SendResponse, Bolt12ReceiveRequest, Bolt12ReceiveResponse, Bolt12SendRequest, Bolt12SendResponse, @@ -32,6 +34,7 @@ use ldk_server_protos::endpoints::{ use ldk_server_protos::error::{ErrorCode, ErrorResponse}; use reqwest::header::CONTENT_TYPE; use reqwest::Client; +use std::time::{SystemTime, UNIX_EPOCH}; const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; @@ -40,12 +43,31 @@ const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; pub struct LdkServerClient { base_url: String, client: Client, + api_key: String, } impl LdkServerClient { /// Constructs a [`LdkServerClient`] using `base_url` as the ldk-server endpoint. - pub fn new(base_url: String) -> Self { - Self { base_url, client: Client::new() } + /// `api_key` is used for HMAC-based authentication. + pub fn new(base_url: String, api_key: String) -> Self { + Self { base_url, client: Client::new(), api_key } + } + + /// Computes the HMAC-SHA256 authentication header value. + /// Format: "HMAC :" + fn compute_auth_header(&self, body: &[u8]) -> String { + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("System time should be after Unix epoch") + .as_secs(); + + // Compute HMAC-SHA256(api_key, timestamp_bytes || body) + let mut hmac_engine: HmacEngine = HmacEngine::new(self.api_key.as_bytes()); + hmac_engine.input(×tamp.to_be_bytes()); + hmac_engine.input(body); + let hmac_result = Hmac::::from_engine(hmac_engine); + + format!("HMAC {}:{}", timestamp, hmac_result) } /// Retrieve the latest node info like `node_id`, `current_best_block` etc. @@ -196,10 +218,12 @@ impl LdkServerClient { &self, request: &Rq, url: &str, ) -> Result { let request_body = request.encode_to_vec(); + let auth_header = self.compute_auth_header(&request_body); let response_raw = self .client .post(url) .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) + .header("X-Auth", auth_header) .body(request_body) .send() .await diff --git a/ldk-server/Cargo.toml b/ldk-server/Cargo.toml index e1053f7..62f82d3 100644 --- a/ldk-server/Cargo.toml +++ b/ldk-server/Cargo.toml @@ -20,6 +20,7 @@ async-trait = { version = "0.1.85", default-features = false } toml = { version = "0.8.9", default-features = false, features = ["parse"] } chrono = { version = "0.4", default-features = false, features = ["clock"] } log = "0.4.28" +base64 = { version = "0.21", default-features = false, features = ["std"] } # Required for RabittMQ based EventPublisher. Only enabled for `events-rabbitmq` feature. lapin = { version = "2.4.0", features = ["rustls"], default-features = false, optional = true } diff --git a/ldk-server/ldk-server-config.toml b/ldk-server/ldk-server-config.toml index e4343b6..1f3d651 100644 --- a/ldk-server/ldk-server-config.toml +++ b/ldk-server/ldk-server-config.toml @@ -12,6 +12,10 @@ dir_path = "/tmp/ldk-server/" # Path for LDK and BDK data persis level = "Debug" # Log level (Error, Warn, Info, Debug, Trace) file_path = "/tmp/ldk-server/ldk-server.log" # Log file path +# HMAC Authentication (REQUIRED) +[auth] +api_key = "your-secret-api-key" + # Must set either bitcoind or esplora settings, but not both # Bitcoin Core settings diff --git a/ldk-server/src/api/error.rs b/ldk-server/src/api/error.rs index cacb0f0..15a5bca 100644 --- a/ldk-server/src/api/error.rs +++ b/ldk-server/src/api/error.rs @@ -43,6 +43,9 @@ pub(crate) enum LdkServerErrorCode { /// Please refer to [`protos::error::ErrorCode::InvalidRequestError`]. InvalidRequestError, + /// Please refer to [`protos::error::ErrorCode::AuthError`]. + AuthError, + /// Please refer to [`protos::error::ErrorCode::LightningError`]. LightningError, @@ -54,6 +57,7 @@ impl fmt::Display for LdkServerErrorCode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { LdkServerErrorCode::InvalidRequestError => write!(f, "InvalidRequestError"), + LdkServerErrorCode::AuthError => write!(f, "AuthError"), LdkServerErrorCode::LightningError => write!(f, "LightningError"), LdkServerErrorCode::InternalServerError => write!(f, "InternalServerError"), } diff --git a/ldk-server/src/main.rs b/ldk-server/src/main.rs index f5c86ae..f635994 100644 --- a/ldk-server/src/main.rs +++ b/ldk-server/src/main.rs @@ -356,7 +356,7 @@ fn main() { match res { Ok((stream, _)) => { let io_stream = TokioIo::new(stream); - let node_service = NodeService::new(Arc::clone(&node), Arc::clone(&paginated_store)); + let node_service = NodeService::new(Arc::clone(&node), Arc::clone(&paginated_store), config_file.auth_config.clone()); runtime.spawn(async move { if let Err(err) = http1::Builder::new().serve_connection(io_stream, node_service).await { error!("Failed to serve connection: {}", err); diff --git a/ldk-server/src/service.rs b/ldk-server/src/service.rs index 6048a10..3c089a6 100644 --- a/ldk-server/src/service.rs +++ b/ldk-server/src/service.rs @@ -7,6 +7,8 @@ // You may not use this file except in accordance with one or both of these // licenses. +use ldk_node::bitcoin::hashes::hmac::{Hmac, HmacEngine}; +use ldk_node::bitcoin::hashes::{sha256, Hash, HashEngine}; use ldk_node::Node; use http_body_util::{BodyExt, Full, Limited}; @@ -30,7 +32,7 @@ use crate::api::bolt12_receive::handle_bolt12_receive_request; use crate::api::bolt12_send::handle_bolt12_send_request; use crate::api::close_channel::{handle_close_channel_request, handle_force_close_channel_request}; use crate::api::error::LdkServerError; -use crate::api::error::LdkServerErrorCode::InvalidRequestError; +use crate::api::error::LdkServerErrorCode::{AuthError, InvalidRequestError}; use crate::api::get_balances::handle_get_balances_request; use crate::api::get_node_info::handle_get_node_info_request; use crate::api::get_payment_details::handle_get_payment_details_request; @@ -43,6 +45,7 @@ use crate::api::open_channel::handle_open_channel; use crate::api::splice_channel::{handle_splice_in_request, handle_splice_out_request}; use crate::api::update_channel_config::handle_update_channel_config_request; use crate::io::persist::paginated_kv_store::PaginatedKVStore; +use crate::util::config::AuthConfig; use crate::util::proto_adapter::to_error_response; use std::future::Future; use std::pin::Pin; @@ -56,14 +59,80 @@ const MAX_BODY_SIZE: usize = 10 * 1024 * 1024; pub struct NodeService { node: Arc, paginated_kv_store: Arc, + auth_config: AuthConfig, } impl NodeService { - pub(crate) fn new(node: Arc, paginated_kv_store: Arc) -> Self { - Self { node, paginated_kv_store } + pub(crate) fn new( + node: Arc, paginated_kv_store: Arc, auth_config: AuthConfig, + ) -> Self { + Self { node, paginated_kv_store, auth_config } } } +// Maximum allowed time difference between client timestamp and server time (1 minute) +const AUTH_TIMESTAMP_TOLERANCE_SECS: u64 = 60; + +/// Extracts authentication parameters from request headers. +/// Returns (timestamp, hmac_hex) if valid format, or error. +fn extract_auth_params(req: &Request) -> Result<(u64, String), LdkServerError> { + let auth_header = req + .headers() + .get("X-Auth") + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| LdkServerError::new(AuthError, "Missing X-Auth header"))?; + + // Format: "HMAC :" + let auth_data = auth_header + .strip_prefix("HMAC ") + .ok_or_else(|| LdkServerError::new(AuthError, "Invalid X-Auth header format"))?; + + let (timestamp_str, hmac_hex) = auth_data + .split_once(':') + .ok_or_else(|| LdkServerError::new(AuthError, "Invalid X-Auth header format"))?; + + let timestamp = timestamp_str + .parse::() + .map_err(|_| LdkServerError::new(AuthError, "Invalid timestamp in X-Auth header"))?; + + // validate hmac_hex is valid hex + if hmac_hex.len() != 64 || !hmac_hex.chars().all(|c| c.is_ascii_hexdigit()) { + return Err(LdkServerError::new(AuthError, "Invalid HMAC in X-Auth header")); + } + + Ok((timestamp, hmac_hex.to_string())) +} + +/// Validates the HMAC authentication after the request body has been read. +fn validate_hmac_auth( + timestamp: u64, provided_hmac_hex: &str, body: &[u8], auth_config: &AuthConfig, +) -> Result<(), LdkServerError> { + // Validate timestamp is within acceptable window + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_err(|_| LdkServerError::new(AuthError, "System time error"))? + .as_secs(); + + let time_diff = now.abs_diff(timestamp); + if time_diff > AUTH_TIMESTAMP_TOLERANCE_SECS { + return Err(LdkServerError::new(AuthError, "Request timestamp expired")); + } + + // Compute expected HMAC: HMAC-SHA256(api_key, timestamp_bytes || body) + let mut hmac_engine: HmacEngine = HmacEngine::new(auth_config.api_key.as_bytes()); + hmac_engine.input(×tamp.to_be_bytes()); + hmac_engine.input(body); + let expected_hmac = Hmac::::from_engine(hmac_engine); + + // Compare HMACs (constant-time comparison via Hash equality) + let expected_hex = expected_hmac.to_string(); + if expected_hex != provided_hmac_hex { + return Err(LdkServerError::new(AuthError, "Invalid credentials")); + } + + Ok(()) +} + pub(crate) struct Context { pub(crate) node: Arc, pub(crate) paginated_kv_store: Arc, @@ -75,56 +144,155 @@ impl Service> for NodeService { type Future = Pin> + Send>>; fn call(&self, req: Request) -> Self::Future { + // Extract auth params from headers (validation happens after body is read) + let auth_params = match extract_auth_params(&req) { + Ok(params) => params, + Err(e) => { + let (error_response, status_code) = to_error_response(e); + return Box::pin(async move { + Ok(Response::builder() + .status(status_code) + .body(Full::new(Bytes::from(error_response.encode_to_vec()))) + // unwrap safety: body only errors when previous chained calls failed. + .unwrap()) + }); + }, + }; + let context = Context { node: Arc::clone(&self.node), paginated_kv_store: Arc::clone(&self.paginated_kv_store), }; + let auth_config = self.auth_config.clone(); + // Exclude '/' from path pattern matching. match &req.uri().path()[1..] { - GET_NODE_INFO_PATH => { - Box::pin(handle_request(context, req, handle_get_node_info_request)) - }, - GET_BALANCES_PATH => { - Box::pin(handle_request(context, req, handle_get_balances_request)) - }, - ONCHAIN_RECEIVE_PATH => { - Box::pin(handle_request(context, req, handle_onchain_receive_request)) - }, - ONCHAIN_SEND_PATH => { - Box::pin(handle_request(context, req, handle_onchain_send_request)) - }, - BOLT11_RECEIVE_PATH => { - Box::pin(handle_request(context, req, handle_bolt11_receive_request)) - }, - BOLT11_SEND_PATH => Box::pin(handle_request(context, req, handle_bolt11_send_request)), - BOLT12_RECEIVE_PATH => { - Box::pin(handle_request(context, req, handle_bolt12_receive_request)) - }, - BOLT12_SEND_PATH => Box::pin(handle_request(context, req, handle_bolt12_send_request)), - OPEN_CHANNEL_PATH => Box::pin(handle_request(context, req, handle_open_channel)), - SPLICE_IN_PATH => Box::pin(handle_request(context, req, handle_splice_in_request)), - SPLICE_OUT_PATH => Box::pin(handle_request(context, req, handle_splice_out_request)), - CLOSE_CHANNEL_PATH => { - Box::pin(handle_request(context, req, handle_close_channel_request)) - }, - FORCE_CLOSE_CHANNEL_PATH => { - Box::pin(handle_request(context, req, handle_force_close_channel_request)) - }, - LIST_CHANNELS_PATH => { - Box::pin(handle_request(context, req, handle_list_channels_request)) - }, - UPDATE_CHANNEL_CONFIG_PATH => { - Box::pin(handle_request(context, req, handle_update_channel_config_request)) - }, - GET_PAYMENT_DETAILS_PATH => { - Box::pin(handle_request(context, req, handle_get_payment_details_request)) - }, - LIST_PAYMENTS_PATH => { - Box::pin(handle_request(context, req, handle_list_payments_request)) - }, - LIST_FORWARDED_PAYMENTS_PATH => { - Box::pin(handle_request(context, req, handle_list_forwarded_payments_request)) - }, + GET_NODE_INFO_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_get_node_info_request, + )), + GET_BALANCES_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_get_balances_request, + )), + ONCHAIN_RECEIVE_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_onchain_receive_request, + )), + ONCHAIN_SEND_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_onchain_send_request, + )), + BOLT11_RECEIVE_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_bolt11_receive_request, + )), + BOLT11_SEND_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_bolt11_send_request, + )), + BOLT12_RECEIVE_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_bolt12_receive_request, + )), + BOLT12_SEND_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_bolt12_send_request, + )), + OPEN_CHANNEL_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_open_channel, + )), + SPLICE_IN_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_splice_in_request, + )), + SPLICE_OUT_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_splice_out_request, + )), + CLOSE_CHANNEL_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_close_channel_request, + )), + FORCE_CLOSE_CHANNEL_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_force_close_channel_request, + )), + LIST_CHANNELS_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_list_channels_request, + )), + UPDATE_CHANNEL_CONFIG_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_update_channel_config_request, + )), + GET_PAYMENT_DETAILS_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_get_payment_details_request, + )), + LIST_PAYMENTS_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_list_payments_request, + )), + LIST_FORWARDED_PAYMENTS_PATH => Box::pin(handle_request( + context, + req, + auth_params, + auth_config, + handle_list_forwarded_payments_request, + )), path => { let error = format!("Unknown request: {}", path).into_bytes(); Box::pin(async { @@ -144,7 +312,8 @@ async fn handle_request< R: Message, F: Fn(Context, T) -> Result, >( - context: Context, request: Request, handler: F, + context: Context, request: Request, auth_params: (u64, String), + auth_config: AuthConfig, handler: F, ) -> Result<>>::Response, hyper::Error> { // Limit the size of the request body to prevent abuse let limited_body = Limited::new(request.into_body(), MAX_BODY_SIZE); @@ -163,6 +332,17 @@ async fn handle_request< }, }; + // Validate HMAC authentication with the request body + let (timestamp, provided_hmac) = auth_params; + if let Err(e) = validate_hmac_auth(timestamp, &provided_hmac, &bytes, &auth_config) { + let (error_response, status_code) = to_error_response(e); + return Ok(Response::builder() + .status(status_code) + .body(Full::new(Bytes::from(error_response.encode_to_vec()))) + // unwrap safety: body only errors when previous chained calls failed. + .unwrap()); + } + match T::decode(bytes) { Ok(request) => match handler(context, request) { Ok(response) => Ok(Response::builder() @@ -189,3 +369,115 @@ async fn handle_request< }, } } + +#[cfg(test)] +mod tests { + use super::*; + + fn compute_hmac(api_key: &str, timestamp: u64, body: &[u8]) -> String { + let mut hmac_engine: HmacEngine = HmacEngine::new(api_key.as_bytes()); + hmac_engine.input(×tamp.to_be_bytes()); + hmac_engine.input(body); + Hmac::::from_engine(hmac_engine).to_string() + } + + fn create_test_request(auth_header: Option) -> Request<()> { + let mut builder = Request::builder(); + if let Some(header) = auth_header { + builder = builder.header("X-Auth", header); + } + builder.body(()).unwrap() + } + + #[test] + fn test_extract_auth_params_success() { + let timestamp = + std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); + let hmac = "8f5a33c2c68fb253899a588308fd13dcaf162d2788966a1fb6cc3aa2e0c51a93"; + let auth_header = format!("HMAC {timestamp}:{hmac}"); + + let req = create_test_request(Some(auth_header)); + + let result = extract_auth_params(&req); + assert!(result.is_ok()); + let (ts, hmac_hex) = result.unwrap(); + assert_eq!(ts, timestamp); + assert_eq!(hmac_hex, hmac); + } + + #[test] + fn test_extract_auth_params_missing_header() { + let req = create_test_request(None); + + let result = extract_auth_params(&req); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().error_code, AuthError); + } + + #[test] + fn test_extract_auth_params_invalid_format() { + // Missing "HMAC " prefix + let req = create_test_request(Some("12345:deadbeef".to_string())); + + let result = extract_auth_params(&req); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().error_code, AuthError); + } + + #[test] + fn test_validate_hmac_auth_success() { + let auth_config = AuthConfig { api_key: "test_api_key".to_string() }; + let body = b"test request body"; + let timestamp = + std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); + let hmac = compute_hmac(&auth_config.api_key, timestamp, body); + + let result = validate_hmac_auth(timestamp, &hmac, body, &auth_config); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_hmac_auth_wrong_key() { + let auth_config = AuthConfig { api_key: "test_api_key".to_string() }; + let body = b"test request body"; + let timestamp = + std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); + // Compute HMAC with wrong key + let hmac = compute_hmac("wrong_key", timestamp, body); + + let result = validate_hmac_auth(timestamp, &hmac, body, &auth_config); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().error_code, AuthError); + } + + #[test] + fn test_validate_hmac_auth_expired_timestamp() { + let auth_config = AuthConfig { api_key: "test_api_key".to_string() }; + let body = b"test request body"; + // Use a timestamp from 10 minutes ago + let timestamp = + std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() + - 600; + let hmac = compute_hmac(&auth_config.api_key, timestamp, body); + + let result = validate_hmac_auth(timestamp, &hmac, body, &auth_config); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().error_code, AuthError); + } + + #[test] + fn test_validate_hmac_auth_tampered_body() { + let auth_config = AuthConfig { api_key: "test_api_key".to_string() }; + let original_body = b"test request body"; + let tampered_body = b"tampered body"; + let timestamp = + std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); + // Compute HMAC with original body + let hmac = compute_hmac(&auth_config.api_key, timestamp, original_body); + + // Try to validate with tampered body + let result = validate_hmac_auth(timestamp, &hmac, tampered_body, &auth_config); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().error_code, AuthError); + } +} diff --git a/ldk-server/src/util/config.rs b/ldk-server/src/util/config.rs index f09ca86..bb77619 100644 --- a/ldk-server/src/util/config.rs +++ b/ldk-server/src/util/config.rs @@ -24,6 +24,7 @@ pub struct Config { pub listening_addr: SocketAddress, pub alias: Option, pub network: Network, + pub auth_config: AuthConfig, pub rest_service_addr: SocketAddr, pub storage_dir_path: String, pub chain_source: ChainSource, @@ -34,6 +35,11 @@ pub struct Config { pub log_file_path: Option, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AuthConfig { + pub api_key: String, +} + #[derive(Debug)] pub enum ChainSource { Rpc { rpc_address: SocketAddr, rpc_user: String, rpc_password: String }, @@ -144,12 +150,23 @@ impl TryFrom for Config { ))? .into()); + let auth_config = toml_config + .auth + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`auth` section with `api_key` is required in config file", + ) + }) + .map(|auth| AuthConfig { api_key: auth.api_key })?; + Ok(Config { listening_addr, network: toml_config.node.network, alias, rest_service_addr, storage_dir_path: toml_config.storage.disk.dir_path, + auth_config, chain_source, rabbitmq_connection_string, rabbitmq_exchange_name, @@ -171,6 +188,7 @@ pub struct TomlConfig { rabbitmq: Option, liquidity: Option, log: Option, + auth: Option, } #[derive(Deserialize, Serialize)] @@ -220,6 +238,11 @@ struct RabbitmqConfig { exchange_name: String, } +#[derive(Deserialize, Serialize)] +struct TomlAuthConfig { + api_key: String, +} + #[derive(Deserialize, Serialize)] struct LiquidityConfig { lsps2_service: Option, @@ -304,17 +327,20 @@ mod tests { listening_address = "localhost:3001" rest_service_address = "127.0.0.1:3002" alias = "LDK Server" - + [storage.disk] dir_path = "/tmp" [log] level = "Trace" file = "/var/log/ldk-server.log" - + + [auth] + api_key = "test_api_key" + [esplora] server_url = "https://mempool.space/api" - + [rabbitmq] connection_string = "rabbitmq_connection_string" exchange_name = "rabbitmq_exchange_name" @@ -344,6 +370,7 @@ mod tests { network: Network::Regtest, rest_service_addr: SocketAddr::from_str("127.0.0.1:3002").unwrap(), storage_dir_path: "/tmp".to_string(), + auth_config: AuthConfig { api_key: "test_api_key".to_string() }, chain_source: ChainSource::Esplora { server_url: String::from("https://mempool.space/api"), }, @@ -369,6 +396,7 @@ mod tests { assert_eq!(config.network, expected.network); assert_eq!(config.rest_service_addr, expected.rest_service_addr); assert_eq!(config.storage_dir_path, expected.storage_dir_path); + assert_eq!(config.auth_config, expected.auth_config); let ChainSource::Esplora { server_url } = config.chain_source else { panic!("unexpected config chain source"); }; @@ -389,21 +417,24 @@ mod tests { listening_address = "localhost:3001" rest_service_address = "127.0.0.1:3002" alias = "LDK Server" - + [storage.disk] dir_path = "/tmp" [log] level = "Trace" file = "/var/log/ldk-server.log" - + + [auth] + api_key = "test_api_key" + [electrum] server_url = "ssl://electrum.blockstream.info:50002" - + [rabbitmq] connection_string = "rabbitmq_connection_string" exchange_name = "rabbitmq_exchange_name" - + [liquidity.lsps2_service] advertise_service = false channel_opening_fee_ppm = 1000 # 0.1% fee @@ -433,19 +464,22 @@ mod tests { listening_address = "localhost:3001" rest_service_address = "127.0.0.1:3002" alias = "LDK Server" - + [storage.disk] dir_path = "/tmp" [log] level = "Trace" file = "/var/log/ldk-server.log" - + + [auth] + api_key = "test_api_key" + [bitcoind] rpc_address = "127.0.0.1:8332" # RPC endpoint rpc_user = "bitcoind-testuser" rpc_password = "bitcoind-testpassword" - + [rabbitmq] connection_string = "rabbitmq_connection_string" exchange_name = "rabbitmq_exchange_name" @@ -481,22 +515,25 @@ mod tests { listening_address = "localhost:3001" rest_service_address = "127.0.0.1:3002" alias = "LDK Server" - + [storage.disk] dir_path = "/tmp" [log] level = "Trace" file = "/var/log/ldk-server.log" - + + [auth] + api_key = "test_api_key" + [bitcoind] rpc_address = "127.0.0.1:8332" # RPC endpoint rpc_user = "bitcoind-testuser" rpc_password = "bitcoind-testpassword" - + [esplora] server_url = "https://mempool.space/api" - + [rabbitmq] connection_string = "rabbitmq_connection_string" exchange_name = "rabbitmq_exchange_name" diff --git a/ldk-server/src/util/proto_adapter.rs b/ldk-server/src/util/proto_adapter.rs index 7fb7255..6e42a88 100644 --- a/ldk-server/src/util/proto_adapter.rs +++ b/ldk-server/src/util/proto_adapter.rs @@ -9,7 +9,7 @@ use crate::api::error::LdkServerError; use crate::api::error::LdkServerErrorCode::{ - InternalServerError, InvalidRequestError, LightningError, + AuthError, InternalServerError, InvalidRequestError, LightningError, }; use bytes::Bytes; use hex::prelude::*; @@ -443,12 +443,14 @@ pub(crate) fn proto_to_bolt11_description( pub(crate) fn to_error_response(ldk_error: LdkServerError) -> (ErrorResponse, StatusCode) { let error_code = match ldk_error.error_code { InvalidRequestError => ErrorCode::InvalidRequestError, + AuthError => ErrorCode::AuthError, LightningError => ErrorCode::LightningError, InternalServerError => ErrorCode::InternalServerError, } as i32; let status = match ldk_error.error_code { InvalidRequestError => StatusCode::BAD_REQUEST, + AuthError => StatusCode::UNAUTHORIZED, LightningError => StatusCode::INTERNAL_SERVER_ERROR, InternalServerError => StatusCode::INTERNAL_SERVER_ERROR, }; From 84661c939ba2efe7b24e9a1bac5118c86122fff3 Mon Sep 17 00:00:00 2001 From: benthecarman Date: Fri, 9 Jan 2026 16:54:21 -0600 Subject: [PATCH 2/2] Add TLS support for RPC/CLI With TLS we can now have encrypted communication with the client and server. We auto generate a self signed cert on startup. Co-Authored-By: Claude Opus 4.5 --- Cargo.lock | 23 ++++ ldk-server-cli/src/main.rs | 20 ++- ldk-server-client/src/client.rs | 54 +++++--- ldk-server/Cargo.toml | 2 + ldk-server/src/main.rs | 34 ++++- ldk-server/src/util/config.rs | 28 +++++ ldk-server/src/util/mod.rs | 1 + ldk-server/src/util/tls.rs | 217 ++++++++++++++++++++++++++++++++ 8 files changed, 353 insertions(+), 26 deletions(-) create mode 100644 ldk-server/src/util/tls.rs diff --git a/Cargo.lock b/Cargo.lock index 69dbd04..917d680 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1731,9 +1731,11 @@ dependencies = [ "log", "prost", "rand 0.8.5", + "rcgen", "rusqlite", "serde", "tokio", + "tokio-rustls 0.26.4", "toml", ] @@ -2526,6 +2528,18 @@ dependencies = [ "cipher", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "reactor-trait" version = "1.1.0" @@ -4040,6 +4054,15 @@ dependencies = [ "time", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "yoke" version = "0.8.1" diff --git a/ldk-server-cli/src/main.rs b/ldk-server-cli/src/main.rs index bab6408..69a2279 100644 --- a/ldk-server-cli/src/main.rs +++ b/ldk-server-cli/src/main.rs @@ -49,6 +49,13 @@ struct Cli { #[arg(short, long)] api_key: String, + #[arg( + short, + long, + help = "Path to the server's TLS certificate file (PEM format). Found at /tls_cert.pem" + )] + tls_cert: String, + #[command(subcommand)] command: Commands, } @@ -217,7 +224,18 @@ enum Commands { #[tokio::main] async fn main() { let cli = Cli::parse(); - let client = LdkServerClient::new(cli.base_url, cli.api_key); + + // Load server certificate for TLS verification + let server_cert_pem = std::fs::read(&cli.tls_cert).unwrap_or_else(|e| { + eprintln!("Failed to read server certificate file '{}': {}", cli.tls_cert, e); + std::process::exit(1); + }); + + let client = + LdkServerClient::new(cli.base_url, cli.api_key, &server_cert_pem).unwrap_or_else(|e| { + eprintln!("Failed to create client: {e}"); + std::process::exit(1); + }); match cli.command { Commands::GetNodeInfo => { diff --git a/ldk-server-client/src/client.rs b/ldk-server-client/src/client.rs index 3c76060..060f9bd 100644 --- a/ldk-server-client/src/client.rs +++ b/ldk-server-client/src/client.rs @@ -33,12 +33,16 @@ use ldk_server_protos::endpoints::{ }; use ldk_server_protos::error::{ErrorCode, ErrorResponse}; use reqwest::header::CONTENT_TYPE; -use reqwest::Client; +use reqwest::{Certificate, Client}; use std::time::{SystemTime, UNIX_EPOCH}; const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; /// Client to access a hosted instance of LDK Server. +/// +/// The client requires the server's TLS certificate to be provided for verification. +/// This certificate can be found at `/tls_cert.pem` after the +/// server generates it on first startup. #[derive(Clone)] pub struct LdkServerClient { base_url: String, @@ -48,9 +52,21 @@ pub struct LdkServerClient { impl LdkServerClient { /// Constructs a [`LdkServerClient`] using `base_url` as the ldk-server endpoint. + /// + /// `base_url` should not include the scheme, e.g., `localhost:3000`. /// `api_key` is used for HMAC-based authentication. - pub fn new(base_url: String, api_key: String) -> Self { - Self { base_url, client: Client::new(), api_key } + /// `server_cert_pem` is the server's TLS certificate in PEM format. This can be + /// found at `/tls_cert.pem` after the server starts. + pub fn new(base_url: String, api_key: String, server_cert_pem: &[u8]) -> Result { + let cert = Certificate::from_pem(server_cert_pem) + .map_err(|e| format!("Failed to parse server certificate: {e}"))?; + + let client = Client::builder() + .add_root_certificate(cert) + .build() + .map_err(|e| format!("Failed to build HTTP client: {e}"))?; + + Ok(Self { base_url, client, api_key }) } /// Computes the HMAC-SHA256 authentication header value. @@ -75,7 +91,7 @@ impl LdkServerClient { pub async fn get_node_info( &self, request: GetNodeInfoRequest, ) -> Result { - let url = format!("http://{}/{GET_NODE_INFO_PATH}", self.base_url); + let url = format!("https://{}/{GET_NODE_INFO_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -84,7 +100,7 @@ impl LdkServerClient { pub async fn get_balances( &self, request: GetBalancesRequest, ) -> Result { - let url = format!("http://{}/{GET_BALANCES_PATH}", self.base_url); + let url = format!("https://{}/{GET_BALANCES_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -93,7 +109,7 @@ impl LdkServerClient { pub async fn onchain_receive( &self, request: OnchainReceiveRequest, ) -> Result { - let url = format!("http://{}/{ONCHAIN_RECEIVE_PATH}", self.base_url); + let url = format!("https://{}/{ONCHAIN_RECEIVE_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -102,7 +118,7 @@ impl LdkServerClient { pub async fn onchain_send( &self, request: OnchainSendRequest, ) -> Result { - let url = format!("http://{}/{ONCHAIN_SEND_PATH}", self.base_url); + let url = format!("https://{}/{ONCHAIN_SEND_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -111,7 +127,7 @@ impl LdkServerClient { pub async fn bolt11_receive( &self, request: Bolt11ReceiveRequest, ) -> Result { - let url = format!("http://{}/{BOLT11_RECEIVE_PATH}", self.base_url); + let url = format!("https://{}/{BOLT11_RECEIVE_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -120,7 +136,7 @@ impl LdkServerClient { pub async fn bolt11_send( &self, request: Bolt11SendRequest, ) -> Result { - let url = format!("http://{}/{BOLT11_SEND_PATH}", self.base_url); + let url = format!("https://{}/{BOLT11_SEND_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -129,7 +145,7 @@ impl LdkServerClient { pub async fn bolt12_receive( &self, request: Bolt12ReceiveRequest, ) -> Result { - let url = format!("http://{}/{BOLT12_RECEIVE_PATH}", self.base_url); + let url = format!("https://{}/{BOLT12_RECEIVE_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -138,7 +154,7 @@ impl LdkServerClient { pub async fn bolt12_send( &self, request: Bolt12SendRequest, ) -> Result { - let url = format!("http://{}/{BOLT12_SEND_PATH}", self.base_url); + let url = format!("https://{}/{BOLT12_SEND_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -147,7 +163,7 @@ impl LdkServerClient { pub async fn open_channel( &self, request: OpenChannelRequest, ) -> Result { - let url = format!("http://{}/{OPEN_CHANNEL_PATH}", self.base_url); + let url = format!("https://{}/{OPEN_CHANNEL_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -156,7 +172,7 @@ impl LdkServerClient { pub async fn splice_in( &self, request: SpliceInRequest, ) -> Result { - let url = format!("http://{}/{SPLICE_IN_PATH}", self.base_url); + let url = format!("https://{}/{SPLICE_IN_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -165,7 +181,7 @@ impl LdkServerClient { pub async fn splice_out( &self, request: SpliceOutRequest, ) -> Result { - let url = format!("http://{}/{SPLICE_OUT_PATH}", self.base_url); + let url = format!("https://{}/{SPLICE_OUT_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -174,7 +190,7 @@ impl LdkServerClient { pub async fn close_channel( &self, request: CloseChannelRequest, ) -> Result { - let url = format!("http://{}/{CLOSE_CHANNEL_PATH}", self.base_url); + let url = format!("https://{}/{CLOSE_CHANNEL_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -183,7 +199,7 @@ impl LdkServerClient { pub async fn force_close_channel( &self, request: ForceCloseChannelRequest, ) -> Result { - let url = format!("http://{}/{FORCE_CLOSE_CHANNEL_PATH}", self.base_url); + let url = format!("https://{}/{FORCE_CLOSE_CHANNEL_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -192,7 +208,7 @@ impl LdkServerClient { pub async fn list_channels( &self, request: ListChannelsRequest, ) -> Result { - let url = format!("http://{}/{LIST_CHANNELS_PATH}", self.base_url); + let url = format!("https://{}/{LIST_CHANNELS_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -201,7 +217,7 @@ impl LdkServerClient { pub async fn list_payments( &self, request: ListPaymentsRequest, ) -> Result { - let url = format!("http://{}/{LIST_PAYMENTS_PATH}", self.base_url); + let url = format!("https://{}/{LIST_PAYMENTS_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -210,7 +226,7 @@ impl LdkServerClient { pub async fn update_channel_config( &self, request: UpdateChannelConfigRequest, ) -> Result { - let url = format!("http://{}/{UPDATE_CHANNEL_CONFIG_PATH}", self.base_url); + let url = format!("https://{}/{UPDATE_CHANNEL_CONFIG_PATH}", self.base_url); self.post_request(&request, &url).await } diff --git a/ldk-server/Cargo.toml b/ldk-server/Cargo.toml index 62f82d3..3ec19fe 100644 --- a/ldk-server/Cargo.toml +++ b/ldk-server/Cargo.toml @@ -10,6 +10,8 @@ hyper = { version = "1", default-features = false, features = ["server", "http1" http-body-util = { version = "0.1", default-features = false } hyper-util = { version = "0.1", default-features = false, features = ["server-graceful"] } tokio = { version = "1.38.0", default-features = false, features = ["time", "signal", "rt-multi-thread"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +rcgen = { version = "0.13", default-features = false, features = ["ring"] } prost = { version = "0.11.6", default-features = false, features = ["std"] } ldk-server-protos = { path = "../ldk-server-protos" } bytes = { version = "1.4.0", default-features = false } diff --git a/ldk-server/src/main.rs b/ldk-server/src/main.rs index f635994..2938814 100644 --- a/ldk-server/src/main.rs +++ b/ldk-server/src/main.rs @@ -36,6 +36,7 @@ use crate::io::persist::{ use crate::util::config::{load_config, ChainSource}; use crate::util::logger::ServerLogger; use crate::util::proto_adapter::{forwarded_payment_to_proto, payment_to_proto}; +use crate::util::tls::get_or_generate_tls_config; use hex::DisplayHex; use ldk_node::config::Config; use ldk_node::lightning::ln::channelmanager::PaymentId; @@ -155,14 +156,15 @@ fn main() { }, }; - let paginated_store: Arc = - Arc::new(match SqliteStore::new(PathBuf::from(config_file.storage_dir_path), None, None) { + let paginated_store: Arc = Arc::new( + match SqliteStore::new(PathBuf::from(&config_file.storage_dir_path), None, None) { Ok(store) => store, Err(e) => { error!("Failed to create SqliteStore: {e:?}"); std::process::exit(-1); }, - }); + }, + ); #[cfg(not(feature = "events-rabbitmq"))] let event_publisher: Arc = @@ -213,6 +215,20 @@ fn main() { let rest_svc_listener = TcpListener::bind(config_file.rest_service_addr) .await .expect("Failed to bind listening port"); + + let server_config = match get_or_generate_tls_config( + config_file.tls_config.as_ref(), + &config_file.storage_dir_path, + ) { + Ok(config) => config, + Err(e) => { + error!("Failed to set up TLS: {e}"); + std::process::exit(-1); + } + }; + let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config)); + info!("TLS enabled for REST service on {}", config_file.rest_service_addr); + loop { select! { event = event_node.next_event_async() => { @@ -355,11 +371,17 @@ fn main() { res = rest_svc_listener.accept() => { match res { Ok((stream, _)) => { - let io_stream = TokioIo::new(stream); let node_service = NodeService::new(Arc::clone(&node), Arc::clone(&paginated_store), config_file.auth_config.clone()); + let acceptor = tls_acceptor.clone(); runtime.spawn(async move { - if let Err(err) = http1::Builder::new().serve_connection(io_stream, node_service).await { - error!("Failed to serve connection: {}", err); + match acceptor.accept(stream).await { + Ok(tls_stream) => { + let io_stream = TokioIo::new(tls_stream); + if let Err(err) = http1::Builder::new().serve_connection(io_stream, node_service).await { + error!("Failed to serve TLS connection: {err}"); + } + }, + Err(e) => error!("TLS handshake failed: {e}"), } }); }, diff --git a/ldk-server/src/util/config.rs b/ldk-server/src/util/config.rs index bb77619..9e2c3c4 100644 --- a/ldk-server/src/util/config.rs +++ b/ldk-server/src/util/config.rs @@ -25,6 +25,7 @@ pub struct Config { pub alias: Option, pub network: Network, pub auth_config: AuthConfig, + pub tls_config: Option, pub rest_service_addr: SocketAddr, pub storage_dir_path: String, pub chain_source: ChainSource, @@ -40,6 +41,12 @@ pub struct AuthConfig { pub api_key: String, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TlsConfig { + pub cert_path: String, + pub key_path: String, +} + #[derive(Debug)] pub enum ChainSource { Rpc { rpc_address: SocketAddr, rpc_user: String, rpc_password: String }, @@ -160,6 +167,10 @@ impl TryFrom for Config { }) .map(|auth| AuthConfig { api_key: auth.api_key })?; + let tls_config = toml_config + .tls + .map(|tls| TlsConfig { cert_path: tls.cert_path, key_path: tls.key_path }); + Ok(Config { listening_addr, network: toml_config.node.network, @@ -173,6 +184,7 @@ impl TryFrom for Config { lsps2_service_config, log_level, log_file_path: toml_config.log.and_then(|l| l.file), + tls_config, }) } } @@ -189,6 +201,7 @@ pub struct TomlConfig { liquidity: Option, log: Option, auth: Option, + tls: Option, } #[derive(Deserialize, Serialize)] @@ -243,6 +256,12 @@ struct TomlAuthConfig { api_key: String, } +#[derive(Deserialize, Serialize)] +struct TomlTlsConfig { + cert_path: String, + key_path: String, +} + #[derive(Deserialize, Serialize)] struct LiquidityConfig { lsps2_service: Option, @@ -338,6 +357,10 @@ mod tests { [auth] api_key = "test_api_key" + [tls] + cert_path = "/path/to/cert.pem" + key_path = "/path/to/key.pem" + [esplora] server_url = "https://mempool.space/api" @@ -371,6 +394,10 @@ mod tests { rest_service_addr: SocketAddr::from_str("127.0.0.1:3002").unwrap(), storage_dir_path: "/tmp".to_string(), auth_config: AuthConfig { api_key: "test_api_key".to_string() }, + tls_config: Some(TlsConfig { + cert_path: "/path/to/cert.pem".to_string(), + key_path: "/path/to/key.pem".to_string(), + }), chain_source: ChainSource::Esplora { server_url: String::from("https://mempool.space/api"), }, @@ -397,6 +424,7 @@ mod tests { assert_eq!(config.rest_service_addr, expected.rest_service_addr); assert_eq!(config.storage_dir_path, expected.storage_dir_path); assert_eq!(config.auth_config, expected.auth_config); + assert_eq!(config.tls_config, expected.tls_config); let ChainSource::Esplora { server_url } = config.chain_source else { panic!("unexpected config chain source"); }; diff --git a/ldk-server/src/util/mod.rs b/ldk-server/src/util/mod.rs index 8bcf1c1..3662b12 100644 --- a/ldk-server/src/util/mod.rs +++ b/ldk-server/src/util/mod.rs @@ -10,3 +10,4 @@ pub(crate) mod config; pub(crate) mod logger; pub(crate) mod proto_adapter; +pub(crate) mod tls; diff --git a/ldk-server/src/util/tls.rs b/ldk-server/src/util/tls.rs new file mode 100644 index 0000000..e282856 --- /dev/null +++ b/ldk-server/src/util/tls.rs @@ -0,0 +1,217 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +use crate::util::config::TlsConfig; +use base64::Engine; +use rcgen::{generate_simple_self_signed, CertifiedKey}; +use std::fs; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use tokio_rustls::rustls::ServerConfig; + +// PEM markers +const PEM_CERT_BEGIN: &str = "-----BEGIN CERTIFICATE-----"; +const PEM_CERT_END: &str = "-----END CERTIFICATE-----"; +const PEM_KEY_BEGIN: &str = "-----BEGIN PRIVATE KEY-----"; +const PEM_KEY_END: &str = "-----END PRIVATE KEY-----"; + +/// Gets or generates TLS configuration. If custom paths are provided, uses those. +/// Otherwise, generates a self-signed certificate in the storage directory. +pub fn get_or_generate_tls_config( + tls_config: Option<&TlsConfig>, storage_dir: &str, +) -> Result { + if let Some(tls) = tls_config { + load_tls_config(&tls.cert_path, &tls.key_path) + } else { + // Check if we already have generated certs, if we don't, generate new ones + let cert_path = format!("{storage_dir}/tls_cert.pem"); + let key_path = format!("{storage_dir}/tls_key.pem"); + if !fs::exists(&cert_path).unwrap_or(false) || !fs::exists(&key_path).unwrap_or(false) { + generate_self_signed_cert(&cert_path, &key_path)?; + } + + load_tls_config(&cert_path, &key_path) + } +} + +/// Parses a PEM-encoded certificate file and returns the DER-encoded certificates. +fn parse_pem_certs(pem_data: &str) -> Result>, String> { + let mut certs = Vec::new(); + + for block in pem_data.split(PEM_CERT_END) { + if let Some(start) = block.find(PEM_CERT_BEGIN) { + let base64_content: String = block[start + PEM_CERT_BEGIN.len()..] + .lines() + .filter(|line| !line.starts_with("-----") && !line.is_empty()) + .collect(); + + let der = base64::engine::general_purpose::STANDARD + .decode(&base64_content) + .map_err(|e| format!("Failed to decode certificate base64: {e}"))?; + + certs.push(CertificateDer::from(der)); + } + } + + Ok(certs) +} + +/// Parses a PEM-encoded PKCS#8 private key file and returns the DER-encoded key. +fn parse_pem_private_key(pem_data: &str) -> Result, String> { + let start = pem_data.find(PEM_KEY_BEGIN).ok_or("Missing BEGIN PRIVATE KEY marker")?; + let end = pem_data.find(PEM_KEY_END).ok_or("Missing END PRIVATE KEY marker")?; + + let base64_content: String = pem_data[start + PEM_KEY_BEGIN.len()..end] + .lines() + .filter(|line| !line.starts_with("-----") && !line.is_empty()) + .collect(); + + let der = base64::engine::general_purpose::STANDARD + .decode(&base64_content) + .map_err(|e| format!("Failed to decode private key base64: {e}"))?; + + Ok(PrivateKeyDer::Pkcs8(der.into())) +} + +/// Generates a self-signed TLS certificate and saves it to the storage directory. +/// Returns the paths to the generated cert and key files. +fn generate_self_signed_cert(cert_path: &str, key_path: &str) -> Result<(), String> { + let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()]; + + let CertifiedKey { cert, key_pair } = generate_simple_self_signed(subject_alt_names) + .map_err(|e| format!("Failed to generate self-signed certificate: {e}"))?; + + // Convert DER to PEM format + let cert_der = cert.der(); + let key_der = key_pair.serialize_der(); + + let cert_pem = format!( + "{PEM_CERT_BEGIN}\n{}\n{PEM_CERT_END}\n", + base64::engine::general_purpose::STANDARD + .encode(cert_der) + .as_bytes() + .chunks(64) + .map(|chunk| std::str::from_utf8(chunk).unwrap()) + .collect::>() + .join("\n") + ); + + let key_pem = format!( + "{PEM_KEY_BEGIN}\n{}\n{PEM_KEY_END}\n", + base64::engine::general_purpose::STANDARD + .encode(&key_der) + .as_bytes() + .chunks(64) + .map(|chunk| std::str::from_utf8(chunk).unwrap()) + .collect::>() + .join("\n") + ); + + fs::write(cert_path, &cert_pem) + .map_err(|e| format!("Failed to write TLS certificate to '{cert_path}': {e}"))?; + fs::write(key_path, &key_pem) + .map_err(|e| format!("Failed to write TLS key to '{key_path}': {e}"))?; + + Ok(()) +} + +/// Loads TLS configuration from provided paths. +fn load_tls_config(cert_path: &str, key_path: &str) -> Result { + let cert_pem = fs::read_to_string(cert_path) + .map_err(|e| format!("Failed to read TLS certificate file '{cert_path}': {e}"))?; + let key_pem = fs::read_to_string(key_path) + .map_err(|e| format!("Failed to read TLS key file '{key_path}': {e}"))?; + + let certs = parse_pem_certs(&cert_pem)?; + + if certs.is_empty() { + return Err("No certificates found in certificate file".to_string()); + } + + let key = parse_pem_private_key(&key_pem)?; + + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| format!("Failed to build TLS server config: {e}")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_pem_certs() { + let pem = "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpegPjMCMA0GCSqGSIb3DQEBCwUAMBExDzANBgNVBAMMBnVu\ndXNlZDAeFw0yMzAxMDEwMDAwMDBaFw0yNDAxMDEwMDAwMDBaMBExDzANBgNVBAMM\nBnVudXNlZDBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQC7o96FCEcJsggt0c0dSfEB\nmm6vv1LdCoxXnhOSCutoJgJgmCPBjU1doFFKwAtXjfOv0eSLZ3NHLu0LRKmVvOsP\nAgMBAAGjUzBRMB0GA1UdDgQWBBQK3fc0myO0psd71FJd8v7VCmDJOzAfBgNVHSME\nGDAWgBQK3fc0myO0psd71FJd8v7VCmDJOzAPBgNVHRMBAf8EBTADAQH/MA0GCSqG\nSIb3DQEBCwUAA0EAhJg0cx2pFfVfGBfbJQNFa+A4ynJBMqKYlbUnJBfWPwg13RhC\nivLjYyhKzEbnOug0TuFfVaUBGfBYbPgaJQ4BAg==\n-----END CERTIFICATE-----\n"; + + let certs = parse_pem_certs(pem).unwrap(); + assert_eq!(certs.len(), 1); + assert!(!certs[0].is_empty()); + } + + #[test] + fn test_parse_pem_certs_multiple() { + let pem = "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpegPjMCMA0GCSqGSIb3DQEBCwUAMBExDzANBgNVBAMMBnVu\ndXNlZDAeFw0yMzAxMDEwMDAwMDBaFw0yNDAxMDEwMDAwMDBaMBExDzANBgNVBAMM\nBnVudXNlZDBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQC7o96FCEcJsggt0c0dSfEB\nmm6vv1LdCoxXnhOSCutoJgJgmCPBjU1doFFKwAtXjfOv0eSLZ3NHLu0LRKmVvOsP\nAgMBAAGjUzBRMB0GA1UdDgQWBBQK3fc0myO0psd71FJd8v7VCmDJOzAfBgNVHSME\nGDAWgBQK3fc0myO0psd71FJd8v7VCmDJOzAPBgNVHRMBAf8EBTADAQH/MA0GCSqG\nSIb3DQEBCwUAA0EAhJg0cx2pFfVfGBfbJQNFa+A4ynJBMqKYlbUnJBfWPwg13RhC\nivLjYyhKzEbnOug0TuFfVaUBGfBYbPgaJQ4BAg==\n-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpegPjMCMA0GCSqGSIb3DQEBCwUAMBExDzANBgNVBAMMBnVu\ndXNlZDAeFw0yMzAxMDEwMDAwMDBaFw0yNDAxMDEwMDAwMDBaMBExDzANBgNVBAMM\nBnVudXNlZDBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQC7o96FCEcJsggt0c0dSfEB\nmm6vv1LdCoxXnhOSCutoJgJgmCPBjU1doFFKwAtXjfOv0eSLZ3NHLu0LRKmVvOsP\nAgMBAAGjUzBRMB0GA1UdDgQWBBQK3fc0myO0psd71FJd8v7VCmDJOzAfBgNVHSME\nGDAWgBQK3fc0myO0psd71FJd8v7VCmDJOzAPBgNVHRMBAf8EBTADAQH/MA0GCSqG\nSIb3DQEBCwUAA0EAhJg0cx2pFfVfGBfbJQNFa+A4ynJBMqKYlbUnJBfWPwg13RhC\nivLjYyhKzEbnOug0TuFfVaUBGfBYbPgaJQ4BAg==\n-----END CERTIFICATE-----\n"; + + let certs = parse_pem_certs(pem).unwrap(); + assert_eq!(certs.len(), 2); + } + + #[test] + fn test_parse_pem_certs_empty() { + let certs = parse_pem_certs("").unwrap(); + assert!(certs.is_empty()); + + let certs = parse_pem_certs("not a cert").unwrap(); + assert!(certs.is_empty()); + } + + #[test] + fn test_parse_pem_private_key_pkcs8() { + let pem = "-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg2a2rwplBQLzHPDvn\nsaw8HKDP6WYBSF684gcz+D7zeVShRANCAAQq8R/E45tTNWMEpK8abYM7VzuJxpPS\nhJCi6bzjOPGHawEO8safLOWFaV7GqLJM0OdM3eu/qcz8HwgI3T8EVHQK\n-----END PRIVATE KEY-----\n"; + + let key = parse_pem_private_key(pem).unwrap(); + assert!(matches!(key, PrivateKeyDer::Pkcs8(_))); + } + + #[test] + fn test_parse_pem_private_key_invalid() { + let result = parse_pem_private_key(""); + assert!(result.is_err()); + + let result = parse_pem_private_key("not a key"); + assert!(result.is_err()); + } + + #[test] + fn test_generate_and_load_roundtrip() { + let temp_dir = std::env::temp_dir(); + let suffix: u64 = rand::random(); + let cert_path = temp_dir.join(format!("test_tls_cert_{suffix}.pem")); + let key_path = temp_dir.join(format!("test_tls_key_{suffix}.pem")); + + // Clean up any existing files to be safe + let _ = fs::remove_file(&cert_path); + let _ = fs::remove_file(&key_path); + + // Generate cert + generate_self_signed_cert(cert_path.to_str().unwrap(), key_path.to_str().unwrap()).unwrap(); + + // Verify files exist + assert!(cert_path.exists()); + assert!(key_path.exists()); + + // Load config + let res = load_tls_config(cert_path.to_str().unwrap(), key_path.to_str().unwrap()); + assert!(res.is_ok()); + + // Clean up + let _ = fs::remove_file(&cert_path); + let _ = fs::remove_file(&key_path); + } +}