diff --git a/payjoin/src/core/hpke.rs b/payjoin/src/core/hpke.rs index 0deb2f5d1..4bc1b36e0 100644 --- a/payjoin/src/core/hpke.rs +++ b/payjoin/src/core/hpke.rs @@ -171,7 +171,7 @@ impl<'de> serde::Deserialize<'de> for HpkePublicKey { /// Message A is sent from the sender to the receiver containing an Original PSBT payload pub fn encrypt_message_a( - body: Vec, + body: &[u8; PADDED_PLAINTEXT_A_LENGTH], reply_pk: &HpkePublicKey, receiver_pk: &HpkePublicKey, ) -> Result, HpkeError> { @@ -182,8 +182,6 @@ pub fn encrypt_message_a( INFO_A, &mut OsRng, )?; - let mut body = body; - pad_plaintext(&mut body, PADDED_PLAINTEXT_A_LENGTH)?; let mut plaintext = compressed_bytes_from_pubkey(reply_pk).to_vec(); plaintext.extend(body); let ciphertext = encryption_context.seal(&plaintext, &[])?; @@ -223,7 +221,7 @@ pub fn decrypt_message_a( /// Message B is sent from the receiver to the sender containing a Payjoin PSBT payload or an error pub fn encrypt_message_b( - mut plaintext: Vec, + body: &[u8; PADDED_PLAINTEXT_B_LENGTH], receiver_keypair: &HpkeKeyPair, sender_pk: &HpkePublicKey, ) -> Result, HpkeError> { @@ -237,8 +235,7 @@ pub fn encrypt_message_b( INFO_B, &mut OsRng, )?; - let plaintext: &[u8] = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_B_LENGTH)?; - let ciphertext = encryption_context.seal(plaintext, &[])?; + let ciphertext = encryption_context.seal(body, &[])?; let mut message_b = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec(); message_b.extend(&ciphertext); Ok(message_b) @@ -261,14 +258,6 @@ pub fn decrypt_message_b( Ok(plaintext) } -fn pad_plaintext(msg: &mut Vec, padded_length: usize) -> Result<&[u8], HpkeError> { - if msg.len() > padded_length { - return Err(HpkeError::PayloadTooLarge { actual: msg.len(), max: padded_length }); - } - msg.resize(padded_length, 0); - Ok(msg) -} - /// Error from de/encrypting a v2 Hybrid Public Key Encryption payload. #[derive(Debug, PartialEq, Eq)] pub enum HpkeError { @@ -304,7 +293,7 @@ impl fmt::Display for HpkeError { PayloadTooLarge { actual, max } => { write!( f, - "Plaintext too large, max size is {max} bytes, actual size is {actual} bytes" + "Plaintext length incorrect, expected size is {max} bytes, actual size is {actual} bytes" ) } PayloadTooShort => write!(f, "Payload too small"), @@ -332,13 +321,13 @@ mod test { #[test] fn message_a_round_trip() { - let mut plaintext = "foo".as_bytes().to_vec(); + let mut plaintext = [0u8; PADDED_PLAINTEXT_A_LENGTH]; let reply_keypair = HpkeKeyPair::gen_keypair(); let receiver_keypair = HpkeKeyPair::gen_keypair(); let message_a = encrypt_message_a( - plaintext.clone(), + &plaintext, reply_keypair.public_key(), receiver_keypair.public_key(), ) @@ -350,14 +339,12 @@ mod test { assert_eq!(decrypted.0.len(), PADDED_PLAINTEXT_A_LENGTH); - // decrypted plaintext is padded, so pad the expected plaintext - plaintext.resize(PADDED_PLAINTEXT_A_LENGTH, 0); assert_eq!(decrypted, (plaintext.to_vec(), reply_keypair.public_key().clone())); // ensure full plaintext round trips plaintext[PADDED_PLAINTEXT_A_LENGTH - 1] = 42; let message_a = encrypt_message_a( - plaintext.clone(), + &plaintext, reply_keypair.public_key(), receiver_keypair.public_key(), ) @@ -387,30 +374,17 @@ mod test { decrypt_message_a(&corrupted_message_a, receiver_keypair.secret_key().clone()), Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) ); - - plaintext.resize(PADDED_PLAINTEXT_A_LENGTH + 1, 0); - assert_eq!( - encrypt_message_a( - plaintext.clone(), - reply_keypair.public_key(), - receiver_keypair.public_key(), - ), - Err(HpkeError::PayloadTooLarge { - actual: PADDED_PLAINTEXT_A_LENGTH + 1, - max: PADDED_PLAINTEXT_A_LENGTH, - }) - ); } #[test] fn message_b_round_trip() { - let mut plaintext = "foo".as_bytes().to_vec(); + let mut plaintext = [0u8; PADDED_PLAINTEXT_B_LENGTH]; let reply_keypair = HpkeKeyPair::gen_keypair(); let receiver_keypair = HpkeKeyPair::gen_keypair(); let message_b = - encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()) + encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key()) .expect("encryption should work"); assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES); @@ -423,13 +397,11 @@ mod test { .expect("decryption should work"); assert_eq!(decrypted.len(), PADDED_PLAINTEXT_B_LENGTH); - // decrypted plaintext is padded, so pad the expected plaintext - plaintext.resize(PADDED_PLAINTEXT_B_LENGTH, 0); assert_eq!(decrypted, plaintext.to_vec()); plaintext[PADDED_PLAINTEXT_B_LENGTH - 1] = 42; let message_b = - encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()) + encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key()) .expect("encryption should work"); assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES); @@ -481,15 +453,6 @@ mod test { ), Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) ); - - plaintext.resize(PADDED_PLAINTEXT_B_LENGTH + 1, 0); - assert_eq!( - encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()), - Err(HpkeError::PayloadTooLarge { - actual: PADDED_PLAINTEXT_B_LENGTH + 1, - max: PADDED_PLAINTEXT_B_LENGTH - }) - ); } /// Test that the encrypted payloads are uniform. @@ -508,17 +471,17 @@ mod test { let receiver_keypair = HpkeKeyPair::gen_keypair(); let reply_keypair = HpkeKeyPair::gen_keypair(); - let plaintext_a = vec![0u8; PADDED_PLAINTEXT_A_LENGTH]; + let plaintext_a = [0u8; PADDED_PLAINTEXT_A_LENGTH]; let message_a = encrypt_message_a( - plaintext_a, + &plaintext_a, reply_keypair.public_key(), receiver_keypair.public_key(), ) .expect("encryption should work"); - let plaintext_b = vec![0u8; PADDED_PLAINTEXT_B_LENGTH]; + let plaintext_b = [0u8; PADDED_PLAINTEXT_B_LENGTH]; let message_b = - encrypt_message_b(plaintext_b, &receiver_keypair, sender_keypair.public_key()) + encrypt_message_b(&plaintext_b, &receiver_keypair, sender_keypair.public_key()) .expect("encryption should work"); messages_a.push(message_a); diff --git a/payjoin/src/core/receive/error.rs b/payjoin/src/core/receive/error.rs index a7fcdb5a8..42b3c2c34 100644 --- a/payjoin/src/core/receive/error.rs +++ b/payjoin/src/core/receive/error.rs @@ -3,6 +3,7 @@ use std::{error, fmt}; use crate::error_codes::ErrorCode::{ self, NotEnoughMoney, OriginalPsbtRejected, Unavailable, VersionUnsupported, }; +// use crate::hpke::HpkeError::PayloadTooLarge; /// The top-level error type for the payjoin receiver #[derive(Debug)] @@ -14,6 +15,10 @@ pub enum Error { /// /// e.g. database errors, network failures, wallet errors Implementation(crate::ImplementationError), + PayloadTooLarge { + actual: usize, + max: usize, + }, } impl From<&Error> for JsonReply { @@ -21,6 +26,7 @@ impl From<&Error> for JsonReply { match e { Error::Protocol(e) => e.into(), Error::Implementation(_) => JsonReply::new(Unavailable, "Receiver error"), + Error::PayloadTooLarge { actual: _, max: _ } => todo!("unimplemented"), } } } @@ -34,6 +40,12 @@ impl fmt::Display for Error { match self { Error::Protocol(e) => write!(f, "Protocol error: {e}"), Error::Implementation(e) => write!(f, "Implementation error: {e}"), + Error::PayloadTooLarge { actual, max } => { + write!( + f, + "Plaintext length incorrect, expected size is {max} bytes, actual size is {actual} bytes" + ) + } } } } @@ -43,6 +55,7 @@ impl error::Error for Error { match self { Error::Protocol(e) => e.source(), Error::Implementation(e) => e.source(), + Error::PayloadTooLarge { .. } => None, } } } diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index c0f43e000..09e6aa685 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -24,6 +24,7 @@ //! Note: Even fresh requests may be linkable via metadata (e.g. client IP, request timing), //! but request reuse makes correlation trivial for the relay. +use std::io::Write; use std::str::FromStr; use std::time::{Duration, SystemTime}; @@ -42,7 +43,9 @@ use super::error::{Error, InputContributionError}; use super::{ common, InternalPayloadError, JsonReply, OutputSubstitutionError, ProtocolError, SelectionError, }; -use crate::hpke::{decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey}; +use crate::hpke::{ + decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey, PADDED_PLAINTEXT_B_LENGTH, +}; use crate::ohttp::{ ohttp_encapsulate, process_get_res, process_post_res, OhttpEncapsulationError, OhttpKeys, }; @@ -1035,7 +1038,17 @@ impl Receiver { let payjoin_bytes = self.psbt.serialize(); let sender_mailbox = short_id_from_pubkey(e); target_resource = mailbox_endpoint(&self.session_context.directory, &sender_mailbox); - body = encrypt_message_b(payjoin_bytes, &self.session_context.receiver_key, e)?; + + let mut buf = [0u8; PADDED_PLAINTEXT_B_LENGTH]; + + (&mut &mut buf[..]).write_all(&payjoin_bytes).map_err(|e| { + assert!(e.kind() == std::io::ErrorKind::WriteZero); + Error::PayloadTooLarge { + actual: payjoin_bytes.len(), + max: PADDED_PLAINTEXT_B_LENGTH, + } + })?; + body = encrypt_message_b(&buf, &self.session_context.receiver_key, e)?; method = "POST"; } else { // Prepare v2 wrapped and backwards-compatible v1 payload diff --git a/payjoin/src/core/send/v2/mod.rs b/payjoin/src/core/send/v2/mod.rs index 7cd204686..373d7f21c 100644 --- a/payjoin/src/core/send/v2/mod.rs +++ b/payjoin/src/core/send/v2/mod.rs @@ -28,6 +28,8 @@ //! Note: Even fresh requests may be linkable via metadata (e.g. client IP, request timing), //! but request reuse makes correlation trivial for the relay. +use std::io::Write; + use bitcoin::hashes::{sha256, Hash}; use bitcoin::Address; pub use error::{CreateRequestError, EncapsulationError}; @@ -39,7 +41,7 @@ use url::Url; use super::error::BuildSenderError; use super::*; -use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeSecretKey}; +use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeSecretKey, PADDED_PLAINTEXT_A_LENGTH}; use crate::ohttp::{ohttp_encapsulate, process_get_res, process_post_res}; use crate::persist::{ MaybeFatalTransition, MaybeSuccessTransitionWithNoResults, NextStateTransition, @@ -286,7 +288,7 @@ impl Sender { let (request, ohttp_ctx) = extract_request( ohttp_relay, self.reply_key.clone(), - body, + &body, base_url, self.pj_param.receiver_pubkey().clone(), &mut ohttp_keys, @@ -349,7 +351,7 @@ impl Sender { pub(crate) fn extract_request( ohttp_relay: impl IntoUrl, reply_key: HpkeSecretKey, - body: Vec, + body: &[u8; PADDED_PLAINTEXT_A_LENGTH], url: Url, receiver_pubkey: HpkePublicKey, ohttp_keys: &mut OhttpKeys, @@ -378,7 +380,7 @@ pub(crate) fn serialize_v2_body( output_substitution: OutputSubstitution, fee_contribution: Option, min_fee_rate: FeeRate, -) -> Result, CreateRequestError> { +) -> Result<[u8; PADDED_PLAINTEXT_A_LENGTH], CreateRequestError> { // Grug say localhost base be discarded anyway. no big brain needed. let base_url = Url::parse("http://localhost").expect("invalid URL"); @@ -386,7 +388,23 @@ pub(crate) fn serialize_v2_body( serialize_url(base_url, output_substitution, fee_contribution, min_fee_rate, Version::Two); let query_params = placeholder_url.query().unwrap_or_default(); let base64 = psbt.to_string(); - Ok(format!("{base64}\n{query_params}").into_bytes()) + + let mut buf = [0u8; PADDED_PLAINTEXT_A_LENGTH]; + // The body length needs an additional byte added to account for the newline between the + // base64 body and query params + let body_len = base64.len() + 1 + query_params.len(); + + write!(&mut &mut buf[..], "{base64}\n{query_params}").map_err(|e| { + assert!(e.kind() == std::io::ErrorKind::WriteZero); + CreateRequestError::from(InternalCreateRequestError::Hpke( + crate::hpke::HpkeError::PayloadTooLarge { + actual: body_len, + max: PADDED_PLAINTEXT_A_LENGTH, + }, + )) + })?; + + Ok(buf) } /// Data required to validate the POST response. @@ -430,7 +448,7 @@ impl Sender { .join(&mailbox.to_string()) .map_err(|e| InternalCreateRequestError::Url(e.into()))?; let body = encrypt_message_a( - Vec::new(), + &[0u8; PADDED_PLAINTEXT_A_LENGTH], &HpkeKeyPair::from_secret_key(&self.reply_key).public_key().clone(), &self.pj_param.receiver_pubkey().clone(), ) @@ -513,6 +531,7 @@ mod test { use std::time::{Duration, SystemTime}; use bitcoin::hex::FromHex; + use bitcoin::psbt::raw; use bitcoin::Address; use payjoin_test_utils::{BoxError, EXAMPLE_URL, KEM, KEY_ID, PARSED_ORIGINAL_PSBT, SYMMETRIC}; @@ -559,8 +578,64 @@ mod test { sender.psbt_ctx.output_substitution, sender.psbt_ctx.fee_contribution, sender.psbt_ctx.min_fee_rate, + )?; + + let expected_bytes = as FromHex>::from_hex(SERIALIZED_BODY_V2)?; + let expected_len = expected_bytes.len(); + assert_eq!(&body[..expected_len], &expected_bytes[..]); + Ok(()) + } + + #[test] + fn test_serialize_body_length() -> Result<(), BoxError> { + let sender = create_sender_context(SystemTime::now() + Duration::from_secs(60))?; + let mut padded_psbt = sender.psbt_ctx.original_psbt.clone(); + let base_url = Url::parse("http://localhost").expect("invalid URL"); + + let placeholder_url = serialize_url( + base_url, + sender.psbt_ctx.output_substitution, + sender.psbt_ctx.fee_contribution, + sender.psbt_ctx.min_fee_rate, + Version::Two, + ); + let query_params = placeholder_url.query().unwrap_or_default(); + let dummy_key = raw::ProprietaryKey { prefix: vec![0], subtype: 0, key: vec![0] }; + + let dummy_value = vec![0u8; 4952]; + padded_psbt.proprietary.insert(dummy_key.clone(), dummy_value.clone()); + let body = serialize_v2_body( + &padded_psbt, + sender.psbt_ctx.output_substitution, + sender.psbt_ctx.fee_contribution, + sender.psbt_ctx.min_fee_rate, + ); + assert!(body.is_ok()); + + let dummy_value = vec![0u8; PADDED_PLAINTEXT_A_LENGTH]; + padded_psbt.proprietary.insert(dummy_key, dummy_value); + let body = serialize_v2_body( + &padded_psbt, + sender.psbt_ctx.output_substitution, + sender.psbt_ctx.fee_contribution, + sender.psbt_ctx.min_fee_rate, ); - assert_eq!(body.as_ref().unwrap(), & as FromHex>::from_hex(SERIALIZED_BODY_V2)?,); + let padded_psbt_length = format!("{}\n{}", padded_psbt, query_params).len(); + match body { + Err(e) => { + assert_eq!( + e.to_string(), + CreateRequestError::from(InternalCreateRequestError::Hpke( + crate::hpke::HpkeError::PayloadTooLarge { + actual: padded_psbt_length, + max: PADDED_PLAINTEXT_A_LENGTH + }, + )) + .to_string() + ) + } + Ok(_) => panic!("Expected error, got Ok"), + } Ok(()) }