diff --git a/payjoin/src/receive/error.rs b/payjoin/src/receive/error.rs index 5277f2d3c..fe30450e1 100644 --- a/payjoin/src/receive/error.rs +++ b/payjoin/src/receive/error.rs @@ -36,8 +36,13 @@ impl From for Error { } #[cfg(feature = "v2")] -impl From for Error { - fn from(e: crate::v2::Error) -> Self { Error::Server(Box::new(e)) } +impl From for Error { + fn from(e: crate::v2::HpkeError) -> Self { Error::Server(Box::new(e)) } +} + +#[cfg(feature = "v2")] +impl From for Error { + fn from(e: crate::v2::OhttpEncapsulationError) -> Self { Error::Server(Box::new(e)) } } /// Error that may occur when the request from sender is malformed. diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index 9fc58b0a7..c281c7fd9 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -57,7 +57,9 @@ pub(crate) enum InternalValidationError { FeeContributionPaysOutputSizeIncrease, FeeRateBelowMinimum, #[cfg(feature = "v2")] - V2(crate::v2::Error), + HpkeError(crate::v2::HpkeError), + #[cfg(feature = "v2")] + OhttpEncapsulation(crate::v2::OhttpEncapsulationError), #[cfg(feature = "v2")] Psbt(bitcoin::psbt::Error), } @@ -103,7 +105,9 @@ impl fmt::Display for ValidationError { FeeContributionPaysOutputSizeIncrease => write!(f, "fee contribution pays for additional outputs"), FeeRateBelowMinimum => write!(f, "the fee rate of proposed transaction is below minimum"), #[cfg(feature = "v2")] - V2(e) => write!(f, "v2 error: {}", e), + HpkeError(e) => write!(f, "v2 error: {}", e), + #[cfg(feature = "v2")] + OhttpEncapsulation(e) => write!(f, "Ohttp encapsulation error: {}", e), #[cfg(feature = "v2")] Psbt(e) => write!(f, "psbt error: {}", e), } @@ -144,7 +148,9 @@ impl std::error::Error for ValidationError { FeeContributionPaysOutputSizeIncrease => None, FeeRateBelowMinimum => None, #[cfg(feature = "v2")] - V2(error) => Some(error), + HpkeError(error) => Some(error), + #[cfg(feature = "v2")] + OhttpEncapsulation(error) => Some(error), #[cfg(feature = "v2")] Psbt(error) => Some(error), } @@ -177,7 +183,9 @@ pub(crate) enum InternalCreateRequestError { PrevTxOut(crate::psbt::PrevTxOutError), InputType(crate::input_type::InputTypeError), #[cfg(feature = "v2")] - V2(crate::v2::Error), + Hpke(crate::v2::HpkeError), + #[cfg(feature = "v2")] + OhttpEncapsulation(crate::v2::OhttpEncapsulationError), #[cfg(feature = "v2")] SubdirectoryNotBase64(bitcoin::base64::DecodeError), #[cfg(feature = "v2")] @@ -207,7 +215,9 @@ impl fmt::Display for CreateRequestError { PrevTxOut(e) => write!(f, "invalid previous transaction output: {}", e), InputType(e) => write!(f, "invalid input type: {}", e), #[cfg(feature = "v2")] - V2(e) => write!(f, "v2 error: {}", e), + Hpke(e) => write!(f, "v2 error: {}", e), + #[cfg(feature = "v2")] + OhttpEncapsulation(e) => write!(f, "v2 error: {}", e), #[cfg(feature = "v2")] SubdirectoryNotBase64(e) => write!(f, "subdirectory is not valid base64 error: {}", e), #[cfg(feature = "v2")] @@ -239,7 +249,9 @@ impl std::error::Error for CreateRequestError { PrevTxOut(error) => Some(error), InputType(error) => Some(error), #[cfg(feature = "v2")] - V2(error) => Some(error), + Hpke(error) => Some(error), + #[cfg(feature = "v2")] + OhttpEncapsulation(error) => Some(error), #[cfg(feature = "v2")] SubdirectoryNotBase64(error) => Some(error), #[cfg(feature = "v2")] diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index c8adf3664..ac3280736 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -341,14 +341,14 @@ impl RequestContext { self.min_fee_rate, )?; let body = crate::v2::encrypt_message_a(body, self.e, rs) - .map_err(InternalCreateRequestError::V2)?; + .map_err(InternalCreateRequestError::Hpke)?; let (body, ohttp_res) = crate::v2::ohttp_encapsulate( self.ohttp_keys.as_mut().ok_or(InternalCreateRequestError::MissingOhttpConfig)?, "POST", url.as_str(), Some(&body), ) - .map_err(InternalCreateRequestError::V2)?; + .map_err(InternalCreateRequestError::OhttpEncapsulation)?; log::debug!("ohttp_relay_url: {:?}", ohttp_relay); Ok(( Request { url: ohttp_relay, body }, @@ -586,9 +586,9 @@ impl ContextV2 { let mut res_buf = Vec::new(); response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; let mut res_buf = crate::v2::ohttp_decapsulate(self.ohttp_res, &res_buf) - .map_err(InternalValidationError::V2)?; + .map_err(InternalValidationError::OhttpEncapsulation)?; let psbt = crate::v2::decrypt_message_b(&mut res_buf, self.e) - .map_err(InternalValidationError::V2)?; + .map_err(InternalValidationError::HpkeError)?; if psbt.is_empty() { return Ok(None); } diff --git a/payjoin/src/uri.rs b/payjoin/src/uri.rs index 72be2c7f6..f9c19da7c 100644 --- a/payjoin/src/uri.rs +++ b/payjoin/src/uri.rs @@ -270,8 +270,8 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { let config_bytes = bitcoin::base64::decode_config(&*base64_config, bitcoin::base64::URL_SAFE) .map_err(InternalPjParseError::NotBase64)?; - let config = - OhttpKeys::decode(&config_bytes).map_err(InternalPjParseError::BadOhttp)?; + let config = OhttpKeys::decode(&config_bytes) + .map_err(InternalPjParseError::DecodeOhttpKeys)?; self.ohttp = Some(config); Ok(bip21::de::ParamKind::Known) } @@ -336,7 +336,7 @@ impl std::fmt::Display for PjParseError { InternalPjParseError::NotBase64(_) => write!(f, "ohttp config is not valid base64"), InternalPjParseError::BadEndpoint(_) => write!(f, "Endpoint is not valid"), #[cfg(feature = "v2")] - InternalPjParseError::BadOhttp(_) => write!(f, "ohttp config is not valid"), + InternalPjParseError::DecodeOhttpKeys(_) => write!(f, "ohttp config is not valid"), InternalPjParseError::UnsecureEndpoint => { write!(f, "Endpoint scheme is not secure (https or onion)") } @@ -354,7 +354,7 @@ enum InternalPjParseError { NotBase64(bitcoin::base64::DecodeError), BadEndpoint(url::ParseError), #[cfg(feature = "v2")] - BadOhttp(crate::v2::Error), + DecodeOhttpKeys(ohttp::Error), UnsecureEndpoint, } diff --git a/payjoin/src/v2.rs b/payjoin/src/v2.rs index 3abd419d7..692486a62 100644 --- a/payjoin/src/v2.rs +++ b/payjoin/src/v2.rs @@ -37,12 +37,12 @@ pub fn encrypt_message_a( mut raw_msg: Vec, e_sec: SecretKey, s: PublicKey, -) -> Result, Error> { +) -> Result, HpkeError> { let secp = Secp256k1::new(); let e_pub = e_sec.public_key(&secp); let es = SharedSecret::new(&s, &e_sec); let cipher = ChaCha20Poly1305::new_from_slice(&es.secret_bytes()) - .map_err(|_| InternalError::InvalidKeyLength)?; + .map_err(|_| HpkeError::InvalidKeyLength)?; let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng); // key es encrypts only 1 message so 0 is unique let aad = &e_pub.serialize(); let msg = pad(&mut raw_msg)?; @@ -55,14 +55,17 @@ pub fn encrypt_message_a( } #[cfg(feature = "receive")] -pub fn decrypt_message_a(message_a: &[u8], s: SecretKey) -> Result<(Vec, PublicKey), Error> { +pub fn decrypt_message_a( + message_a: &[u8], + s: SecretKey, +) -> Result<(Vec, PublicKey), HpkeError> { // let message a = [pubkey/AD][nonce][authentication tag][ciphertext] - let e = PublicKey::from_slice(&message_a[..33])?; - let nonce = Nonce::from_slice(&message_a[33..45]); + let e = PublicKey::from_slice(message_a.get(..33).ok_or(HpkeError::PayloadTooShort)?)?; + let nonce = Nonce::from_slice(message_a.get(33..45).ok_or(HpkeError::PayloadTooShort)?); let es = SharedSecret::new(&e, &s); let cipher = ChaCha20Poly1305::new_from_slice(&es.secret_bytes()) - .map_err(|_| InternalError::InvalidKeyLength)?; - let c_t = &message_a[45..]; + .map_err(|_| HpkeError::InvalidKeyLength)?; + let c_t = message_a.get(45..).ok_or(HpkeError::PayloadTooShort)?; let aad = &e.serialize(); let payload = Payload { msg: c_t, aad }; let buffer = cipher.decrypt(nonce, payload)?; @@ -70,13 +73,13 @@ pub fn decrypt_message_a(message_a: &[u8], s: SecretKey) -> Result<(Vec, Pub } #[cfg(feature = "receive")] -pub fn encrypt_message_b(raw_msg: &mut Vec, re_pub: PublicKey) -> Result, Error> { +pub fn encrypt_message_b(raw_msg: &mut Vec, re_pub: PublicKey) -> Result, HpkeError> { // let message b = [pubkey/AD][nonce][authentication tag][ciphertext] let secp = Secp256k1::new(); let (e_sec, e_pub) = secp.generate_keypair(&mut OsRng); let ee = SharedSecret::new(&re_pub, &e_sec); let cipher = ChaCha20Poly1305::new_from_slice(&ee.secret_bytes()) - .map_err(|_| InternalError::InvalidKeyLength)?; + .map_err(|_| HpkeError::InvalidKeyLength)?; let nonce = Nonce::from_slice(&[0u8; 12]); // key es encrypts only 1 message so 0 is unique let aad = &e_pub.serialize(); let msg = pad(raw_msg)?; @@ -89,21 +92,24 @@ pub fn encrypt_message_b(raw_msg: &mut Vec, re_pub: PublicKey) -> Result Result, Error> { +pub fn decrypt_message_b(message_b: &mut [u8], e: SecretKey) -> Result, HpkeError> { // let message b = [pubkey/AD][nonce][authentication tag][ciphertext] - let re = PublicKey::from_slice(&message_b[..33])?; - let nonce = Nonce::from_slice(&message_b[33..45]); + let re = PublicKey::from_slice(message_b.get(..33).ok_or(HpkeError::PayloadTooShort)?)?; + let nonce = Nonce::from_slice(message_b.get(33..45).ok_or(HpkeError::PayloadTooShort)?); let ee = SharedSecret::new(&re, &e); let cipher = ChaCha20Poly1305::new_from_slice(&ee.secret_bytes()) - .map_err(|_| InternalError::InvalidKeyLength)?; - let payload = Payload { msg: &message_b[45..], aad: &re.serialize() }; + .map_err(|_| HpkeError::InvalidKeyLength)?; + let payload = Payload { + msg: message_b.get(45..).ok_or(HpkeError::PayloadTooShort)?, + aad: &re.serialize(), + }; let buffer = cipher.decrypt(nonce, payload)?; Ok(buffer) } -fn pad(msg: &mut Vec) -> Result<&[u8], Error> { +fn pad(msg: &mut Vec) -> Result<&[u8], HpkeError> { if msg.len() > PADDED_MESSAGE_BYTES { - return Err(Error(InternalError::PayloadTooLarge)); + return Err(HpkeError::PayloadTooLarge); } while msg.len() < PADDED_MESSAGE_BYTES { msg.push(0); @@ -111,87 +117,56 @@ fn pad(msg: &mut Vec) -> Result<&[u8], Error> { Ok(msg) } -/// Error that may occur when de/encrypting or de/capsulating a v2 message. -/// -/// This is currently opaque type because we aren't sure which variants will stay. -/// You can only display it. -#[derive(Debug)] -pub struct Error(InternalError); - +/// Error from de/encrypting a v2 Hybrid Public Key Encryption payload. #[derive(Debug)] -pub(crate) enum InternalError { - Ohttp(ohttp::Error), - Bhttp(bhttp::Error), - ParseUrl(url::ParseError), +pub enum HpkeError { Secp256k1(bitcoin::secp256k1::Error), ChaCha20Poly1305(chacha20poly1305::aead::Error), InvalidKeyLength, PayloadTooLarge, + PayloadTooShort, } -impl From for Error { - fn from(value: ohttp::Error) -> Self { Self(InternalError::Ohttp(value)) } +impl From for HpkeError { + fn from(value: bitcoin::secp256k1::Error) -> Self { Self::Secp256k1(value) } } -impl From for Error { - fn from(value: bhttp::Error) -> Self { Self(InternalError::Bhttp(value)) } -} - -impl From for Error { - fn from(value: url::ParseError) -> Self { Self(InternalError::ParseUrl(value)) } -} - -impl From for Error { - fn from(value: bitcoin::secp256k1::Error) -> Self { Self(InternalError::Secp256k1(value)) } -} - -impl From for Error { - fn from(value: chacha20poly1305::aead::Error) -> Self { - Self(InternalError::ChaCha20Poly1305(value)) - } +impl From for HpkeError { + fn from(value: chacha20poly1305::aead::Error) -> Self { Self::ChaCha20Poly1305(value) } } -impl fmt::Display for Error { +impl fmt::Display for HpkeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use InternalError::*; + use HpkeError::*; - match &self.0 { - Ohttp(e) => e.fmt(f), - Bhttp(e) => e.fmt(f), - ParseUrl(e) => e.fmt(f), + match &self { Secp256k1(e) => e.fmt(f), ChaCha20Poly1305(e) => e.fmt(f), InvalidKeyLength => write!(f, "Invalid Length"), PayloadTooLarge => write!(f, "Payload too large, max size is {} bytes", PADDED_MESSAGE_BYTES), + PayloadTooShort => write!(f, "Payload too small"), } } } -impl error::Error for Error { +impl error::Error for HpkeError { fn source(&self) -> Option<&(dyn error::Error + 'static)> { - use InternalError::*; + use HpkeError::*; - match &self.0 { - Ohttp(e) => Some(e), - Bhttp(e) => Some(e), - ParseUrl(e) => Some(e), + match &self { Secp256k1(e) => Some(e), - ChaCha20Poly1305(_) | InvalidKeyLength | PayloadTooLarge => None, + ChaCha20Poly1305(_) | InvalidKeyLength | PayloadTooLarge | PayloadTooShort => None, } } } -impl From for Error { - fn from(value: InternalError) -> Self { Self(value) } -} - pub fn ohttp_encapsulate( ohttp_keys: &mut ohttp::KeyConfig, method: &str, target_resource: &str, body: Option<&[u8]>, -) -> Result<(Vec, ohttp::ClientResponse), Error> { +) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { let ctx = ohttp::ClientRequest::from_config(ohttp_keys)?; let url = url::Url::parse(target_resource)?; let authority_bytes = url.host().map_or_else(Vec::new, |host| { @@ -220,19 +195,64 @@ pub fn ohttp_encapsulate( pub fn ohttp_decapsulate( res_ctx: ohttp::ClientResponse, ohttp_body: &[u8], -) -> Result, Error> { +) -> Result, OhttpEncapsulationError> { let bhttp_body = res_ctx.decapsulate(ohttp_body)?; let mut r = std::io::Cursor::new(bhttp_body); let response = bhttp::Message::read_bhttp(&mut r)?; Ok(response.content().to_vec()) } +/// Error from de/encapsulating an Oblivious HTTP request or response. +#[derive(Debug)] +pub enum OhttpEncapsulationError { + Ohttp(ohttp::Error), + Bhttp(bhttp::Error), + ParseUrl(url::ParseError), +} + +impl From for OhttpEncapsulationError { + fn from(value: ohttp::Error) -> Self { Self::Ohttp(value) } +} + +impl From for OhttpEncapsulationError { + fn from(value: bhttp::Error) -> Self { Self::Bhttp(value) } +} + +impl From for OhttpEncapsulationError { + fn from(value: url::ParseError) -> Self { Self::ParseUrl(value) } +} + +impl fmt::Display for OhttpEncapsulationError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use OhttpEncapsulationError::*; + + match &self { + Ohttp(e) => e.fmt(f), + Bhttp(e) => e.fmt(f), + ParseUrl(e) => e.fmt(f), + } + } +} + +impl error::Error for OhttpEncapsulationError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + use OhttpEncapsulationError::*; + + match &self { + Ohttp(e) => Some(e), + Bhttp(e) => Some(e), + ParseUrl(e) => Some(e), + } + } +} + #[derive(Debug, Clone)] pub struct OhttpKeys(pub ohttp::KeyConfig); impl OhttpKeys { - pub fn decode(bytes: &[u8]) -> Result { - ohttp::KeyConfig::decode(bytes).map(Self).map_err(Error::from) + /// Decode an OHTTP KeyConfig + pub fn decode(bytes: &[u8]) -> Result { + ohttp::KeyConfig::decode(bytes).map(Self) } }