diff --git a/Cargo.lock b/Cargo.lock index 8899937..43572d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1834,6 +1834,12 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "dunce" version = "1.0.5" @@ -2067,6 +2073,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619" + [[package]] name = "fs_extra" version = "1.3.0" @@ -3183,6 +3195,32 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "mockall" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -3675,6 +3713,32 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "predicates" +version = "3.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" + +[[package]] +name = "predicates-tree" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "prettyplease" version = "0.2.35" @@ -3750,6 +3814,7 @@ dependencies = [ "der", "futures", "hex", + "mockall", "prost", "reqwest", "secp256k1", @@ -7283,6 +7348,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index 835edfa..9570178 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,3 +33,7 @@ wormhole-vaas-serde = "0.1.0" [dev-dependencies] serde_json = "1.0.140" base64 = "0.22.1" +mockall = "0.13.1" + +[profile.release] +overflow-checks = true diff --git a/src/main.rs b/src/main.rs index 42a6e29..de2b615 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,6 +61,7 @@ const INVALID_UNRELIABLE_DATA_FORMAT: &str = "Invalid unreliable data format"; const INVALID_PDA_MESSAGE: &str = "Invalid PDA message"; const INVALID_EMITTER_CHAIN: &str = "Invalid emitter chain"; const INVALID_ACCUMULATOR_ADDRESS: &str = "Invalid accumulator address"; +const INVALID_VAA_VERSION: &str = "Invalid VAA version"; fn decode_and_verify_update( wormhole_pid: &Pubkey, @@ -87,6 +88,14 @@ fn decode_and_verify_update( anyhow::anyhow!(format!("{}: {}", INVALID_UNRELIABLE_DATA_FORMAT, e)) })?; + if unreliable_data.message.vaa_version != 1 { + tracing::error!( + vaa_version = unreliable_data.message.vaa_version, + "Unsupported VAA version" + ); + return Err(anyhow::anyhow!(INVALID_VAA_VERSION)); + } + if Chain::Pythnet != unreliable_data.emitter_chain.into() { tracing::error!( emitter_chain = unreliable_data.emitter_chain, @@ -320,7 +329,7 @@ mod tests { use secp256k1::SecretKey; use solana_account_decoder::{UiAccount, UiAccountData}; - use crate::posted_message::MessageData; + use crate::{posted_message::MessageData, signer::MockSigner}; fn get_wormhole_pid() -> Pubkey { Pubkey::from_str("H3fxXJ86ADW2PNuDDmZJg6mzTtPxkYCpNuQUTgmJ7AjU").unwrap() @@ -488,6 +497,18 @@ mod tests { assert_eq!(result.unwrap_err().to_string(), INVALID_EMITTER_CHAIN); } + #[test] + fn test_decode_and_verify_update_invalid_vaa_version() { + let mut unreliable_data = get_unreliable_data(); + unreliable_data.message.vaa_version = 2; + let result = decode_and_verify_update( + &get_wormhole_pid(), + &get_accumulator_address(), + get_update(unreliable_data), + ); + assert_eq!(result.unwrap_err().to_string(), INVALID_VAA_VERSION); + } + #[test] fn test_decode_and_verify_update_invalid_emitter_address() { let mut unreliable_data = get_unreliable_data(); @@ -530,4 +551,21 @@ mod tests { "30e41be3f10d3ac813f91e49e189bbb948d030be", ); } + + #[tokio::test] + async fn test_sign_error() { + let mut mock_signer = MockSigner::new(); + mock_signer + .expect_sign() + .return_once(|_| Err(anyhow::anyhow!("Mock signing error"))); + let unreliable_data = get_unreliable_data(); + let body = message_data_to_body(&unreliable_data); + let err = Observation::try_new(body, Arc::new(mock_signer)) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Failed to sign observation: Mock signing error" + ); + } } diff --git a/src/posted_message.rs b/src/posted_message.rs index 80777f5..92bffdc 100644 --- a/src/posted_message.rs +++ b/src/posted_message.rs @@ -127,4 +127,55 @@ mod tests { let decoded = PostedMessageUnreliableData::try_from_slice(encoded.as_slice()).unwrap(); assert_eq!(decoded, post_message_unreliable_data); } + + #[test] + fn test_invalid_magic() { + let post_message_unreliable_data = PostedMessageUnreliableData { + message: MessageData { + vaa_version: 1, + consistency_level: 2, + vaa_time: 3, + vaa_signature_account: [4u8; 32], + submission_time: 5, + nonce: 6, + sequence: 7, + emitter_chain: 8, + emitter_address: [9u8; 32], + payload: vec![10u8; 32], + }, + }; + + let mut encoded = borsh::to_vec(&post_message_unreliable_data).unwrap(); + encoded[0..3].copy_from_slice(b"foo"); // Invalid magic + + let err = PostedMessageUnreliableData::try_from_slice(encoded.as_slice()).unwrap_err(); + assert_eq!( + err.to_string(), + "Magic mismatch. Expected [109, 115, 117] but got [102, 111, 111]" + ); + } + + #[test] + fn test_invalid_data_length() { + let post_message_unreliable_data = PostedMessageUnreliableData { + message: MessageData { + vaa_version: 1, + consistency_level: 2, + vaa_time: 3, + vaa_signature_account: [4u8; 32], + submission_time: 5, + nonce: 6, + sequence: 7, + emitter_chain: 8, + emitter_address: [9u8; 32], + payload: vec![10u8; 32], + }, + }; + + let mut encoded = borsh::to_vec(&post_message_unreliable_data).unwrap(); + encoded = encoded[0..encoded.len() - 1].to_vec(); + + let err = PostedMessageUnreliableData::try_from_slice(encoded.as_slice()).unwrap_err(); + assert_eq!(err.to_string(), "Unexpected length of input"); + } } diff --git a/src/signer.rs b/src/signer.rs index a40ef39..18e6677 100644 --- a/src/signer.rs +++ b/src/signer.rs @@ -1,5 +1,5 @@ use der::{ - asn1::{AnyRef, BitStringRef, UintRef}, + asn1::{AnyRef, BitStringRef}, oid::ObjectIdentifier, Decode, Sequence, }; @@ -13,12 +13,13 @@ use std::{ use async_trait::async_trait; use prost::Message as ProstMessage; use secp256k1::{ - ecdsa::{RecoverableSignature, RecoveryId}, + ecdsa::{RecoverableSignature, RecoveryId, Signature}, Message, PublicKey, Secp256k1, SecretKey, }; use sequoia_openpgp::armor::{Kind, Reader, ReaderMode}; use sha3::{Digest, Keccak256}; +#[cfg_attr(test, mockall::automock)] #[async_trait] pub trait Signer: Send + Sync { async fn sign(&self, data: [u8; 32]) -> anyhow::Result<[u8; 65]>; @@ -181,12 +182,6 @@ pub struct SubjectPublicKeyInfo<'a> { pub subject_public_key: BitStringRef<'a>, } -#[derive(Sequence)] -struct EcdsaSignature<'a> { - r: UintRef<'a>, - s: UintRef<'a>, -} - #[async_trait] impl Signer for KMSSigner { async fn sign(&self, data: [u8; 32]) -> anyhow::Result<[u8; 65]> { @@ -204,14 +199,13 @@ impl Signer for KMSSigner { .signature .ok_or_else(|| anyhow::anyhow!("KMS did not return a signature"))?; - let decoded_signature = EcdsaSignature::from_der(kms_signature.as_ref()) - .map_err(|e| anyhow::anyhow!("Failed to decode SubjectPublicKeyInfo: {}", e))?; - - let r_bytes = decoded_signature.r.as_bytes(); - let s_bytes = decoded_signature.s.as_bytes(); - let mut signature = [0u8; 65]; - signature[(32 - r_bytes.len())..32].copy_from_slice(r_bytes); - signature[(64 - s_bytes.len())..64].copy_from_slice(decoded_signature.s.as_bytes()); + let mut signature = Signature::from_der(kms_signature.as_ref()) + .map_err(|e| anyhow::anyhow!("Failed to decode signature from KMS: {}", e))?; + // NOTE: AWS KMS does not guarantee that the ECDSA signature is normalized. + // Therefore, we must normalize it ourselves to prevent malleability, + // so that it can be successfully verified later using the secp256k1 standard libraries. + signature.normalize_s(); + let signature_bytes = signature.serialize_compact(); let public_key = self.get_public_key()?; for raw_id in 0..4 { @@ -220,10 +214,12 @@ impl Signer for KMSSigner { .map_err(|e| anyhow::anyhow!("Failed to create RecoveryId: {}", e))?; if let Ok(recovered_public_key) = secp.recover_ecdsa( &Message::from_digest(data), - &RecoverableSignature::from_compact(&signature[..64], recid) + &RecoverableSignature::from_compact(&signature_bytes, recid) .map_err(|e| anyhow::anyhow!("Failed to create RecoverableSignature: {}", e))?, ) { if recovered_public_key == public_key.0 { + let mut signature = [0u8; 65]; + signature[..64].copy_from_slice(&signature_bytes); signature[64] = raw_id as u8; return Ok(signature); }