diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index 22fabf0a3..d0c9929bd 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -146,6 +146,15 @@ version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" +[[package]] +name = "arc-swap" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d03449bb8ca2cc2ef70869af31463d1ae5ccc8fa3e334b307203fbf815207e" +dependencies = [ + "rustversion", +] + [[package]] name = "arraydeque" version = "0.5.1" @@ -301,6 +310,80 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-server" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1ab4a3ec9ea8a657c72d99a03a824af695bd0fb5ec639ccbd9cd3543b41a5f9" +dependencies = [ + "arc-swap", + "bytes", + "fs-err 3.2.2", + "http", + "http-body", + "hyper", + "hyper-util", + "pin-project-lite", + "rustls 0.23.31", + "rustls-pemfile", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "base58ck" version = "0.1.0" @@ -1144,7 +1227,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -1304,6 +1387,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "fs-err" +version = "3.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf68cef89750956493a66a10f512b9e58d9db21f2a573c079c0bdf1207a54a7" +dependencies = [ + "autocfg", + "tokio", +] + [[package]] name = "fs2" version = "0.4.3" @@ -1712,21 +1805,6 @@ dependencies = [ "webpki-roots 1.0.2", ] -[[package]] -name = "hyper-tungstenite" -version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ec0b60f8f6371eb04e4b19361b39fa9a1a88bf344d50c31347824599ca150e" -dependencies = [ - "http-body-util", - "hyper", - "hyper-util", - "pin-project-lite", - "tokio", - "tokio-tungstenite", - "tungstenite", -] - [[package]] name = "hyper-util" version = "0.1.16" @@ -2076,6 +2154,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "memchr" version = "2.7.4" @@ -2194,7 +2278,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -2243,7 +2327,6 @@ dependencies = [ "http-body-util", "hyper", "hyper-rustls", - "hyper-tungstenite", "hyper-util", "mockito", "rcgen 0.12.1", @@ -2254,6 +2337,7 @@ dependencies = [ "tokio-rustls", "tokio-tungstenite", "tokio-util", + "tower", "tracing", "tracing-subscriber", "uuid", @@ -2434,6 +2518,7 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", + "ohttp-relay", "payjoin", "prometheus", "rand 0.8.5", @@ -2443,6 +2528,7 @@ dependencies = [ "tokio-rustls", "tokio-rustls-acme", "tokio-stream", + "tower", "tracing", "tracing-subscriber", ] @@ -2466,18 +2552,38 @@ dependencies = [ "url", ] +[[package]] +name = "payjoin-service" +version = "0.0.1" +dependencies = [ + "anyhow", + "axum", + "axum-server", + "clap", + "config", + "ohttp-relay", + "payjoin-directory", + "rand 0.8.5", + "rustls 0.23.31", + "serde", + "tokio", + "tower", + "tracing", + "tracing-subscriber", +] + [[package]] name = "payjoin-test-utils" version = "0.0.1" dependencies = [ + "axum-server", "bitcoin 0.32.8", "bitcoin-ohttp", "corepc-node", "http", - "ohttp-relay", "once_cell", "payjoin", - "payjoin-directory", + "payjoin-service", "rcgen 0.14.3", "reqwest", "rustls 0.23.31", @@ -3095,6 +3201,15 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.12.0" @@ -3320,6 +3435,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_spanned" version = "1.0.0" @@ -3869,6 +3995,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -3912,6 +4039,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -4082,7 +4210,7 @@ dependencies = [ "askama", "camino", "cargo_metadata", - "fs-err", + "fs-err 2.11.0", "glob", "goblin", "heck", @@ -4155,7 +4283,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64c6309fc36c7992afc03bc0c5b059c656bccbef3f2a4bc362980017f8936141" dependencies = [ "camino", - "fs-err", + "fs-err 2.11.0", "once_cell", "proc-macro2", "quote", diff --git a/Cargo-recent.lock b/Cargo-recent.lock index 22fabf0a3..d0c9929bd 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -146,6 +146,15 @@ version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" +[[package]] +name = "arc-swap" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d03449bb8ca2cc2ef70869af31463d1ae5ccc8fa3e334b307203fbf815207e" +dependencies = [ + "rustversion", +] + [[package]] name = "arraydeque" version = "0.5.1" @@ -301,6 +310,80 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-server" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1ab4a3ec9ea8a657c72d99a03a824af695bd0fb5ec639ccbd9cd3543b41a5f9" +dependencies = [ + "arc-swap", + "bytes", + "fs-err 3.2.2", + "http", + "http-body", + "hyper", + "hyper-util", + "pin-project-lite", + "rustls 0.23.31", + "rustls-pemfile", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "base58ck" version = "0.1.0" @@ -1144,7 +1227,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -1304,6 +1387,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "fs-err" +version = "3.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf68cef89750956493a66a10f512b9e58d9db21f2a573c079c0bdf1207a54a7" +dependencies = [ + "autocfg", + "tokio", +] + [[package]] name = "fs2" version = "0.4.3" @@ -1712,21 +1805,6 @@ dependencies = [ "webpki-roots 1.0.2", ] -[[package]] -name = "hyper-tungstenite" -version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ec0b60f8f6371eb04e4b19361b39fa9a1a88bf344d50c31347824599ca150e" -dependencies = [ - "http-body-util", - "hyper", - "hyper-util", - "pin-project-lite", - "tokio", - "tokio-tungstenite", - "tungstenite", -] - [[package]] name = "hyper-util" version = "0.1.16" @@ -2076,6 +2154,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "memchr" version = "2.7.4" @@ -2194,7 +2278,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -2243,7 +2327,6 @@ dependencies = [ "http-body-util", "hyper", "hyper-rustls", - "hyper-tungstenite", "hyper-util", "mockito", "rcgen 0.12.1", @@ -2254,6 +2337,7 @@ dependencies = [ "tokio-rustls", "tokio-tungstenite", "tokio-util", + "tower", "tracing", "tracing-subscriber", "uuid", @@ -2434,6 +2518,7 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", + "ohttp-relay", "payjoin", "prometheus", "rand 0.8.5", @@ -2443,6 +2528,7 @@ dependencies = [ "tokio-rustls", "tokio-rustls-acme", "tokio-stream", + "tower", "tracing", "tracing-subscriber", ] @@ -2466,18 +2552,38 @@ dependencies = [ "url", ] +[[package]] +name = "payjoin-service" +version = "0.0.1" +dependencies = [ + "anyhow", + "axum", + "axum-server", + "clap", + "config", + "ohttp-relay", + "payjoin-directory", + "rand 0.8.5", + "rustls 0.23.31", + "serde", + "tokio", + "tower", + "tracing", + "tracing-subscriber", +] + [[package]] name = "payjoin-test-utils" version = "0.0.1" dependencies = [ + "axum-server", "bitcoin 0.32.8", "bitcoin-ohttp", "corepc-node", "http", - "ohttp-relay", "once_cell", "payjoin", - "payjoin-directory", + "payjoin-service", "rcgen 0.14.3", "reqwest", "rustls 0.23.31", @@ -3095,6 +3201,15 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.12.0" @@ -3320,6 +3435,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_spanned" version = "1.0.0" @@ -3869,6 +3995,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -3912,6 +4039,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -4082,7 +4210,7 @@ dependencies = [ "askama", "camino", "cargo_metadata", - "fs-err", + "fs-err 2.11.0", "glob", "goblin", "heck", @@ -4155,7 +4283,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64c6309fc36c7992afc03bc0c5b059c656bccbef3f2a4bc362980017f8936141" dependencies = [ "camino", - "fs-err", + "fs-err 2.11.0", "once_cell", "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 2fecb8718..045576380 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "payjoin-directory", "payjoin-test-utils", "payjoin-ffi", + "payjoin-service", ] resolver = "2" @@ -13,6 +14,7 @@ resolver = "2" ohttp-relay = { path = "ohttp-relay" } payjoin = { path = "payjoin" } payjoin-directory = { path = "payjoin-directory" } +payjoin-service = { path = "payjoin-service" } payjoin-test-utils = { path = "payjoin-test-utils" } [profile.crane] diff --git a/ohttp-relay/Cargo.toml b/ohttp-relay/Cargo.toml index 06e50b0b5..98f982745 100644 --- a/ohttp-relay/Cargo.toml +++ b/ohttp-relay/Cargo.toml @@ -16,13 +16,14 @@ exclude = ["tests"] default = ["bootstrap"] bootstrap = ["connect-bootstrap", "ws-bootstrap"] connect-bootstrap = [] -ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"] +ws-bootstrap = ["futures", "rustls", "tokio-tungstenite"] _test-util = [] [dependencies] byteorder = "1.5.0" bytes = "1.10.1" futures = { version = "0.3.31", optional = true } +hex = { package = "hex-conservative", version = "0.1.1" } http = "1.3.1" http-body-util = "0.1.3" hyper = { version = "1.6.0", features = ["http1", "server"] } @@ -31,8 +32,7 @@ hyper-rustls = { version = "0.27.7", default-features = false, features = [ "http1", "ring", ] } -hyper-tungstenite = { version = "0.18.0", optional = true } -hyper-util = { version = "0.1.16", features = ["client-legacy"] } +hyper-util = { version = "0.1.16", features = ["client-legacy", "service"] } rustls = { version = "0.23.31", optional = true, default-features = false, features = [ "ring", ] } @@ -44,11 +44,11 @@ tokio = { version = "1.47.1", features = [ ] } tokio-tungstenite = { version = "0.27.0", optional = true } tokio-util = { version = "0.7.16", features = ["net", "codec"] } +tower = "0.5" tracing = "0.1.41" tracing-subscriber = { version = "0.3.20", features = ["env-filter"] } [dev-dependencies] -hex = { package = "hex-conservative", version = "0.1.1" } mockito = "1.7.0" rcgen = "0.12" reqwest = { version = "0.12.23", default-features = false, features = [ diff --git a/ohttp-relay/src/bootstrap/connect.rs b/ohttp-relay/src/bootstrap/connect.rs index 3e57e45cf..48fe2cb6c 100644 --- a/ohttp-relay/src/bootstrap/connect.rs +++ b/ohttp-relay/src/bootstrap/connect.rs @@ -1,7 +1,8 @@ +use std::fmt::Debug; use std::net::SocketAddr; use http_body_util::combinators::BoxBody; -use hyper::body::{Bytes, Incoming}; +use hyper::body::Bytes; use hyper::upgrade::Upgraded; use hyper::{Method, Request, Response}; use hyper_util::rt::TokioIo; @@ -11,15 +12,16 @@ use tracing::{error, instrument}; use crate::error::Error; use crate::{empty, GatewayUri}; -pub(crate) fn is_connect_request(req: &Request) -> bool { - Method::CONNECT == req.method() -} +pub(crate) fn is_connect_request(req: &Request) -> bool { Method::CONNECT == req.method() } #[instrument] -pub(crate) async fn try_upgrade( - req: Request, +pub(crate) async fn try_upgrade( + req: Request, gateway_origin: GatewayUri, -) -> Result>, Error> { +) -> Result>, Error> +where + B: Send + Debug + 'static, +{ let addr = gateway_origin .to_socket_addr() .await diff --git a/ohttp-relay/src/bootstrap/mod.rs b/ohttp-relay/src/bootstrap/mod.rs index f3ae187c2..22205b689 100644 --- a/ohttp-relay/src/bootstrap/mod.rs +++ b/ohttp-relay/src/bootstrap/mod.rs @@ -1,5 +1,7 @@ +use std::fmt::Debug; + use http_body_util::combinators::BoxBody; -use hyper::body::{Bytes, Incoming}; +use hyper::body::Bytes; use hyper::{Request, Response}; use tracing::instrument; @@ -13,10 +15,13 @@ pub mod connect; pub mod ws; #[instrument] -pub(crate) async fn handle_ohttp_keys( - mut req: Request, +pub(crate) async fn handle_ohttp_keys( + req: Request, gateway_origin: GatewayUri, -) -> Result>, Error> { +) -> Result>, Error> +where + B: Send + Debug + 'static, +{ #[cfg(feature = "connect-bootstrap")] if connect::is_connect_request(&req) { return connect::try_upgrade(req, gateway_origin).await; @@ -24,7 +29,7 @@ pub(crate) async fn handle_ohttp_keys( #[cfg(feature = "ws-bootstrap")] if ws::is_websocket_request(&req) { - return ws::try_upgrade(&mut req, gateway_origin).await; + return ws::try_upgrade(req, gateway_origin).await; } Err(Error::BadRequest("Not a supported proxy upgrade request".to_string())) diff --git a/ohttp-relay/src/bootstrap/ws.rs b/ohttp-relay/src/bootstrap/ws.rs index feffa0fd1..3a23e0c27 100644 --- a/ohttp-relay/src/bootstrap/ws.rs +++ b/ohttp-relay/src/bootstrap/ws.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use std::io; use std::net::SocketAddr; use std::pin::Pin; @@ -5,54 +6,111 @@ use std::task::{Context, Poll}; use futures::{Sink, SinkExt, StreamExt}; use http_body_util::combinators::BoxBody; -use http_body_util::BodyExt; -use hyper::body::{Bytes, Incoming}; -use hyper::{Request, Response}; -use hyper_tungstenite::HyperWebsocket; +use hyper::body::Bytes; +use hyper::header::{CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE}; +use hyper::{Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_tungstenite::tungstenite::handshake::derive_accept_key; use tokio_tungstenite::tungstenite::protocol::Message; use tokio_tungstenite::{tungstenite, WebSocketStream}; use tracing::{error, instrument}; +use crate::empty; use crate::error::Error; use crate::gateway_uri::GatewayUri; -pub(crate) fn is_websocket_request(req: &Request) -> bool { - hyper_tungstenite::is_upgrade_request(req) +/// Check if the request is a WebSocket upgrade request. +/// +/// This is done manually to support generic body types. +/// When bootstrapping moves to axum, this can be replaced with +/// `axum::extract::ws::WebSocketUpgrade`. +pub(crate) fn is_websocket_request(req: &Request) -> bool { + let dominated_by_upgrade = req + .headers() + .get(CONNECTION) + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_ascii_lowercase().contains("upgrade")) + .unwrap_or(false); + + let upgrade_to_websocket = req + .headers() + .get(UPGRADE) + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false); + + dominated_by_upgrade && upgrade_to_websocket && req.headers().contains_key(SEC_WEBSOCKET_KEY) } +/// Upgrade the request to a WebSocket connection and proxy to the gateway. +/// +/// This performs the WebSocket handshake to support generic body types. +/// When bootstrapping moves to axum, this can be replaced with +/// `axum::extract::ws::WebSocketUpgrade`. #[instrument] -pub(crate) async fn try_upgrade( - req: &mut Request, +pub(crate) async fn try_upgrade( + req: Request, gateway_origin: GatewayUri, -) -> Result>, Error> { +) -> Result>, Error> +where + B: Send + Debug + 'static, +{ let gateway_addr = gateway_origin .to_socket_addr() .await .map_err(|e| Error::InternalServerError(Box::new(e)))? .ok_or_else(|| Error::NotFound)?; - let (res, websocket) = hyper_tungstenite::upgrade(req, None) - .map_err(|e| Error::BadRequest(format!("Error upgrading to websocket: {}", e)))?; + let key = req + .headers() + .get(SEC_WEBSOCKET_KEY) + .ok_or_else(|| Error::BadRequest("Missing Sec-WebSocket-Key header".to_string()))? + .to_str() + .map_err(|_| Error::BadRequest("Invalid Sec-WebSocket-Key header".to_string()))? + .to_string(); + + let accept_key = derive_accept_key(key.as_bytes()); tokio::spawn(async move { - if let Err(e) = serve_websocket(websocket, gateway_addr).await { - error!("Error in websocket connection: {e}"); + match hyper::upgrade::on(req).await { + Ok(upgraded) => { + let ws_stream = WebSocketStream::from_raw_socket( + TokioIo::new(upgraded), + tungstenite::protocol::Role::Server, + None, + ) + .await; + if let Err(e) = serve_websocket(ws_stream, gateway_addr).await { + error!("Error in websocket connection: {e}"); + } + } + Err(e) => error!("WebSocket upgrade error: {}", e), } }); - let (parts, body) = res.into_parts(); - let boxbody = body.map_err(|never| match never {}).boxed(); - Ok(Response::from_parts(parts, boxbody)) + + let res = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(UPGRADE, "websocket") + .header(CONNECTION, "Upgrade") + .header(SEC_WEBSOCKET_ACCEPT, accept_key) + .body(empty()) + .map_err(|e| Error::InternalServerError(Box::new(e)))?; + + Ok(res) } /// Stream WebSocket frames from the client to the gateway server's TCP socket and vice versa. -#[instrument] -async fn serve_websocket( - websocket: HyperWebsocket, +#[instrument(skip(ws_stream))] +async fn serve_websocket( + ws_stream: WebSocketStream, gateway_addr: SocketAddr, -) -> Result<(), Box> { +) -> Result<(), Box> +where + S: AsyncRead + AsyncWrite + Unpin, +{ let mut tcp_stream = tokio::net::TcpStream::connect(gateway_addr).await?; - let mut ws_io = WsIo::new(websocket.await?); + let mut ws_io = WsIo::new(ws_stream); let (_, _) = tokio::io::copy_bidirectional(&mut ws_io, &mut tcp_stream).await?; Ok(()) } diff --git a/ohttp-relay/src/lib.rs b/ohttp-relay/src/lib.rs index 1af4828f7..431ee7acb 100644 --- a/ohttp-relay/src/lib.rs +++ b/ohttp-relay/src/lib.rs @@ -1,6 +1,9 @@ +use std::fmt::Debug; use std::net::SocketAddr; +use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; +use std::task::{Context, Poll}; pub(crate) use gateway_prober::Prober; pub use gateway_uri::GatewayUri; @@ -13,13 +16,13 @@ use hyper::header::{ ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_LENGTH, CONTENT_TYPE, }; use hyper::server::conn::http1; -use hyper::service::service_fn; use hyper::{Method, Request, Response}; use hyper_rustls::builderstates::WantsSchemes; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use hyper_util::client::legacy::connect::HttpConnector; use hyper_util::client::legacy::Client; use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::service::TowerToHyperService; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, UnixListener}; use tokio_util::net::Listener; @@ -31,6 +34,9 @@ mod gateway_prober; #[cfg(feature = "_test-util")] pub mod gateway_prober; mod gateway_uri; +pub mod sentinel; +pub use sentinel::{InvalidHeader, SentinelTag}; + use crate::error::{BoxError, Error}; #[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))] @@ -39,6 +45,7 @@ pub mod bootstrap; pub const DEFAULT_PORT: u16 = 3000; pub const OHTTP_RELAY_HOST: HeaderValue = HeaderValue::from_static("0.0.0.0"); pub const EXPECTED_MEDIA_TYPE: HeaderValue = HeaderValue::from_static("message/ohttp-req"); +pub const DEFAULT_GATEWAY: &str = "https://payjo.in"; #[instrument] pub async fn listen_tcp( @@ -48,7 +55,7 @@ pub async fn listen_tcp( let addr = SocketAddr::from(([0, 0, 0, 0], port)); let listener = TcpListener::bind(addr).await?; println!("OHTTP relay listening on tcp://{}", addr); - ohttp_relay(listener, RelayConfig::new_with_default_client(gateway_origin)).await + ohttp_relay(listener, RelayConfig::new_with_default_client(gateway_origin, None)).await } #[instrument] @@ -58,7 +65,7 @@ pub async fn listen_socket( ) -> Result>, BoxError> { let listener = UnixListener::bind(socket_path)?; info!("OHTTP relay listening on socket: {}", socket_path); - ohttp_relay(listener, RelayConfig::new_with_default_client(gateway_origin)).await + ohttp_relay(listener, RelayConfig::new_with_default_client(gateway_origin, None)).await } #[cfg(feature = "_test-util")] @@ -69,7 +76,7 @@ pub async fn listen_tcp_on_free_port( let listener = tokio::net::TcpListener::bind("[::]:0").await?; let port = listener.local_addr()?.port(); println!("OHTTP relay binding to port {}", listener.local_addr()?); - let config = RelayConfig::new(default_gateway, root_store); + let config = RelayConfig::new(default_gateway, root_store, None); let handle = ohttp_relay(listener, config).await?; Ok((port, handle)) } @@ -79,17 +86,76 @@ struct RelayConfig { default_gateway: GatewayUri, client: HttpClient, prober: Prober, + sentinel_tag: Option, } impl RelayConfig { - fn new_with_default_client(default_gateway: GatewayUri) -> Self { - Self::new(default_gateway, HttpClient::default()) + fn new_with_default_client( + default_gateway: GatewayUri, + sentinel_tag: Option, + ) -> Self { + Self::new(default_gateway, HttpClient::default(), sentinel_tag) } - fn new(default_gateway: GatewayUri, into_client: impl Into) -> Self { + fn new( + default_gateway: GatewayUri, + into_client: impl Into, + sentinel_tag: Option, + ) -> Self { let client = into_client.into(); let prober = Prober::new_with_client(client.clone()); - RelayConfig { default_gateway, client, prober } + RelayConfig { default_gateway, client, prober, sentinel_tag } + } +} + +#[derive(Clone)] +pub struct Service { + config: Arc, +} + +impl Service { + fn from_config(config: Arc) -> Self { Self { config } } + + pub async fn new(sentinel_tag: SentinelTag) -> Self { + // The default gateway is hardcoded because it is obsolete and required only for backwards + // compatibility. + // The new mechanism for specifying a custom gateway is via RFC 9540 using + // `/.well-known/ohttp-gateway` request paths. + let gateway_origin = GatewayUri::from_str(DEFAULT_GATEWAY).expect("valid gateway uri"); + let config = RelayConfig::new_with_default_client(gateway_origin, Some(sentinel_tag)); + config.prober.assert_opt_in(&config.default_gateway).await; + Self { config: Arc::new(config) } + } + + #[cfg(feature = "_test-util")] + pub async fn new_with_roots( + root_store: rustls::RootCertStore, + sentinel_tag: SentinelTag, + ) -> Self { + let gateway_origin = GatewayUri::from_str(DEFAULT_GATEWAY).expect("valid gateway uri"); + let config = RelayConfig::new(gateway_origin, root_store, Some(sentinel_tag)); + config.prober.assert_opt_in(&config.default_gateway).await; + Self { config: Arc::new(config) } + } +} + +impl tower::Service> for Service +where + B: hyper::body::Body + Send + Debug + 'static, + B::Error: Into, +{ + type Response = Response>; + type Error = hyper::Error; + type Future = + Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let config = self.config.clone(); + Box::pin(async move { serve_ohttp_relay(req, &config).await }) } } @@ -144,13 +210,12 @@ where let handle = tokio::spawn(async move { while let Ok((stream, _)) = listener.accept().await { - let config = config.clone(); + let service = Service::from_config(config.clone()); let io = TokioIo::new(stream); tokio::spawn(async move { - if let Err(err) = http1::Builder::new() - .serve_connection(io, service_fn(|req| serve_ohttp_relay(req, &config))) - .with_upgrades() - .await + let hyper_service = TowerToHyperService::new(service); + if let Err(err) = + http1::Builder::new().serve_connection(io, hyper_service).with_upgrades().await { error!("Error serving connection: {:?}", err); } @@ -163,22 +228,32 @@ where } #[instrument] -async fn serve_ohttp_relay( - req: Request, +async fn serve_ohttp_relay( + req: Request, config: &RelayConfig, -) -> Result>, hyper::Error> { - let mut res = match (req.method(), req.uri().path()) { +) -> Result>, hyper::Error> +where + B: hyper::body::Body + Send + Debug + 'static, + B::Error: Into, +{ + let method = req.method().clone(); + let path = req.uri().path(); + let authority = req.uri().authority().cloned(); + + let mut res = match (&method, path) { (&Method::OPTIONS, _) => Ok(handle_preflight()), (&Method::GET, "/health") => Ok(health_check().await), - (&Method::POST, _) => match parse_gateway_uri(&req, config).await { + (&Method::POST, _) => match parse_gateway_uri(&method, path, authority, config).await { Ok(gateway_uri) => handle_ohttp_relay(req, config, gateway_uri).await, Err(e) => Err(e), }, #[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))] - (&Method::GET, _) | (&Method::CONNECT, _) => match parse_gateway_uri(&req, config).await { - Ok(gateway_uri) => crate::bootstrap::handle_ohttp_keys(req, gateway_uri).await, - Err(e) => Err(e), - }, + (&Method::GET, _) | (&Method::CONNECT, _) => { + match parse_gateway_uri(&method, path, authority, config).await { + Ok(gateway_uri) => crate::bootstrap::handle_ohttp_keys(req, gateway_uri).await, + Err(e) => Err(e), + } + } _ => Err(Error::NotFound), } .unwrap_or_else(|e| e.to_response()); @@ -187,14 +262,16 @@ async fn serve_ohttp_relay( } async fn parse_gateway_uri( - req: &Request, + method: &Method, + path: &str, + authority: Option, config: &RelayConfig, ) -> Result { // for POST and GET (websockets), the gateway URI is provided in the path // for CONNECT requests, just an authority is provided, and we assume HTTPS - let gateway_uri = match req.method() { - &Method::CONNECT => req.uri().authority().cloned().map(GatewayUri::from), - _ => parse_gateway_uri_from_path(req.uri().path(), &config.default_gateway).ok(), + let gateway_uri = match method { + &Method::CONNECT => authority.map(GatewayUri::from), + _ => parse_gateway_uri_from_path(path, &config.default_gateway).ok(), } .ok_or_else(|| Error::BadRequest("Invalid gateway".to_string()))?; @@ -247,12 +324,17 @@ fn handle_preflight() -> Response> { async fn health_check() -> Response> { Response::new(empty()) } #[instrument] -async fn handle_ohttp_relay( - req: Request, +async fn handle_ohttp_relay( + req: Request, config: &RelayConfig, gateway: GatewayUri, -) -> Result>, Error> { - let fwd_req = into_forward_req(req, gateway)?; +) -> Result>, Error> +where + B: hyper::body::Body + Send + Debug + 'static, + B::Error: Into, +{ + let fwd_req = into_forward_req(req, gateway, config.sentinel_tag.as_ref()).await?; + forward_request(fwd_req, config).await.map(|res| { let (parts, body) = res.into_parts(); let boxed_body = BoxBody::new(body); @@ -262,10 +344,15 @@ async fn handle_ohttp_relay( /// Convert an incoming request into a request to forward to the target gateway server. #[instrument] -fn into_forward_req( - req: Request, +async fn into_forward_req( + req: Request, gateway_origin: GatewayUri, -) -> Result>, Error> { + sentinel_tag: Option<&SentinelTag>, +) -> Result>, Error> +where + B: hyper::body::Body + Send + Debug + 'static, + B::Error: Into, +{ let (head, body) = req.into_parts(); if head.method != hyper::Method::POST { @@ -285,7 +372,14 @@ fn into_forward_req( builder = builder.header(CONTENT_LENGTH, content_length); } - builder.body(BoxBody::new(body)).map_err(|e| Error::InternalServerError(Box::new(e))) + let bytes = + body.collect().await.map_err(|e| Error::BadRequest(e.into().to_string()))?.to_bytes(); + + if let Some(tag) = sentinel_tag { + builder = builder.header(sentinel::HEADER_NAME, tag.to_header_value()); + } + + builder.body(full(bytes)).map_err(|e| Error::InternalServerError(Box::new(e))) } #[instrument] diff --git a/ohttp-relay/src/main.rs b/ohttp-relay/src/main.rs index ea7811e7e..d27854985 100644 --- a/ohttp-relay/src/main.rs +++ b/ohttp-relay/src/main.rs @@ -1,6 +1,6 @@ use std::str::FromStr; -use ohttp_relay::{GatewayUri, DEFAULT_PORT}; +use ohttp_relay::{GatewayUri, DEFAULT_GATEWAY, DEFAULT_PORT}; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{fmt, EnvFilter}; @@ -12,11 +12,20 @@ async fn main() -> Result<(), Box> { .expect("Failed to install default crypto provider"); init_tracing(); + + // If GATEWAY_URI is set, it must be payjo.in + if let Ok(gateway_uri) = std::env::var("GATEWAY_URI") { + if gateway_uri != DEFAULT_GATEWAY { + panic!( + "GATEWAY_URI is set to '{}' but only '{}' is supported. This environment variable is being deprecated in favor of gateway opt-in via RFC 9540.", + gateway_uri, DEFAULT_GATEWAY + ); + } + } + let port_env = std::env::var("PORT"); let unix_socket_env = std::env::var("UNIX_SOCKET"); - let gateway_origin_str = std::env::var("GATEWAY_ORIGIN").expect("GATEWAY_ORIGIN is required"); - let gateway_origin = - GatewayUri::from_str(&gateway_origin_str).expect("Invalid GATEWAY_ORIGIN URI"); + let gateway_origin = GatewayUri::from_str(DEFAULT_GATEWAY).expect("valid gateway uri"); match (port_env, unix_socket_env) { (Ok(_), Ok(_)) => panic!( diff --git a/ohttp-relay/src/sentinel.rs b/ohttp-relay/src/sentinel.rs new file mode 100644 index 000000000..1ba3107af --- /dev/null +++ b/ohttp-relay/src/sentinel.rs @@ -0,0 +1,75 @@ +use hex::{DisplayHex, FromHex}; + +/// HTTP header name for the sentinel tag. +pub const HEADER_NAME: &str = "x-pj-sentinel"; + +#[derive(Debug)] +pub struct InvalidHeader; + +impl std::fmt::Display for InvalidHeader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "malformed sentinel header") + } +} + +impl std::error::Error for InvalidHeader {} + +/// A random 32-byte tag shared between relay and gateway for same-instance detection. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SentinelTag([u8; 32]); + +impl SentinelTag { + /// Creates a new sentinel tag from raw bytes. + pub fn new(bytes: [u8; 32]) -> Self { Self(bytes) } + + /// Returns the tag as a hex string for use in HTTP headers. + pub fn to_header_value(&self) -> String { self.0.to_lower_hex_string() } +} + +/// Verifies a sentinel header value against a tag. +/// +/// Note that incoming requests should be **rejected** when this function returns `Ok(true)`, +/// as that would indicate the relay and gateway are the same instance. +pub fn verify(tag: &SentinelTag, header_value: &str) -> Result { + let header_bytes = <[u8; 32]>::from_hex(header_value).map_err(|_| InvalidHeader)?; + Ok(tag.0 == header_bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn same_tag_matches() { + let tag = SentinelTag::new([0u8; 32]); + let header = tag.to_header_value(); + assert!(verify(&tag, &header).unwrap(), "same tag should match"); + } + + #[test] + fn different_tag_does_not_match() { + let tag1 = SentinelTag::new([0u8; 32]); + let tag2 = SentinelTag::new([1u8; 32]); + let header = tag1.to_header_value(); + assert!(!verify(&tag2, &header).unwrap(), "different tag should not match"); + } + + #[test] + fn malformed_header_returns_err() { + let tag = SentinelTag::new([0u8; 32]); + assert!(verify(&tag, "invalid").is_err(), "non-hex should error"); + assert!(verify(&tag, "abcd").is_err(), "wrong length should error"); + assert!(verify(&tag, "zz").is_err(), "invalid hex chars should error"); + } + + #[test] + fn header_format() { + let tag = SentinelTag::new([0xab; 32]); + let header = tag.to_header_value(); + + // Should be 64 hex characters (32 bytes) + assert_eq!(header.len(), 64, "header should be 64 hex characters"); + assert!(header.chars().all(|c| c.is_ascii_hexdigit()), "header should be valid hex"); + assert_eq!(header, "ab".repeat(32), "header should match expected hex"); + } +} diff --git a/payjoin-directory/Cargo.toml b/payjoin-directory/Cargo.toml index 35c8704fc..9e916db86 100644 --- a/payjoin-directory/Cargo.toml +++ b/payjoin-directory/Cargo.toml @@ -27,8 +27,9 @@ config = "0.15.14" futures = "0.3.31" http-body-util = "0.1.3" hyper = { version = "1.6.0", features = ["http1", "server"] } -hyper-util = { version = "0.1.16", features = ["tokio"] } +hyper-util = { version = "0.1.16", features = ["tokio", "service"] } ohttp = { package = "bitcoin-ohttp", version = "0.6.0" } +ohttp-relay = { path = "../ohttp-relay" } payjoin = { version = "1.0.0-rc.1", features = [ "directory", ], default-features = false } @@ -41,6 +42,7 @@ tokio-rustls = { version = "0.26.2", features = [ ], default-features = false, optional = true } tokio-rustls-acme = { version = "0.7.1", optional = true } tokio-stream = { version = "0.1.17", features = ["net"] } +tower = "0.5" tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index d571e6746..ee2264056 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -1,6 +1,7 @@ use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; +use std::task::{Context, Poll}; use anyhow::Result; use futures::StreamExt; @@ -11,6 +12,7 @@ use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE}; use hyper::server::conn::http1; use hyper::{Method, Request, Response, StatusCode, Uri}; use hyper_util::rt::TokioIo; +use hyper_util::service::TowerToHyperService; use payjoin::directory::{ShortId, ShortIdError, ENCAPSULATED_MESSAGE_BYTES}; use tokio::net::TcpListener; #[cfg(feature = "acme")] @@ -22,6 +24,8 @@ use tracing::{debug, error, trace, warn}; pub use crate::db::files::Db as FilesDb; use crate::db::Db; pub mod key_config; +use ohttp_relay::SentinelTag; + pub use crate::key_config::*; use crate::metrics::Metrics; @@ -66,23 +70,37 @@ pub struct Service { db: D, ohttp: ohttp::Server, metrics: Metrics, + sentinel_tag: Option, } -impl hyper::service::Service> for Service { +impl tower::Service> for Service +where + B: Body + Send + 'static, + B::Error: Into, +{ type Response = Response>; type Error = anyhow::Error; type Future = Pin> + Send>>; - fn call(&self, req: Request) -> Self::Future { + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { let this = self.clone(); Box::pin(async move { this.serve_request(req).await }) } } impl Service { - pub fn new(db: D, ohttp: ohttp::Server, metrics: Metrics) -> Self { - Self { db, ohttp, metrics } + pub fn new( + db: D, + ohttp: ohttp::Server, + metrics: Metrics, + sentinel_tag: Option, + ) -> Self { + Self { db, ohttp, metrics, sentinel_tag } } #[cfg(feature = "_manual-tls")] @@ -106,8 +124,9 @@ impl Service { return; } }; + let hyper_service = TowerToHyperService::new(service); if let Err(err) = http1::Builder::new() - .serve_connection(TokioIo::new(tls_stream), service) + .serve_connection(TokioIo::new(tls_stream), hyper_service) .with_upgrades() .await { @@ -154,35 +173,49 @@ impl Service { } } - // TODO https://docs.rs/tower/0.4.13/tower/make/trait.MakeService.html async fn serve_connection(&self, stream: I) where I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static, { self.metrics.record_connection(); - if let Err(err) = - http1::Builder::new().serve_connection(TokioIo::new(stream), self).with_upgrades().await + let hyper_service = TowerToHyperService::new(self.clone()); + if let Err(err) = http1::Builder::new() + .serve_connection(TokioIo::new(stream), hyper_service) + .with_upgrades() + .await { error!("Error serving connection: {:?}", err); } } - async fn serve_request( + async fn serve_request( &self, - req: Request, - ) -> Result>> { + req: Request, + ) -> Result>> + where + B: Body + Send + 'static, + B::Error: Into, + { let path = req.uri().path().to_string(); let query = req.uri().query().unwrap_or_default().to_string(); let (parts, body) = req.into_parts(); let path_segments: Vec<&str> = path.split('/').collect(); debug!("Service::serve_request: {:?}", &path_segments); + + let sentinel_header = parts + .headers + .get(ohttp_relay::sentinel::HEADER_NAME) + .and_then(|v| v.to_str().ok()) + .map(String::from); + let mut response = match (parts.method, path_segments.as_slice()) { (Method::POST, ["", ".well-known", "ohttp-gateway"]) => - self.handle_ohttp_gateway(body).await, + self.handle_ohttp_gateway(body, sentinel_header.as_deref()).await, (Method::GET, ["", ".well-known", "ohttp-gateway"]) => self.handle_ohttp_gateway_get(&query).await, - (Method::POST, ["", ""]) => self.handle_ohttp_gateway(body).await, + (Method::POST, ["", ""]) => + self.handle_ohttp_gateway(body, sentinel_header.as_deref()).await, (Method::GET, ["", "ohttp-keys"]) => self.get_ohttp_keys().await, (Method::POST, ["", id]) => self.post_fallback_v1(id, query, body).await, (Method::GET, ["", "health"]) => health_check().await, @@ -197,13 +230,43 @@ impl Service { Ok(response) } - async fn handle_ohttp_gateway( + async fn handle_ohttp_gateway( &self, - body: Incoming, - ) -> Result>, HandlerError> { - // decapsulate - let ohttp_body = - body.collect().await.map_err(|e| HandlerError::BadRequest(e.into()))?.to_bytes(); + body: B, + sentinel_header: Option<&str>, + ) -> Result>, HandlerError> + where + B: Body + Send + 'static, + B::Error: Into, + { + let ohttp_body = body + .collect() + .await + .map_err(|e| HandlerError::BadRequest(anyhow::anyhow!(e.into())))? + .to_bytes(); + + // Best-effort validation that the relay and gateway aren't on the same + // payjoin-service instance + if let Some(tag) = &self.sentinel_tag { + if let Some(header_value) = sentinel_header { + match ohttp_relay::sentinel::verify(tag, header_value) { + Ok(false) => {} // Allow + Ok(true) => { + warn!("Rejected OHTTP request from same-instance relay"); + return Err(HandlerError::Forbidden(anyhow::anyhow!( + "Relay and gateway must be operated by different entities" + ))); + } + Err(_) => { + warn!("Rejected OHTTP request with malformed sentinel header"); + return Err(HandlerError::BadRequest(anyhow::anyhow!( + "Malformed sentinel header" + ))); + } + } + } + } + let (bhttp_req, res_ctx) = self .ohttp .decapsulate(&ohttp_body) @@ -327,12 +390,16 @@ impl Service { } } - async fn post_fallback_v1( + async fn post_fallback_v1( &self, id: &str, query: String, - body: impl Body, - ) -> Result>, HandlerError> { + body: B, + ) -> Result>, HandlerError> + where + B: Body + Send + 'static, + B::Error: Into, + { trace!("Post fallback v1"); let none_response = Response::builder() .status(StatusCode::SERVICE_UNAVAILABLE) @@ -549,6 +616,7 @@ enum HandlerError { SenderGone(anyhow::Error), OhttpKeyRejection(anyhow::Error), BadRequest(anyhow::Error), + Forbidden(anyhow::Error), } impl HandlerError { @@ -581,6 +649,10 @@ impl HandlerError { warn!("Bad request: {}", e); *res.status_mut() = StatusCode::BAD_REQUEST } + HandlerError::Forbidden(e) => { + warn!("Forbidden: {}", e); + *res.status_mut() = StatusCode::FORBIDDEN + } }; res diff --git a/payjoin-directory/src/main.rs b/payjoin-directory/src/main.rs index 6f00f6598..60f47095e 100644 --- a/payjoin-directory/src/main.rs +++ b/payjoin-directory/src/main.rs @@ -30,7 +30,7 @@ async fn main() -> Result<(), BoxError> { .await .expect("Failed to initialize persistent storage"); - let service = Service::new(db, ohttp.into(), metrics); + let service = Service::new(db, ohttp.into(), metrics, None); // Start metrics server in the background if let Some(addr) = config.metrics_listen_addr { diff --git a/payjoin-ffi/dart/native/Cargo.toml b/payjoin-ffi/dart/native/Cargo.toml index 206b21716..14f53fd27 100644 --- a/payjoin-ffi/dart/native/Cargo.toml +++ b/payjoin-ffi/dart/native/Cargo.toml @@ -24,6 +24,8 @@ payjoin-ffi = { git = "https://github.com/payjoin/rust-payjoin.git", branch = "m payjoin-ffi = { path = "../.." } [patch.crates-io] +ohttp-relay = { path = "../../../ohttp-relay" } payjoin = { path = "../../../payjoin" } payjoin-directory = { path = "../../../payjoin-directory" } +payjoin-service = { path = "../../../payjoin-service" } payjoin-test-utils = { path = "../../../payjoin-test-utils" } diff --git a/payjoin-ffi/javascript/test-utils/Cargo.toml b/payjoin-ffi/javascript/test-utils/Cargo.toml index eed7a70a3..bfbdd48fb 100644 --- a/payjoin-ffi/javascript/test-utils/Cargo.toml +++ b/payjoin-ffi/javascript/test-utils/Cargo.toml @@ -18,12 +18,18 @@ path = "../../../payjoin-test-utils" [build-dependencies] napi-build = "=2.2.4" +[patch.crates-io.ohttp-relay] +path = "../../../ohttp-relay" + [patch.crates-io.payjoin] path = "../../../payjoin" [patch.crates-io.payjoin-directory] path = "../../../payjoin-directory" +[patch.crates-io.payjoin-service] +path = "../../../payjoin-service" + [patch.crates-io.payjoin-test-utils] path = "../../../payjoin-test-utils" diff --git a/payjoin-service/.gitignore b/payjoin-service/.gitignore new file mode 100644 index 000000000..773a6df9b --- /dev/null +++ b/payjoin-service/.gitignore @@ -0,0 +1 @@ +*.dat diff --git a/payjoin-service/Cargo.toml b/payjoin-service/Cargo.toml new file mode 100644 index 000000000..0aa70bd96 --- /dev/null +++ b/payjoin-service/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "payjoin-service" +version = "0.0.1" +description = "Unified Payjoin Directory and OHTTP Relay service" +repository = "https://github.com/payjoin/rust-payjoin/tree/master/payjoin-service" +keywords = ["bip77", "bitcoin", "ohttp", "payjoin", "privacy"] +categories = [ + "cryptography::cryptocurrencies", + "network-programming", + "web-programming", +] +license = "MITNFA" +edition = "2021" +rust-version = "1.85.0" + +[features] +default = [] +_manual-tls = ["dep:axum-server", "dep:rustls", "ohttp-relay/_test-util"] + +[dependencies] +anyhow = "1.0" +axum = "0.8" +axum-server = { version = "0.7", features = [ + "tls-rustls-no-provider", +], optional = true } +clap = { version = "4.5", features = ["derive", "env"] } +config = "0.15" +ohttp-relay = { path = "../ohttp-relay", features = ["bootstrap"] } +payjoin-directory = { path = "../payjoin-directory" } +rand = "0.8" +rustls = { version = "0.23", default-features = false, features = [ + "ring", +], optional = true } +serde = { version = "1.0", features = ["derive"] } +tokio = { version = "1.47", features = ["full"] } +tower = "0.5" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/payjoin-service/README.md b/payjoin-service/README.md new file mode 100644 index 000000000..a30ce097a --- /dev/null +++ b/payjoin-service/README.md @@ -0,0 +1,5 @@ +# payjoin-service + +Unified Payjoin Directory and OHTTP Relay service. Combines [payjoin-directory](../payjoin-directory/README.md) and [ohttp-relay](../ohttp-relay/README.md) into a single binary. + +Note that this binary is under active development and thus the CLI and configuration file may be unstable. diff --git a/payjoin-service/src/cli.rs b/payjoin-service/src/cli.rs new file mode 100644 index 000000000..f7da25a4b --- /dev/null +++ b/payjoin-service/src/cli.rs @@ -0,0 +1,10 @@ +use std::path::PathBuf; + +use clap::Parser; + +#[derive(Debug, Parser)] +#[command(version)] +pub struct Args { + #[arg(short, long)] + pub config: Option, +} diff --git a/payjoin-service/src/config.rs b/payjoin-service/src/config.rs new file mode 100644 index 000000000..c01c065e4 --- /dev/null +++ b/payjoin-service/src/config.rs @@ -0,0 +1,41 @@ +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use config::{ConfigError, File}; +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct Config { + pub port: u16, + pub storage_dir: PathBuf, + #[serde(deserialize_with = "deserialize_duration_secs")] + pub timeout: Duration, +} + +impl Default for Config { + fn default() -> Self { + Self { port: 8080, storage_dir: PathBuf::from("./data"), timeout: Duration::from_secs(30) } + } +} + +fn deserialize_duration_secs<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let secs = u64::deserialize(deserializer)?; + Ok(Duration::from_secs(secs)) +} + +impl Config { + pub fn from_file(path: &Path) -> Result { + config::Config::builder() + // Add from optional config file + .add_source(File::from(path).required(false)) + // Add from the environment (with a prefix of PJ) + // e.g. `PJ_PORT=9090` would set the `port`. + .add_source(config::Environment::with_prefix("PJ")) + .build()? + .try_deserialize() + } +} diff --git a/payjoin-service/src/lib.rs b/payjoin-service/src/lib.rs new file mode 100644 index 000000000..99e6fecaa --- /dev/null +++ b/payjoin-service/src/lib.rs @@ -0,0 +1,154 @@ +use std::net::{Ipv6Addr, SocketAddr}; + +use axum::extract::State; +use axum::http::Method; +use axum::response::{IntoResponse, Response}; +use axum::Router; +use config::Config; +use ohttp_relay::SentinelTag; +use rand::Rng; +use tower::Service; +use tracing::info; + +pub mod cli; +pub mod config; + +#[derive(Clone)] +struct Services { + directory: payjoin_directory::Service, + relay: ohttp_relay::Service, +} + +pub async fn serve(config: Config) -> anyhow::Result<()> { + let sentinel_tag = generate_sentinel_tag(); + + let services = Services { + directory: init_directory(&config, Some(sentinel_tag)).await?, + relay: ohttp_relay::Service::new(sentinel_tag).await, + }; + let app = Router::new().fallback(route_request).with_state(services); + + let addr = SocketAddr::from((Ipv6Addr::UNSPECIFIED, config.port)); + let listener = tokio::net::TcpListener::bind(addr).await?; + info!("Payjoin service listening on {}", addr); + axum::serve(listener, app).await?; + + Ok(()) +} + +/// Serves payjoin-service with manual TLS configuration. +/// +/// Binds to `config.port` (use 0 to let the OS assign a free port) and returns +/// the actual bound port along with a task handle. +/// +/// If `tls_config` is provided, the server will use TLS for incoming connections. +/// The `root_store` is used for outgoing relay connections to the gateway. +#[cfg(feature = "_manual-tls")] +pub async fn serve_manual_tls( + config: Config, + tls_config: Option, + root_store: rustls::RootCertStore, +) -> anyhow::Result<(u16, tokio::task::JoinHandle>)> { + let sentinel_tag = generate_sentinel_tag(); + + let services = Services { + directory: init_directory(&config, Some(sentinel_tag)).await?, + relay: ohttp_relay::Service::new_with_roots(root_store, sentinel_tag).await, + }; + let app = Router::new().fallback(route_request).with_state(services); + + let addr = SocketAddr::from((Ipv6Addr::UNSPECIFIED, config.port)); + let listener = tokio::net::TcpListener::bind(addr).await?; + let port = listener.local_addr()?.port(); + + let handle = match tls_config { + Some(tls) => { + info!("Payjoin service listening on port {} with TLS", port); + tokio::spawn(async move { + axum_server::from_tcp_rustls(listener.into_std()?, tls) + .serve(app.into_make_service()) + .await + .map_err(Into::into) + }) + } + None => { + info!("Payjoin service listening on port {} without TLS", port); + tokio::spawn(async move { axum::serve(listener, app).await.map_err(Into::into) }) + } + }; + + Ok((port, handle)) +} + +/// Generate random sentinel tag at startup. +/// The relay and directory share this tag in a best-effort attempt +/// at preventing collusion from the same instance. +fn generate_sentinel_tag() -> SentinelTag { SentinelTag::new(rand::thread_rng().gen()) } + +async fn init_directory( + config: &Config, + sentinel_tag: Option, +) -> anyhow::Result> { + let db = payjoin_directory::FilesDb::init(config.timeout, config.storage_dir.clone()).await?; + db.spawn_background_prune().await; + + let ohttp_keys_dir = config.storage_dir.join("ohttp-keys"); + let ohttp_config = init_ohttp_config(&ohttp_keys_dir)?; + let metrics = payjoin_directory::metrics::Metrics::new(); + + Ok(payjoin_directory::Service::new(db, ohttp_config.into(), metrics, sentinel_tag)) +} + +fn init_ohttp_config( + ohttp_keys_dir: &std::path::Path, +) -> anyhow::Result { + std::fs::create_dir_all(ohttp_keys_dir)?; + match payjoin_directory::read_server_config(ohttp_keys_dir) { + Ok(config) => Ok(config), + Err(_) => { + let config = payjoin_directory::gen_ohttp_server_config()?; + payjoin_directory::persist_new_key_config(config.clone(), ohttp_keys_dir)?; + Ok(config) + } + } +} + +async fn route_request( + State(mut services): State, + req: axum::extract::Request, +) -> Response { + if is_relay_request(&req) { + match services.relay.call(req).await { + Ok(res) => res.into_response(), + Err(e) => (axum::http::StatusCode::BAD_GATEWAY, e.to_string()).into_response(), + } + } else { + // The directory service handles all other requests (including 404) + match services.directory.call(req).await { + Ok(res) => res.into_response(), + Err(e) => + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), + } + } +} + +/// Determines if a request should be routed to the OHTTP relay service. +/// +/// Routing rules: +/// - `(OPTIONS, _)` => CORS preflight handling +/// - `(CONNECT, _)` => OHTTP bootstrap tunneling +/// - `(POST, "/")` => relay to default gateway (needed for backwards-compatibility only) +/// - `(POST, /http(s)://...)` => RFC 9540 opt-in gateway specified in path +/// - `(GET, /http(s)://...)` => OHTTP bootstrap via WebSocket with opt-in gateway +fn is_relay_request(req: &axum::extract::Request) -> bool { + let method = req.method(); + let path = req.uri().path(); + + match (method, path) { + (&Method::OPTIONS, _) | (&Method::CONNECT, _) | (&Method::POST, "/") => true, + (&Method::POST, p) | (&Method::GET, p) + if p.starts_with("/http://") || p.starts_with("/https://") => + true, + _ => false, + } +} diff --git a/payjoin-service/src/main.rs b/payjoin-service/src/main.rs new file mode 100644 index 000000000..d04269210 --- /dev/null +++ b/payjoin-service/src/main.rs @@ -0,0 +1,22 @@ +use clap::Parser; +use payjoin_service::{cli, config}; +use tracing_subscriber::filter::LevelFilter; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + init_tracing(); + + let args = cli::Args::parse(); + let config_path = args.config.unwrap_or_else(|| "config.toml".into()); + let config = config::Config::from_file(&config_path)?; + + payjoin_service::serve(config).await +} + +fn init_tracing() { + let env_filter = + EnvFilter::builder().with_default_directive(LevelFilter::INFO.into()).from_env_lossy(); + + tracing_subscriber::fmt().with_target(true).with_level(true).with_env_filter(env_filter).init(); +} diff --git a/payjoin-test-utils/Cargo.toml b/payjoin-test-utils/Cargo.toml index 17f2f947b..6fdc879b7 100644 --- a/payjoin-test-utils/Cargo.toml +++ b/payjoin-test-utils/Cargo.toml @@ -9,18 +9,18 @@ rust-version = "1.85" license = "MIT" [dependencies] +axum-server = { version = "0.7", features = ["tls-rustls-no-provider"] } bitcoin = { version = "0.32.7", features = ["base64"] } corepc-node = { version = "0.10.0", features = ["download", "29_0"] } http = "1.3.1" ohttp = { package = "bitcoin-ohttp", version = "0.6.0" } -ohttp-relay = { version = "0.0.11", features = ["_test-util"] } once_cell = "1.21.3" payjoin = { version = "1.0.0-rc.1", features = [ "io", "_manual-tls", "_test-utils", ] } -payjoin-directory = { version = "0.0.3", features = ["_manual-tls"] } +payjoin-service = { path = "../payjoin-service", features = ["_manual-tls"] } rcgen = "0.14.3" reqwest = { version = "0.12.23", default-features = false, features = [ "rustls-tls", diff --git a/payjoin-test-utils/src/lib.rs b/payjoin-test-utils/src/lib.rs index c3064b9b9..2beba4b11 100644 --- a/payjoin-test-utils/src/lib.rs +++ b/payjoin-test-utils/src/lib.rs @@ -1,9 +1,9 @@ -use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::result::Result; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; +use axum_server::tls_rustls::RustlsConfig; use bitcoin::{Amount, Psbt}; pub use corepc_node; // re-export for convenience use corepc_node::AddressType; @@ -18,7 +18,6 @@ use reqwest::{Client, ClientBuilder}; use rustls::pki_types::CertificateDer; use rustls::RootCertStore; use tempfile::tempdir; -use tokio::net::TcpListener; use tokio::task::JoinHandle; use tracing::Level; use tracing_subscriber::{EnvFilter, FmtSubscriber}; @@ -61,11 +60,9 @@ impl TestServices { let mut root_store = RootCertStore::empty(); root_store.add(CertificateDer::from(cert.cert.der().to_vec())).unwrap(); - let directory = init_directory(cert_key).await?; + let directory = init_directory(cert_key, root_store.clone()).await?; + let ohttp_relay = init_ohttp_relay(root_store).await?; - let gateway_origin = - ohttp_relay::GatewayUri::from_str(&format!("https://localhost:{}", directory.0))?; - let ohttp_relay = ohttp_relay::listen_tcp_on_free_port(gateway_origin, root_store).await?; let http_agent: Arc = Arc::new(http_agent(cert_der)?); Ok(Self { @@ -114,33 +111,55 @@ impl TestServices { pub async fn init_directory( local_cert_key: (Vec, Vec), + root_store: RootCertStore, ) -> std::result::Result< (u16, tokio::task::JoinHandle>), BoxSendSyncError, > { - let timeout = Duration::from_secs(2); - let ohttp_server = payjoin_directory::gen_ohttp_server_config()?; - - let metrics = payjoin_directory::metrics::Metrics::new(); let tempdir = tempdir()?; - let db = payjoin_directory::FilesDb::init(timeout, tempdir.path().to_path_buf()).await?; + let config = payjoin_service::config::Config { + port: 0, // let OS assign a free port + storage_dir: tempdir.path().to_path_buf(), + timeout: Duration::from_secs(2), + }; - let service = payjoin_directory::Service::new(db, ohttp_server.into(), metrics); + let tls_config = RustlsConfig::from_der(vec![local_cert_key.0], local_cert_key.1).await?; - let listener = bind_free_port().await?; - let port = listener.local_addr()?.port(); + let (port, handle) = payjoin_service::serve_manual_tls(config, Some(tls_config), root_store) + .await + .map_err(|e| e.to_string())?; let handle = tokio::spawn(async move { let _tempdir = tempdir; // keep the tempdir until the directory shuts down - service.serve_tls(listener, local_cert_key).await + handle.await.map_err(|e| e.to_string())?.map_err(|e| e.to_string().into()) }); Ok((port, handle)) } -async fn bind_free_port() -> Result { - let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); - TcpListener::bind(bind_addr).await +async fn init_ohttp_relay( + root_store: RootCertStore, +) -> std::result::Result< + (u16, tokio::task::JoinHandle>), + BoxSendSyncError, +> { + let tempdir = tempdir()?; + let config = payjoin_service::config::Config { + port: 0, // let OS assign a free port + storage_dir: tempdir.path().to_path_buf(), + timeout: Duration::from_secs(2), + }; + + let (port, handle) = payjoin_service::serve_manual_tls(config, None, root_store) + .await + .map_err(|e| e.to_string())?; + + let handle = tokio::spawn(async move { + let _tempdir = tempdir; // keep the tempdir until the relay shuts down + handle.await.map_err(|e| e.to_string())?.map_err(|e| e.to_string().into()) + }); + + Ok((port, handle)) } /// generate or get a DER encoded localhost cert and key.