diff --git a/Cargo.lock b/Cargo.lock index c7396d7b59..1611f53196 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -328,6 +328,16 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aegis" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "305080716a198f6a57096263ac2df9f681f9487728819a83e0137b18050fc8ad" +dependencies = [ + "cc", + "softaes", +] + [[package]] name = "aes" version = "0.8.4" @@ -2358,6 +2368,7 @@ dependencies = [ "anstyle", "clap_lex 0.7.5", "strsim", + "terminal_size", ] [[package]] @@ -2435,6 +2446,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "condtype" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf0a07a401f374238ab8e2f11a104d2851bf9ce711ec69804834de8af45c7af" + [[package]] name = "console" version = "0.15.11" @@ -2995,6 +3012,31 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "divan" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a405457ec78b8fe08b0e32b4a3570ab5dff6dd16eb9e76a5ee0a9d9cbd898933" +dependencies = [ + "cfg-if", + "clap 4.5.47", + "condtype", + "divan-macros", + "libc", + "regex-lite", +] + +[[package]] +name = "divan-macros" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9556bc800956545d6420a640173e5ba7dfa82f38d3ea5a167eb555bc69ac3323" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "dunce" version = "1.0.5" @@ -4241,6 +4283,7 @@ checksum = "154934ea70c58054b556dd430b99a98c2a7ff5309ac9891597e339b5c28f4371" dependencies = [ "console", "once_cell", + "serde", "similar", ] @@ -6236,6 +6279,65 @@ dependencies = [ "tempfile", ] +[[package]] +name = "monad-wireauth-api" +version = "0.1.0" +dependencies = [ + "bytes", + "clap 4.5.47", + "divan", + "hex", + "lru", + "monad-wireauth-protocol", + "monad-wireauth-session", + "monoio", + "proptest", + "rand 0.8.5", + "rstest", + "secp256k1", + "thiserror 1.0.69", + "tracing", + "tracing-subscriber", + "zerocopy", +] + +[[package]] +name = "monad-wireauth-protocol" +version = "0.1.0" +dependencies = [ + "aegis", + "blake3", + "bytes", + "hex", + "insta", + "rand 0.8.5", + "secp256k1", + "serde", + "tai64", + "thiserror 1.0.69", + "tracing", + "zerocopy", + "zeroize", +] + +[[package]] +name = "monad-wireauth-session" +version = "0.1.0" +dependencies = [ + "bytes", + "hex", + "monad-wireauth-protocol", + "proptest", + "rand 0.8.5", + "rstest", + "secp256k1", + "thiserror 1.0.69", + "tracing", + "tracing-subscriber", + "zerocopy", + "zeroize", +] + [[package]] name = "mongocrypt" version = "0.3.1" @@ -8602,6 +8704,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "softaes" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fef461faaeb36c340b6c887167a9054a034f6acfc50a014ead26a02b4356b3de" + [[package]] name = "sorted-vec" version = "0.8.8" @@ -8800,6 +8908,12 @@ dependencies = [ "libc", ] +[[package]] +name = "tai64" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "014639506e4f425c78e823eabf56e71c093f940ae55b43e58f682e7bc2f5887a" + [[package]] name = "take_mut" version = "0.2.2" @@ -8834,6 +8948,16 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "terminal_size" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b8cb979cb11c32ce1603f8137b22262a9d131aaa5c37b5678025f22b8becd0" +dependencies = [ + "rustix 1.0.8", + "windows-sys 0.60.2", +] + [[package]] name = "test-case" version = "3.3.1" diff --git a/Cargo.toml b/Cargo.toml index c3c38b1a68..ec26d980c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,6 +84,9 @@ monad-types = { path = "./monad-types" } monad-updaters = { path = "./monad-updaters" } monad-validator = { path = "./monad-validator" } monad-wal = { path = "./monad-wal" } +monad-wireauth-api = { path = "./monad-wireauth-api" } +monad-wireauth-protocol = { path = "./monad-wireauth-protocol" } +monad-wireauth-session = { path = "./monad-wireauth-session" } actix = "0.13" actix-http = "3.11.1" @@ -92,6 +95,7 @@ actix-test = "0.1" actix-rt = "2.9.0" actix-web = "4.5.1" actix-ws = "0.3.0" +aegis = "0.5" aes = "0.8.3" agent = { git = "https://github.com/category-labs/manytrace.git", tag = "v0.1.1" } alloy-consensus = "0.8" @@ -138,6 +142,7 @@ criterion = { version = "0.4", features = ["html_reports"] } insta = "1.42" ctr = "0.9.2" dashmap = "6.1.0" +divan = "0.1" enum_dispatch = "0.3.13" env_logger = "0.10" eyre = "0.6.12" @@ -199,6 +204,7 @@ sorted_vector_map = "0.2.0" sorted-vec = "0.8.3" strum = "0.26.3" syn = { version = "1.0.0", features = ["full"] } +tai64 = "4.0" tempfile = "3.5" test-case = "3.0" thiserror = "1.0" @@ -220,5 +226,5 @@ unicode-normalization = "0.1" url = "2.5.0" wasm-bindgen = "0.2" zerocopy = { version = "0.8", features = ["derive"] } -zeroize = "1.3" +zeroize = { version = "1.7", features = ["derive"] } zstd = "0.13" diff --git a/monad-wireauth-api/Cargo.toml b/monad-wireauth-api/Cargo.toml new file mode 100644 index 0000000000..153d00068d --- /dev/null +++ b/monad-wireauth-api/Cargo.toml @@ -0,0 +1,50 @@ +[package] +name = "monad-wireauth-api" +version = "0.1.0" +edition = "2021" + +[dependencies] +monad-wireauth-session = { workspace = true, features = ["bench"] } +monad-wireauth-protocol.workspace = true +bytes.workspace = true +rand.workspace = true +thiserror.workspace = true +secp256k1 = { workspace = true, features = ["global-context"] } +zerocopy.workspace = true +tracing.workspace = true +hex.workspace = true +lru.workspace = true + +[dependencies.clap] +workspace = true +features = ["derive"] +optional = true + +[dependencies.monoio] +workspace = true +features = ["sync", "macros"] +optional = true + +[dependencies.tracing-subscriber] +workspace = true +features = ["env-filter"] +optional = true + +[dev-dependencies] +monoio = { workspace = true, features = ["sync", "macros"] } +tracing-subscriber = { workspace = true, features = ["env-filter"] } +rstest.workspace = true +proptest.workspace = true +divan.workspace = true + +[features] +bench = ["monad-wireauth-session/bench"] + +[[bench]] +name = "manager_bench" +harness = false +required-features = ["bench"] + +[[example]] +name = "demo" +required-features = ["clap", "monoio", "tracing-subscriber"] diff --git a/monad-wireauth-api/README.md b/monad-wireauth-api/README.md new file mode 100644 index 0000000000..78db5bc81d --- /dev/null +++ b/monad-wireauth-api/README.md @@ -0,0 +1,32 @@ +# API + +## DoS Protection + +the filter operates in three modes based on load: + +| condition | action | +|-----------|--------| +| sessions >= `high_watermark_sessions` or handshakes >= `handshake_rate_limit` | drop request | +| sessions >= `low_watermark_sessions` and cookie invalid | send cookie reply | +| sessions >= `low_watermark_sessions` and cookie valid | apply per-ip rate limiting via lru cache | +| sessions < `low_watermark_sessions` | no additional measures | + +defaults: `high_watermark_sessions`=100,000, `handshake_rate_limit`=2000/sec, `low_watermark_sessions`=10,000, `ip_rate_limit_window`=10s, `max_sessions_per_ip`=10, `ip_history_capacity`=1,000,000 + +at 2000 handshakes/sec, approximately 400ms of cpu time per second is spent on handshake-related computation during such attack. + +## Benchmarks + +CPU: 12th Gen Intel(R) Core(TM) i9-12900KF + +RUSTFLAGS: `-C target-cpu=haswell -C opt-level=3` + +``` +Timer precision: 26 ns +manager_bench fastest │ slowest │ median │ mean │ samples │ iters +├─ bench_session_decrypt 162.8 ns │ 235.5 ns │ 166.1 ns │ 167.2 ns │ 100 │ 1600 +├─ bench_session_encrypt 133.5 ns │ 136 ns │ 134.6 ns │ 134.6 ns │ 100 │ 3200 +├─ bench_session_handle_init 131.7 µs │ 206.9 µs │ 135.7 µs │ 137.7 µs │ 100 │ 100 +├─ bench_session_handle_response 61.64 µs │ 72.83 µs │ 62.98 µs │ 63.74 µs │ 100 │ 100 +╰─ bench_session_send_init 74.48 µs │ 116.2 µs │ 76.31 µs │ 77.72 µs │ 100 │ 100 +``` diff --git a/monad-wireauth-api/benches/manager_bench.rs b/monad-wireauth-api/benches/manager_bench.rs new file mode 100644 index 0000000000..d4f5baa669 --- /dev/null +++ b/monad-wireauth-api/benches/manager_bench.rs @@ -0,0 +1,185 @@ +use std::net::SocketAddr; + +use divan::Bencher; +use monad_wireauth_api::{Config, TestContext, API}; +use monad_wireauth_protocol::{common::PublicKey, messages::DataPacketHeader}; +use secp256k1::rand::rng; +use zerocopy::{FromBytes, IntoBytes}; + +fn main() { + divan::main(); +} + +fn create_test_manager() -> (API, PublicKey, TestContext) { + let mut rng = rng(); + let (public_key, private_key) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let config = Config::default(); + let context = TestContext::new(); + let context_clone = context.clone(); + + let manager = API::new(config, private_key, public_key.clone(), context); + (manager, public_key, context_clone) +} + +fn establish_session( + peer1_manager: &mut API, + peer2_manager: &mut API, + _peer1_public: &PublicKey, + peer2_public: &PublicKey, +) { + let peer1_addr: SocketAddr = "127.0.0.1:51820".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:51821".parse().unwrap(); + + peer1_manager + .connect( + peer2_public.clone(), + peer2_addr, + monad_wireauth_session::DEFAULT_RETRY_ATTEMPTS, + ) + .expect("peer1 failed to init session"); + + let init_packet = peer1_manager.next_packet().unwrap().1; + + let mut init_packet_mut = init_packet.to_vec(); + peer2_manager + .dispatch(&mut init_packet_mut, peer1_addr) + .expect("peer2 failed to accept handshake"); + + let response_packet = peer2_manager.next_packet().unwrap().1; + + let mut response_packet_mut = response_packet.to_vec(); + peer1_manager + .dispatch(&mut response_packet_mut, peer2_addr) + .expect("peer1 failed to complete handshake"); + + while let Some((_addr, packet)) = peer1_manager.next_packet() { + let mut packet_mut = packet.to_vec(); + peer2_manager.dispatch(&mut packet_mut, peer1_addr).ok(); + } +} + +#[divan::bench] +fn bench_session_send_init(bencher: Bencher) { + bencher + .with_inputs(|| { + let (manager, _local_public, _) = create_test_manager(); + let (_peer2_manager, peer2_public, _) = create_test_manager(); + let peer2_addr: SocketAddr = "127.0.0.1:51821".parse().unwrap(); + + (manager, peer2_public, peer2_addr) + }) + .bench_local_values(|(mut manager, peer2_public, peer2_addr)| { + manager + .connect( + peer2_public, + peer2_addr, + monad_wireauth_session::DEFAULT_RETRY_ATTEMPTS, + ) + .expect("failed to init session"); + }); +} + +#[divan::bench] +fn bench_session_handle_init(bencher: Bencher) { + bencher + .with_inputs(|| { + let (mut peer1_manager, _peer1_public, _) = create_test_manager(); + let (peer2_manager, peer2_public, _) = create_test_manager(); + let peer1_addr: SocketAddr = "127.0.0.1:51820".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:51821".parse().unwrap(); + + peer1_manager + .connect( + peer2_public, + peer2_addr, + monad_wireauth_session::DEFAULT_RETRY_ATTEMPTS, + ) + .expect("failed to init session"); + let init_packet = peer1_manager.next_packet().unwrap().1; + + (peer2_manager, init_packet, peer1_addr) + }) + .bench_local_values(|(mut peer2_manager, init_packet, peer1_addr)| { + let mut init_packet_mut = init_packet.to_vec(); + peer2_manager + .dispatch(&mut init_packet_mut, peer1_addr) + .expect("failed to handle init"); + }); +} + +#[divan::bench] +fn bench_session_handle_response(bencher: Bencher) { + bencher + .with_inputs(|| { + let (mut mgr1, _peer1_public, _) = create_test_manager(); + let (mut mgr2, peer2_public, _) = create_test_manager(); + let peer1_addr: SocketAddr = "127.0.0.1:51820".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:51821".parse().unwrap(); + + mgr1.connect( + peer2_public, + peer2_addr, + monad_wireauth_session::DEFAULT_RETRY_ATTEMPTS, + ) + .expect("init failed"); + let init_packet = mgr1.next_packet().unwrap().1; + + let mut init_packet_mut = init_packet.to_vec(); + mgr2.dispatch(&mut init_packet_mut, peer1_addr) + .expect("dispatch failed"); + let response_packet = mgr2.next_packet().unwrap().1; + + (mgr1, response_packet, peer2_addr) + }) + .bench_local_values(|(mut mgr1, response_packet, peer2_addr)| { + let mut response_packet_mut = response_packet.to_vec(); + mgr1.dispatch(&mut response_packet_mut, peer2_addr) + .expect("handle response failed"); + }); +} + +#[divan::bench] +fn bench_session_encrypt(bencher: Bencher) { + let (mut mgr1, peer1_public, _) = create_test_manager(); + let (mut mgr2, peer2_public, _) = create_test_manager(); + + establish_session(&mut mgr1, &mut mgr2, &peer1_public, &peer2_public); + + let mut plaintext = vec![0u8; 1024]; + bencher.bench_local(|| { + mgr1.encrypt_by_public_key(&peer2_public, &mut plaintext) + .expect("encryption failed"); + }); +} + +#[divan::bench] +fn bench_session_decrypt(bencher: Bencher) { + let (mut mgr1, peer1_public, _) = create_test_manager(); + let (mut mgr2, peer2_public, _) = create_test_manager(); + + establish_session(&mut mgr1, &mut mgr2, &peer1_public, &peer2_public); + + let mut plaintext = vec![0u8; 1024]; + let header = mgr1 + .encrypt_by_public_key(&peer2_public, &mut plaintext) + .expect("encryption failed"); + + let mut packet_data = Vec::with_capacity(header.as_bytes().len() + plaintext.len()); + packet_data.extend_from_slice(header.as_bytes()); + packet_data.extend_from_slice(&plaintext); + + let header_ref = DataPacketHeader::ref_from_bytes(&packet_data[..DataPacketHeader::SIZE]) + .expect("failed to get header"); + let receiver_index = header_ref.receiver_index.get(); + + let peer1_addr: SocketAddr = "127.0.0.1:51820".parse().unwrap(); + + bencher.bench_local(|| { + mgr2.reset_replay_filter_for_receiver(receiver_index); + + let mut packet_clone = packet_data.clone(); + mgr2.dispatch(&mut packet_clone, peer1_addr) + .expect("decryption failed"); + }); +} diff --git a/monad-wireauth-api/examples/demo.rs b/monad-wireauth-api/examples/demo.rs new file mode 100644 index 0000000000..e8f70587ac --- /dev/null +++ b/monad-wireauth-api/examples/demo.rs @@ -0,0 +1,178 @@ +use std::{future::pending, net::SocketAddr, rc::Rc, time::Duration}; + +use clap::Parser; +use monad_wireauth_api::{Config, StdContext, API, RETRY_ALWAYS}; +use monad_wireauth_protocol::{common::PublicKey, crypto}; +use monoio::net::udp::UdpSocket; +use secp256k1::rand::{rngs::StdRng, SeedableRng}; +use tracing::{debug, info, warn}; +use zerocopy::IntoBytes; + +#[derive(Parser, Debug)] +#[command(version, about = "monoio-based authenticated protocol demo")] +struct Args { + #[arg(short, long, help = "listener address")] + listener: SocketAddr, + + #[arg(short, long, value_delimiter = ',', help = "peer addresses")] + peers: Vec, +} + +struct PeerNode { + manager: API, + socket: Rc, + id: u64, +} + +impl PeerNode { + fn new(addr: SocketAddr, seed: u64, id: u64) -> std::io::Result { + let mut rng = StdRng::seed_from_u64(seed); + let (public_key, private_key) = crypto::generate_keypair(&mut rng).unwrap(); + + info!(id = id, addr = %addr, "initializing node"); + + let config = Config { + session_timeout: Duration::from_secs(10), + session_timeout_jitter: Duration::ZERO, + keepalive_interval: Duration::from_secs(3), + keepalive_jitter: Duration::ZERO, + rekey_interval: Duration::from_secs(60), + rekey_jitter: Duration::ZERO, + ..Default::default() + }; + + let context = StdContext::new(); + let manager = API::new(config, private_key, public_key, context); + let socket = UdpSocket::bind(addr)?; + + Ok(Self { + manager, + socket: Rc::new(socket), + id, + }) + } + + fn connect(&mut self, peer_id: u64, peer_public: PublicKey, peer_addr: SocketAddr) { + info!( + id = self.id, + peer_id = peer_id, + peer_addr = %peer_addr, + "connecting to peer" + ); + self.manager + .connect(peer_public, peer_addr, RETRY_ALWAYS) + .unwrap(); + } + + async fn send_all_packets(&mut self) -> std::io::Result<()> { + while let Some((addr, packet)) = self.manager.next_packet() { + let packet_vec = packet.to_vec(); + let (result, _) = self.socket.send_to(packet_vec, addr).await; + result?; + } + Ok(()) + } + + fn encrypt_by_public_key( + &mut self, + peer_public: &PublicKey, + plaintext: &mut [u8], + ) -> monad_wireauth_api::Result { + self.manager.encrypt_by_public_key(peer_public, plaintext) + } + + fn next_timer(&mut self) -> Option { + self.manager.next_timer() + } +} + +#[monoio::main(timer_enabled = true)] +async fn main() -> std::io::Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), + ) + .init(); + + let args = Args::parse(); + + if args.peers.is_empty() { + warn!("no peers specified"); + } + + let id = args.listener.port() as u64; + let mut node = PeerNode::new(args.listener, id, id)?; + + let mut peer_keys = Vec::new(); + for &peer_addr in &args.peers { + let peer_id = peer_addr.port() as u64; + let seed = peer_id; + let mut rng = StdRng::seed_from_u64(seed); + let (peer_public, _) = crypto::generate_keypair(&mut rng).unwrap(); + peer_keys.push((peer_id, peer_public, peer_addr)); + } + + for &(peer_id, ref peer_public, peer_addr) in &peer_keys { + node.connect(peer_id, peer_public.clone(), peer_addr); + } + + let mut tick_interval = monoio::time::interval(Duration::from_secs(1)); + + loop { + let socket = node.socket.clone(); + let buf = vec![0u8; 65536]; + + monoio::select! { + recv_result = socket.recv_from(buf) => { + let (result, mut buf) = recv_result; + if let Ok((len, src)) = result { + if let Ok(Some(data)) = node.manager.dispatch(&mut buf[..len], src) { + if let Ok(msg) = std::str::from_utf8(&data) { + info!(id = node.id, src = %src, message = %msg, "received message"); + } + } + } + }, + _ = async { + match node.next_timer() { + Some(duration) => { + debug!(?duration, "next timer"); + monoio::time::sleep(duration).await; + node.manager.tick(); + }, + None => pending().await, + } + } => {}, + _ = tick_interval.tick() => { + for &(peer_id, ref peer_public, peer_addr) in &peer_keys { + let message = format!("hello from {} to {}", node.id, peer_id); + let mut plaintext = message.as_bytes().to_vec(); + + match node.encrypt_by_public_key(peer_public, &mut plaintext) { + Ok(header) => { + let mut packet = Vec::new(); + packet.extend_from_slice(header.as_bytes()); + packet.extend_from_slice(&plaintext); + + let (result, _) = node.socket.send_to(packet, peer_addr).await; + match result { + Ok(_) => { + info!(from = node.id, to = peer_id, message = %message, "sent message"); + } + Err(e) => { + warn!(from = node.id, to = peer_id, error = %e, "failed to send message"); + } + } + } + Err(e) => { + warn!(from = node.id, to = peer_id, error = %e, "failed to encrypt message"); + } + } + } + }, + } + + node.send_all_packets().await?; + } +} diff --git a/monad-wireauth-api/src/api.rs b/monad-wireauth-api/src/api.rs new file mode 100644 index 0000000000..3c1f4addf6 --- /dev/null +++ b/monad-wireauth-api/src/api.rs @@ -0,0 +1,617 @@ +use std::{ + collections::{BTreeSet, VecDeque}, + convert::TryFrom, + net::SocketAddr, + time::Duration, +}; + +use bytes::Bytes; +use monad_wireauth_protocol::{ + common::{PrivateKey, PublicKey, SerializedPublicKey}, + messages::{ + CookieReply, DataPacket, DataPacketHeader, HandshakeInitiation, HandshakeResponse, + TYPE_COOKIE_REPLY, TYPE_DATA, TYPE_HANDSHAKE_INITIATION, TYPE_HANDSHAKE_RESPONSE, + }, +}; +use monad_wireauth_session::{Config, SessionIndex}; +use tracing::{debug, instrument, trace, Level}; + +use crate::{ + context::Context, + cookie::Cookies, + error::{Error, ProtocolErrorContext, Result, SessionErrorContext}, + filter::{Filter, FilterAction}, + state::State, + InitiatorState, ResponderState, TransportState, +}; + +pub struct API { + state: State, + next_timers: BTreeSet<(Duration, SessionIndex)>, + packet_queue: VecDeque<(SocketAddr, Bytes)>, + config: Config, + local_static_key: PrivateKey, + local_static_public: PublicKey, + local_serialized_public: SerializedPublicKey, + cookies: Cookies, + filter: Filter, + context: C, +} + +impl API { + pub fn new( + config: Config, + local_static_key: PrivateKey, + local_static_public: PublicKey, + mut context: C, + ) -> Self { + let cookies = Cookies::new( + context.rng(), + (&local_static_public).into(), + config.cookie_refresh_duration, + ); + + let filter = Filter::new( + config.handshake_rate_limit, + config.handshake_rate_reset_interval, + config.ip_rate_limit_window, + config.ip_history_capacity, + config.max_sessions_per_ip, + config.low_watermark_sessions, + config.high_watermark_sessions, + ); + let local_serialized_public: SerializedPublicKey = (&local_static_public).into(); + debug!(local_public_key=?local_serialized_public, "initialized manager"); + Self { + state: State::new(), + next_timers: BTreeSet::new(), + packet_queue: VecDeque::new(), + config, + local_static_key, + local_static_public, + local_serialized_public, + cookies, + filter, + context, + } + } + + #[instrument(level = Level::TRACE, skip(self), fields(local_public_key = ?self.local_serialized_public))] + pub fn next_packet(&mut self) -> Option<(SocketAddr, Bytes)> { + self.packet_queue.pop_front() + } + + #[instrument(level = Level::TRACE, skip(self), fields(local_public_key = ?self.local_serialized_public))] + pub fn next_timer(&self) -> Option { + let duration_since_start = self.context.duration_since_start(); + + let session_timer = self.next_timers.iter().next().map(|&(deadline, _)| { + if deadline > duration_since_start { + deadline - duration_since_start + } else { + Duration::ZERO + } + }); + + let filter_reset_time = self.filter.next_reset_time(); + let filter_timer = if filter_reset_time > duration_since_start { + filter_reset_time - duration_since_start + } else { + Duration::ZERO + }; + + match session_timer { + Some(st) => Some(st.min(filter_timer)), + None => Some(filter_timer), + } + } + + #[instrument(level = Level::TRACE, skip(self), fields(local_public_key = ?self.local_serialized_public))] + pub fn tick(&mut self) { + let duration_since_start = self.context.duration_since_start(); + + self.filter.tick(duration_since_start); + + let expired_sessions: Vec<(Duration, SessionIndex)> = self + .next_timers + .range(..=(duration_since_start, SessionIndex::new(u32::MAX))) + .copied() + .collect(); + + for timer in &expired_sessions { + self.next_timers.remove(timer); + } + + for (_, session_id) in expired_sessions { + let tick_result = if self.state.get_initiator(&session_id).is_some() { + self.state.get_initiator_mut(&session_id).and_then(|s| { + s.tick(duration_since_start) + .map(|(timer, r)| (timer, None, r.rekey, Some(r.terminated))) + }) + } else if self.state.get_responder(&session_id).is_some() { + self.state.get_responder_mut(&session_id).and_then(|s| { + s.tick(duration_since_start) + .map(|(timer, r)| (timer, None, r.rekey, Some(r.terminated))) + }) + } else if let Some(transport) = self.state.get_transport_mut(&session_id) { + Some(transport.tick(&self.config, duration_since_start)) + } else { + None + }; + + if let Some((timer, message, rekey, terminated)) = tick_result { + if let Some(message) = message { + self.packet_queue + .push_back((message.remote_addr, message.header.into())); + } + + if let Some(rekey) = &rekey { + self.handle_rekey_event( + rekey.remote_public_key.clone(), + rekey.remote_addr, + rekey.retry_attempts, + rekey.stored_cookie, + ); + } + + if let Some(timer) = timer { + self.next_timers.insert((timer, session_id)); + } + + if let Some(terminated) = &terminated { + self.handle_terminate_event( + session_id, + terminated.remote_public_key.clone(), + terminated.remote_addr, + ); + } + } + } + } + + fn handle_terminate_event( + &mut self, + session_id: SessionIndex, + remote_public_key: PublicKey, + remote_addr: SocketAddr, + ) { + self.filter.on_session_removed(remote_addr.ip()); + self.state + .handle_terminate(session_id, &(&remote_public_key).into(), remote_addr); + } + + fn handle_rekey_event( + &mut self, + remote_public_key: PublicKey, + remote_addr: SocketAddr, + retry_attempts: u64, + stored_cookie: Option<[u8; 16]>, + ) { + if let Ok((new_session_index, timer, message)) = self.init_session_with_cookie( + remote_public_key, + remote_addr, + stored_cookie, + retry_attempts, + ) { + self.packet_queue.push_back((remote_addr, message.into())); + self.next_timers.insert((timer, new_session_index)); + } + } + + fn handle_established(&mut self, session_id: SessionIndex, transport: TransportState) { + debug!(local_session_id=?session_id, "handling established session"); + let replaced_sessions = self + .state + .handle_established(session_id, transport, &self.config); + + for replaced_session_id in replaced_sessions { + let session = self + .state + .get_transport(&replaced_session_id) + .expect("session must exist"); + self.handle_terminate_event( + replaced_session_id, + session.remote_public_key.clone(), + session.remote_addr, + ); + } + } + + #[instrument(level = Level::TRACE, skip(self, remote_static_key), fields(local_public_key = ?self.local_serialized_public, remote_addr = ?remote_addr))] + pub fn connect( + &mut self, + remote_static_key: PublicKey, + remote_addr: SocketAddr, + retry_attempts: u64, + ) -> Result<()> { + debug!(retry_attempts, "initiating connection"); + let cookie = self + .state + .lookup_cookie_from_initiated_sessions(&(&remote_static_key).into()); + + let (local_index, timer, message) = + self.init_session_with_cookie(remote_static_key, remote_addr, cookie, retry_attempts)?; + + self.packet_queue.push_back((remote_addr, message.into())); + self.next_timers.insert((timer, local_index)); + + Ok(()) + } + + fn init_session_with_cookie( + &mut self, + remote_static_key: PublicKey, + remote_addr: SocketAddr, + cookie: Option<[u8; 16]>, + retry_attempts: u64, + ) -> Result<(SessionIndex, Duration, HandshakeInitiation)> { + let reservation = self + .state + .reserve_session_index() + .ok_or(Error::SessionIndexExhausted)?; + let local_index = reservation.index(); + debug!(local_session_id=?local_index, "allocating session index for new connection"); + let system_time = self.context.system_time(); + let duration_since_start = self.context.duration_since_start(); + let (session, (timer, message)) = InitiatorState::new( + self.context.rng(), + system_time, + duration_since_start, + &self.config, + local_index, + &self.local_static_key, + self.local_static_public.clone(), + remote_static_key.clone(), + remote_addr, + cookie, + retry_attempts, + ) + .map_err(|e| e.with_addr(remote_addr))?; + + reservation.commit(); + + self.state + .insert_initiator(local_index, session, (&remote_static_key).into()); + + self.filter.on_session_added(remote_addr.ip()); + + Ok((local_index, timer, message)) + } + + fn check_under_load( + &mut self, + remote_addr: SocketAddr, + sender_index: u32, + message: &M, + ) -> Result { + let duration_since_start = self.context.duration_since_start(); + let action = self.filter.apply( + remote_addr, + duration_since_start, + self.cookies + .verify(&remote_addr, message, duration_since_start) + .is_ok(), + ); + + match action { + FilterAction::Pass => Ok(true), + FilterAction::SendCookie => { + debug!(?remote_addr, sender_index, "sending cookie reply"); + let reply = self.cookies.create( + remote_addr, + sender_index, + message, + duration_since_start, + )?; + self.packet_queue.push_back((remote_addr, reply.into())); + Ok(false) + } + FilterAction::Drop => { + debug!( + ?remote_addr, + sender_index, "dropping packet due to rate limit" + ); + Ok(false) + } + } + } + + fn accept_handshake_init( + &mut self, + handshake_packet: &mut HandshakeInitiation, + remote_addr: SocketAddr, + ) -> Result<()> { + monad_wireauth_protocol::crypto::verify_mac1( + handshake_packet, + &(&self.local_static_public).into(), + ) + .map_err(|source| Error::Mac1VerificationFailed { + addr: remote_addr, + source, + })?; + + if !self.check_under_load( + remote_addr, + handshake_packet.sender_index.get(), + handshake_packet, + )? { + debug!(?remote_addr, "handshake initiation dropped under load"); + return Ok(()); + } + + let duration_since_start = self.context.duration_since_start(); + + let validated_init = ResponderState::validate_init( + &self.local_static_key, + &self.local_static_public, + handshake_packet, + ) + .map_err(|e| e.with_addr(remote_addr))?; + + let remote_key = SerializedPublicKey::from(&validated_init.remote_public_key); + if self + .state + .get_max_timestamp(&remote_key) + .is_some_and(|max| validated_init.system_time <= max) + { + debug!(?remote_addr, ?remote_key, "timestamp replay detected"); + return Err(Error::TimestampReplay { addr: remote_addr }); + } + + let stored_cookie = self.state.lookup_cookie_from_accepted_sessions(remote_key); + + let reservation = self + .state + .reserve_session_index() + .ok_or(Error::SessionIndexExhausted)?; + let local_index = reservation.index(); + + let (session, timer, message) = ResponderState::new( + self.context.rng(), + duration_since_start, + &self.config, + local_index, + stored_cookie.as_ref(), + validated_init, + remote_addr, + ) + .map_err(|e| e.with_addr(remote_addr))?; + + reservation.commit(); + + self.state + .insert_responder(local_index, session, remote_key); + + self.filter.on_session_added(remote_addr.ip()); + + self.packet_queue.push_back((remote_addr, message.into())); + self.next_timers.insert((timer, local_index)); + + Ok(()) + } + + fn accept_cookie( + &mut self, + cookie_reply: &mut CookieReply, + remote_addr: SocketAddr, + ) -> Result<()> { + let receiver_session_index = cookie_reply.receiver_index.into(); + + if let Some(session) = self.state.get_initiator_mut(&receiver_session_index) { + session + .handle_cookie(cookie_reply) + .map_err(|e| e.with_addr(remote_addr))?; + } else if let Some(session) = self.state.get_responder_mut(&receiver_session_index) { + session + .handle_cookie(cookie_reply) + .map_err(|e| e.with_addr(remote_addr))?; + } + Ok(()) + } + + fn decrypt_data_packet(&mut self, packet: &mut [u8], remote_addr: SocketAddr) -> Result { + let is_keepalive = packet.len() == DataPacketHeader::SIZE; + let data_packet = DataPacket::try_from(packet).map_err(|e| e.with_addr(remote_addr))?; + let receiver_index = data_packet.header.receiver_index.into(); + let nonce: u64 = data_packet.header.counter.into(); + trace!(local_session_id=?receiver_index, nonce, "decrypting data packet"); + + let timer = if let Some(transport) = self.state.get_transport_mut(&receiver_index) { + let duration_since_start = self.context.duration_since_start(); + let remote_addr = transport.remote_addr; + transport + .decrypt(&self.config, duration_since_start, data_packet) + .map_err(|e| e.with_addr(remote_addr))? + } else if let Some(responder) = self.state.get_responder_mut(&receiver_index) { + let duration_since_start = self.context.duration_since_start(); + match responder.decrypt(&self.config, duration_since_start, data_packet) { + Ok(_) => { + let responder = self.state.remove_responder(&receiver_index).unwrap(); + let (transport, establish_timer) = + responder.establish(self.context.rng(), &self.config, duration_since_start); + debug!(local_session_id=?receiver_index, "responder session established"); + self.handle_established(receiver_index, transport); + establish_timer + } + Err(e) => { + return Err(e.with_addr(remote_addr)); + } + } + } else { + return Err(Error::SessionIndexNotFound { + index: receiver_index, + }); + }; + + self.next_timers.insert((timer, receiver_index)); + + Ok(is_keepalive) + } + + fn complete_handshake( + &mut self, + response: &mut HandshakeResponse, + remote_addr: SocketAddr, + ) -> Result<()> { + monad_wireauth_protocol::crypto::verify_mac1(response, &(&self.local_static_public).into()) + .map_err(|source| Error::Mac1VerificationFailed { + addr: remote_addr, + source, + })?; + + if !self.check_under_load(remote_addr, response.sender_index.get(), response)? { + debug!(?remote_addr, "handshake response dropped under load"); + return Ok(()); + } + + let receiver_session_index = response.receiver_index.into(); + + let initiator = self + .state + .get_initiator_mut(&receiver_session_index) + .ok_or(Error::InvalidReceiverIndex { + index: receiver_session_index, + addr: remote_addr, + })?; + + let validated_response = initiator + .validate_response( + &self.config, + &self.local_static_key, + &self.local_static_public, + response, + ) + .map_err(|e| e.with_addr(remote_addr))?; + + let initiator = self + .state + .remove_initiator(&receiver_session_index) + .unwrap(); + + let duration_since_start = self.context.duration_since_start(); + debug!(local_session_id=?receiver_session_index, "initiator session established"); + let (transport, timer, message) = initiator.establish( + self.context.rng(), + &self.config, + duration_since_start, + validated_response, + remote_addr, + ); + + self.handle_established(receiver_session_index, transport); + + self.packet_queue.push_back((remote_addr, message.into())); + self.next_timers.insert((timer, receiver_session_index)); + + Ok(()) + } + + fn encrypt( + &mut self, + session_id: SessionIndex, + plaintext: &mut [u8], + ) -> Result { + let transport = self + .state + .get_transport_mut(&session_id) + .ok_or(Error::SessionNotFound)?; + let (header, timer) = + transport.encrypt(&self.config, self.context.duration_since_start(), plaintext); + self.next_timers.insert((timer, session_id)); + Ok(header) + } + + #[instrument(level = Level::TRACE, skip(self, public_key, plaintext), fields(local_public_key = ?self.local_serialized_public))] + pub fn encrypt_by_public_key( + &mut self, + public_key: &PublicKey, + plaintext: &mut [u8], + ) -> Result { + self.encrypt( + self.state + .get_session_id_by_public_key(&public_key.into()) + .ok_or(Error::SessionNotFound)?, + plaintext, + ) + } + + #[instrument(level = Level::TRACE, skip(self, plaintext), fields(local_public_key = ?self.local_serialized_public, socket_addr = ?socket_addr))] + pub fn encrypt_by_socket( + &mut self, + socket_addr: &SocketAddr, + plaintext: &mut [u8], + ) -> Result { + self.encrypt( + self.state + .get_session_id_by_socket(socket_addr) + .ok_or(Error::SessionNotEstablishedForAddress { addr: *socket_addr })?, + plaintext, + ) + } + + #[instrument(level = Level::TRACE, skip(self, packet), fields(local_public_key = ?self.local_serialized_public, remote_addr = ?remote_addr))] + pub fn dispatch( + &mut self, + packet: &mut [u8], + remote_addr: SocketAddr, + ) -> Result> { + if packet.is_empty() { + return Err(Error::EmptyPacket { addr: remote_addr }); + } + + match packet[0] { + TYPE_HANDSHAKE_INITIATION => { + debug!("processing handshake initiation"); + let handshake = <&mut HandshakeInitiation>::try_from(packet) + .map_err(|e| e.with_addr(remote_addr))?; + self.accept_handshake_init(handshake, remote_addr)?; + Ok(None) + } + TYPE_HANDSHAKE_RESPONSE => { + debug!("processing handshake response"); + let response = <&mut HandshakeResponse>::try_from(packet) + .map_err(|e| e.with_addr(remote_addr))?; + self.complete_handshake(response, remote_addr)?; + Ok(None) + } + TYPE_COOKIE_REPLY => { + debug!("processing cookie reply"); + let cookie_reply = + <&mut CookieReply>::try_from(packet).map_err(|e| e.with_addr(remote_addr))?; + self.accept_cookie(cookie_reply, remote_addr)?; + Ok(None) + } + TYPE_DATA => { + trace!("processing data packet"); + let is_keepalive = self.decrypt_data_packet(packet, remote_addr)?; + if is_keepalive { + trace!("keepalive packet"); + Ok(None) + } else { + Ok(Some(packet[DataPacketHeader::SIZE..].to_vec().into())) + } + } + _ => { + debug!(packet_type = packet[0], "unknown packet type"); + Err(Error::InvalidMessageType { + msg_type: packet[0] as u32, + addr: remote_addr, + }) + } + } + } + + #[instrument(level = Level::TRACE, skip(self, public_key), fields(local_public_key = ?self.local_serialized_public))] + pub fn disconnect(&mut self, public_key: &PublicKey) { + let terminated_addrs = self.state.terminate_by_public_key(&(public_key).into()); + for addr in terminated_addrs { + self.filter.on_session_removed(addr.ip()); + } + } + + #[cfg(any(test, feature = "bench"))] + pub fn reset_replay_filter_for_receiver(&mut self, receiver_index: u32) { + let receiver_session_index = SessionIndex::new(receiver_index); + self.state.reset_replay_filter(&receiver_session_index); + } +} diff --git a/monad-wireauth-api/src/context.rs b/monad-wireauth-api/src/context.rs new file mode 100644 index 0000000000..f982140069 --- /dev/null +++ b/monad-wireauth-api/src/context.rs @@ -0,0 +1,108 @@ +use std::{ + cell::RefCell, + rc::Rc, + time::{Duration, Instant, SystemTime}, +}; + +use secp256k1::rand::{rng, rngs::ThreadRng}; + +pub trait Context { + type Rng: secp256k1::rand::Rng + secp256k1::rand::CryptoRng; + + fn system_time(&self) -> SystemTime; + fn duration_since_start(&self) -> Duration; + fn rng(&mut self) -> &mut Self::Rng; +} + +pub struct StdContext { + rng: ThreadRng, + start_instant: Instant, +} + +impl StdContext { + pub fn new() -> Self { + Self { + rng: rng(), + start_instant: Instant::now(), + } + } +} + +impl Context for StdContext { + type Rng = ThreadRng; + + fn system_time(&self) -> SystemTime { + SystemTime::now() + } + + fn duration_since_start(&self) -> Duration { + self.start_instant.elapsed() + } + + fn rng(&mut self) -> &mut Self::Rng { + &mut self.rng + } +} + +impl Default for StdContext { + fn default() -> Self { + Self::new() + } +} + +#[derive(Clone)] +pub struct TestContext { + shared: Rc>, +} + +struct TestContextShared { + rng: ThreadRng, + time_offset: Duration, + start_time: SystemTime, +} + +impl TestContext { + pub fn new() -> Self { + Self { + shared: Rc::new(RefCell::new(TestContextShared { + rng: rng(), + time_offset: Duration::ZERO, + start_time: SystemTime::UNIX_EPOCH, + })), + } + } + + pub fn advance_time(&self, duration: Duration) { + let mut shared = self.shared.borrow_mut(); + shared.time_offset += duration; + } + + pub fn rewind_time(&self, duration: Duration) { + let mut shared = self.shared.borrow_mut(); + shared.time_offset = shared.time_offset.saturating_sub(duration); + } +} + +impl Context for TestContext { + type Rng = ThreadRng; + + fn system_time(&self) -> SystemTime { + let shared = self.shared.borrow(); + shared.start_time + shared.time_offset + } + + fn duration_since_start(&self) -> Duration { + let shared = self.shared.borrow(); + shared.time_offset + } + + fn rng(&mut self) -> &mut Self::Rng { + unsafe { &mut (*self.shared.as_ptr()).rng } + } +} + +impl Default for TestContext { + fn default() -> Self { + Self::new() + } +} diff --git a/monad-wireauth-api/src/cookie.rs b/monad-wireauth-api/src/cookie.rs new file mode 100644 index 0000000000..6c713a2a45 --- /dev/null +++ b/monad-wireauth-api/src/cookie.rs @@ -0,0 +1,288 @@ +use std::{net::SocketAddr, time::Duration}; + +use monad_wireauth_protocol::{ + common::SerializedPublicKey, errors::ProtocolError, messages::CookieReply, +}; + +use crate::error::{ProtocolErrorContext, Result}; + +pub struct Cookies { + nonce_secret: [u8; 32], + cookie_secret: [u8; 32], + nonce: u128, + local_static_public: SerializedPublicKey, + refresh_duration: Duration, +} + +impl Cookies { + pub fn new( + rng: &mut R, + local_static_public: SerializedPublicKey, + refresh_duration: Duration, + ) -> Self { + let mut cookie_secret = [0u8; 32]; + rng.fill_bytes(&mut cookie_secret); + + let mut nonce_secret = [0u8; 32]; + rng.fill_bytes(&mut nonce_secret); + + Self { + cookie_secret, + nonce_secret, + nonce: 0, + local_static_public, + refresh_duration, + } + } + + pub fn create( + &mut self, + addr: SocketAddr, + sender_index: u32, + message: &M, + duration_since_start: Duration, + ) -> Result { + let time_counter = duration_since_start.as_secs() / self.refresh_duration.as_secs(); + let cookie = monad_wireauth_protocol::cookies::generate_cookie( + &self.cookie_secret, + time_counter, + &addr, + ); + + let nonce_counter = self.nonce; + self.nonce += 1; + + monad_wireauth_protocol::cookies::send_cookie_reply( + &self.nonce_secret, + nonce_counter, + &self.local_static_public, + sender_index, + message.mac1().as_ref(), + &cookie, + ) + .map_err(|e| match e { + ProtocolError::Cookie(c) => c.with_addr(addr), + ProtocolError::Crypto(c) => c.with_addr(addr), + ProtocolError::Handshake(h) => h.with_addr(addr), + ProtocolError::Message(m) => m.with_addr(addr), + }) + } + + pub fn verify( + &self, + remote_addr: &SocketAddr, + message: &M, + duration_since_start: Duration, + ) -> Result<()> { + let time_counter = duration_since_start.as_secs() / self.refresh_duration.as_secs(); + + monad_wireauth_protocol::cookies::verify_cookie( + &self.cookie_secret, + time_counter, + remote_addr, + &self.local_static_public, + message, + ) + .map_err(|e| match e { + ProtocolError::Cookie(c) => c.with_addr(*remote_addr), + ProtocolError::Crypto(c) => c.with_addr(*remote_addr), + ProtocolError::Handshake(h) => h.with_addr(*remote_addr), + ProtocolError::Message(m) => m.with_addr(*remote_addr), + }) + } +} + +#[cfg(test)] +mod tests { + use std::time::SystemTime; + + use monad_wireauth_protocol::handshake::send_handshake_init; + use secp256k1::rand::rng; + + use super::*; + + #[test] + fn test_sanity() { + let mut rng = rng(); + let (initiator_public, initiator_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let (responder_public, _responder_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let refresh_duration = Duration::from_secs(120); + + let mut cookies = Cookies::new( + &mut rng, + SerializedPublicKey::from(&responder_public), + refresh_duration, + ); + + let addr: SocketAddr = "192.168.1.100:51820".parse().unwrap(); + let duration_since_start = Duration::from_secs(10); + + let (init_msg, _state) = send_handshake_init( + &mut rng, + SystemTime::now(), + 12345, + &initiator_private, + &SerializedPublicKey::from(&initiator_public), + &SerializedPublicKey::from(&responder_public), + None, + ) + .unwrap(); + + let mut cookie_reply = cookies + .create(addr, 12345, &init_msg, duration_since_start) + .unwrap(); + + let decrypted_cookie = monad_wireauth_protocol::cookies::accept_cookie_reply( + &SerializedPublicKey::from(&responder_public), + &mut cookie_reply, + init_msg.mac1.as_ref(), + ) + .unwrap(); + + let (init_msg_with_cookie, _state) = send_handshake_init( + &mut rng, + SystemTime::now(), + 12346, + &initiator_private, + &SerializedPublicKey::from(&initiator_public), + &SerializedPublicKey::from(&responder_public), + Some(&decrypted_cookie), + ) + .unwrap(); + + let verify_result = cookies.verify(&addr, &init_msg_with_cookie, duration_since_start); + assert!(verify_result.is_ok()); + } + + #[test] + fn test_rotation_invalidates_old_cookie() { + let mut rng = rng(); + let (initiator_public, initiator_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let (responder_public, _responder_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let refresh_duration = Duration::from_secs(10); + + let mut cookies = Cookies::new( + &mut rng, + SerializedPublicKey::from(&responder_public), + refresh_duration, + ); + + let addr: SocketAddr = "192.168.1.100:51820".parse().unwrap(); + + let (init_msg, _state) = send_handshake_init( + &mut rng, + SystemTime::now(), + 12345, + &initiator_private, + &SerializedPublicKey::from(&initiator_public), + &SerializedPublicKey::from(&responder_public), + None, + ) + .unwrap(); + + let duration_at_time_0 = Duration::from_secs(5); + let mut cookie_reply = cookies + .create(addr, 12345, &init_msg, duration_at_time_0) + .unwrap(); + + let decrypted_cookie = monad_wireauth_protocol::cookies::accept_cookie_reply( + &SerializedPublicKey::from(&responder_public), + &mut cookie_reply, + init_msg.mac1.as_ref(), + ) + .unwrap(); + + let (init_msg_with_cookie, _state) = send_handshake_init( + &mut rng, + SystemTime::now(), + 12346, + &initiator_private, + &SerializedPublicKey::from(&initiator_public), + &SerializedPublicKey::from(&responder_public), + Some(&decrypted_cookie), + ) + .unwrap(); + + let verify_before_rotation = + cookies.verify(&addr, &init_msg_with_cookie, duration_at_time_0); + assert!(verify_before_rotation.is_ok()); + + let duration_after_rotation = Duration::from_secs(25); + let verify_after_rotation = + cookies.verify(&addr, &init_msg_with_cookie, duration_after_rotation); + assert!(verify_after_rotation.is_err()); + } + + #[test] + fn test_cookies_different_after_reset() { + let mut rng = rng(); + let (initiator_public, initiator_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let (responder_public, _responder_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let refresh_duration = Duration::from_secs(120); + + let mut cookies1 = Cookies::new( + &mut rng, + SerializedPublicKey::from(&responder_public), + refresh_duration, + ); + + let addr: SocketAddr = "192.168.1.100:51820".parse().unwrap(); + let duration_since_start = Duration::from_secs(10); + + let (init_msg, _state) = send_handshake_init( + &mut rng, + SystemTime::now(), + 12345, + &initiator_private, + &SerializedPublicKey::from(&initiator_public), + &SerializedPublicKey::from(&responder_public), + None, + ) + .unwrap(); + + let cookie_reply_1 = cookies1 + .create(addr, 12345, &init_msg, duration_since_start) + .unwrap(); + let cookie_reply_2 = cookies1 + .create(addr, 12345, &init_msg, duration_since_start) + .unwrap(); + let cookie_reply_3 = cookies1 + .create(addr, 12345, &init_msg, duration_since_start) + .unwrap(); + + let mut cookies2 = Cookies::new( + &mut rng, + SerializedPublicKey::from(&responder_public), + refresh_duration, + ); + + let cookie_reply_4 = cookies2 + .create(addr, 12345, &init_msg, duration_since_start) + .unwrap(); + let cookie_reply_5 = cookies2 + .create(addr, 12345, &init_msg, duration_since_start) + .unwrap(); + let cookie_reply_6 = cookies2 + .create(addr, 12345, &init_msg, duration_since_start) + .unwrap(); + + assert_ne!( + cookie_reply_1.encrypted_cookie, + cookie_reply_4.encrypted_cookie + ); + assert_ne!( + cookie_reply_2.encrypted_cookie, + cookie_reply_5.encrypted_cookie + ); + assert_ne!( + cookie_reply_3.encrypted_cookie, + cookie_reply_6.encrypted_cookie + ); + } +} diff --git a/monad-wireauth-api/src/error.rs b/monad-wireauth-api/src/error.rs new file mode 100644 index 0000000000..97307cea2f --- /dev/null +++ b/monad-wireauth-api/src/error.rs @@ -0,0 +1,229 @@ +use std::net::SocketAddr; + +use monad_wireauth_protocol::errors::{CookieError, CryptoError, HandshakeError, MessageError}; +use monad_wireauth_session::{SessionError, SessionIndex}; +use thiserror::Error as ThisError; + +#[derive(ThisError, Debug)] +pub enum Error { + #[error("MAC1 verification failed from {addr}: {source}")] + Mac1VerificationFailed { + addr: SocketAddr, + #[source] + source: CryptoError, + }, + + #[error("MAC2 verification failed from {addr}: {source}")] + Mac2VerificationFailed { + addr: SocketAddr, + #[source] + source: CryptoError, + }, + + #[error("static key decryption failed from {addr}: {source}")] + StaticKeyDecryptionFailed { + addr: SocketAddr, + #[source] + source: CryptoError, + }, + + #[error("timestamp decryption failed from {addr}: {source}")] + TimestampDecryptionFailed { + addr: SocketAddr, + #[source] + source: CryptoError, + }, + + #[error("empty message decryption failed from {addr}: {source}")] + EmptyMessageDecryptionFailed { + addr: SocketAddr, + #[source] + source: CryptoError, + }, + + #[error( + "timestamp replay detected from {addr}: received timestamp is not newer than expected" + )] + TimestampReplay { addr: SocketAddr }, + + #[error("invalid message type {msg_type:#04x} from {addr}")] + InvalidMessageType { msg_type: u32, addr: SocketAddr }, + + #[error("invalid receiver index {index} from {addr}")] + InvalidReceiverIndex { + index: SessionIndex, + addr: SocketAddr, + }, + + #[error("buffer too small for message from {addr}: need {required} bytes, got {actual}")] + BufferTooSmall { + addr: SocketAddr, + required: usize, + actual: usize, + }, + + #[error("invalid packet header: malformed or unrecognized format from {addr}")] + InvalidPacketHeader { addr: SocketAddr }, + + #[error("MAC verification failed: data packet integrity check failed from {addr}")] + DataMacVerificationFailed { addr: SocketAddr }, + + #[error("cookie decryption failed from {addr}: {source}")] + CookieDecryptionFailed { + addr: SocketAddr, + #[source] + source: CryptoError, + }, + + #[error("invalid cookie MAC from {addr}: {source}")] + InvalidCookieMac { + addr: SocketAddr, + #[source] + source: CryptoError, + }, + + #[error("invalid key from {addr}: {error}")] + InvalidKey { addr: SocketAddr, error: String }, + + #[error("ECDH operation failed: unable to compute shared secret")] + EcdhFailed, + + #[error("session not found")] + SessionNotFound, + + #[error("session index exhausted")] + SessionIndexExhausted, + + #[error("session not established for address {addr}")] + SessionNotEstablishedForAddress { addr: SocketAddr }, + + #[error("session timeout from {addr}")] + SessionTimeout { addr: SocketAddr }, + + #[error("replay attack detected: packet counter already seen from {addr}")] + ReplayAttack { addr: SocketAddr }, + + #[error("invalid timestamp format: unable to parse TAI64N from {size} bytes")] + InvalidTimestamp { size: usize }, + + #[error("empty packet from {addr}")] + EmptyPacket { addr: SocketAddr }, + + #[error("session index not found: {index}")] + SessionIndexNotFound { index: SessionIndex }, +} + +pub type Result = std::result::Result; + +pub trait SessionErrorContext { + fn with_addr(self, addr: SocketAddr) -> Error; +} + +impl SessionErrorContext for SessionError { + fn with_addr(self, addr: SocketAddr) -> Error { + match self { + SessionError::InvalidHandshake(e) => { + use monad_wireauth_protocol::errors::ProtocolError; + match e { + ProtocolError::Handshake(h) => h.with_addr(addr), + ProtocolError::Crypto(c) => c.with_addr(addr), + ProtocolError::Message(m) => m.with_addr(addr), + ProtocolError::Cookie(c) => c.with_addr(addr), + } + } + SessionError::NotEstablished => Error::SessionNotEstablishedForAddress { addr }, + SessionError::InvalidPacket(e) => e.with_addr(addr), + SessionError::CryptoError(e) => e.with_addr(addr), + SessionError::InvalidMac(e) => match e { + CryptoError::MacVerificationFailed => Error::DataMacVerificationFailed { addr }, + CryptoError::InvalidKey(err) => Error::InvalidKey { + addr, + error: err.to_string(), + }, + CryptoError::EcdhFailed => Error::EcdhFailed, + }, + SessionError::InvalidCookie(e) => e.with_addr(addr), + SessionError::ReplayAttack { .. } => Error::ReplayAttack { addr }, + SessionError::TimestampReplay => Error::TimestampReplay { addr }, + SessionError::SessionTimeout => Error::SessionTimeout { addr }, + } + } +} + +pub trait ProtocolErrorContext { + fn with_addr(self, addr: SocketAddr) -> Error; +} + +impl ProtocolErrorContext for HandshakeError { + fn with_addr(self, addr: SocketAddr) -> Error { + match self { + HandshakeError::Mac1VerificationFailed(source) => { + Error::Mac1VerificationFailed { addr, source } + } + HandshakeError::Mac2VerificationFailed(source) => { + Error::Mac2VerificationFailed { addr, source } + } + HandshakeError::StaticKeyDecryptionFailed(source) => { + Error::StaticKeyDecryptionFailed { addr, source } + } + HandshakeError::TimestampDecryptionFailed(source) => { + Error::TimestampDecryptionFailed { addr, source } + } + HandshakeError::EmptyMessageDecryptionFailed(source) => { + Error::EmptyMessageDecryptionFailed { addr, source } + } + HandshakeError::TimestampReplay { .. } => Error::TimestampReplay { addr }, + HandshakeError::InvalidMessageType(msg_type) => { + Error::InvalidMessageType { msg_type, addr } + } + HandshakeError::InvalidReceiverIndex { index } => { + Error::InvalidReceiverIndex { index, addr } + } + HandshakeError::InvalidTimestamp { size } => Error::InvalidTimestamp { size }, + } + } +} + +impl ProtocolErrorContext for CryptoError { + fn with_addr(self, addr: SocketAddr) -> Error { + match self { + CryptoError::MacVerificationFailed => Error::DataMacVerificationFailed { addr }, + CryptoError::InvalidKey(e) => Error::InvalidKey { + addr, + error: e.to_string(), + }, + CryptoError::EcdhFailed => Error::EcdhFailed, + } + } +} + +impl ProtocolErrorContext for MessageError { + fn with_addr(self, addr: SocketAddr) -> Error { + match self { + MessageError::BufferTooSmall { required, actual } => Error::BufferTooSmall { + addr, + required, + actual, + }, + MessageError::InvalidMessageType(msg_type) => { + Error::InvalidMessageType { msg_type, addr } + } + MessageError::InvalidHeader => Error::InvalidPacketHeader { addr }, + MessageError::InvalidDataPacketHeader => Error::InvalidPacketHeader { addr }, + } + } +} + +impl ProtocolErrorContext for CookieError { + fn with_addr(self, addr: SocketAddr) -> Error { + match self { + CookieError::InvalidMessageType(msg_type) => { + Error::InvalidMessageType { msg_type, addr } + } + CookieError::CookieDecryptionFailed(source) => { + Error::CookieDecryptionFailed { addr, source } + } + CookieError::InvalidCookieMac(source) => Error::InvalidCookieMac { addr, source }, + } + } +} diff --git a/monad-wireauth-api/src/filter.rs b/monad-wireauth-api/src/filter.rs new file mode 100644 index 0000000000..c3d59154bc --- /dev/null +++ b/monad-wireauth-api/src/filter.rs @@ -0,0 +1,454 @@ +use std::{ + collections::HashMap, + net::{IpAddr, SocketAddr}, + num::NonZeroUsize, + time::Duration, +}; + +use lru::LruCache; +use tracing::debug; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FilterAction { + Pass, + SendCookie, + Drop, +} + +pub struct Filter { + counter: u64, + last_reset: Duration, + handshake_rate_limit: u64, + handshake_rate_reset_interval: Duration, + ip_request_history: LruCache, + ip_rate_limit_window: Duration, + ip_session_counts: HashMap, + max_sessions_per_ip: usize, + low_watermark_sessions: usize, + high_watermark_sessions: usize, + total_sessions: usize, +} + +impl Filter { + pub fn new( + handshake_rate_limit: u64, + handshake_rate_reset_interval: Duration, + ip_rate_limit_window: Duration, + ip_history_capacity: usize, + max_sessions_per_ip: usize, + low_watermark_sessions: usize, + high_watermark_sessions: usize, + ) -> Self { + Self { + counter: 0, + last_reset: Duration::ZERO, + handshake_rate_limit, + handshake_rate_reset_interval, + ip_request_history: LruCache::new(NonZeroUsize::new(ip_history_capacity).unwrap()), + ip_rate_limit_window, + ip_session_counts: HashMap::new(), + max_sessions_per_ip, + low_watermark_sessions, + high_watermark_sessions, + total_sessions: 0, + } + } + + pub fn tick(&mut self, duration_since_start: Duration) { + if duration_since_start.saturating_sub(self.last_reset) + >= self.handshake_rate_reset_interval + { + self.counter = 0; + self.last_reset = duration_since_start; + } + } + + pub fn next_reset_time(&self) -> Duration { + self.last_reset + self.handshake_rate_reset_interval + } + + pub fn apply( + &mut self, + remote_addr: SocketAddr, + duration_since_start: Duration, + cookie_valid: bool, + ) -> FilterAction { + self.counter += 1; + + if self.total_sessions >= self.high_watermark_sessions { + debug!( + remote_addr = %remote_addr, + sessions = self.total_sessions, + high_watermark = self.high_watermark_sessions, + "high load - rejecting new handshake" + ); + return FilterAction::Drop; + } + + let under_load = self.counter >= self.handshake_rate_limit; + + if under_load { + debug!( + remote_addr = %remote_addr, + counter = self.counter, + rate_limit = self.handshake_rate_limit, + "rate limit exceeded - dropping handshake" + ); + return FilterAction::Drop; + } + + if self.total_sessions < self.low_watermark_sessions { + return FilterAction::Pass; + } + + if !cookie_valid { + return FilterAction::SendCookie; + } + + let ip = remote_addr.ip(); + let window_start = duration_since_start.saturating_sub(self.ip_rate_limit_window); + + if let Some(last_time) = self.ip_request_history.get_mut(&ip) { + if *last_time >= window_start { + debug!(ip = %ip, "ip rate limit exceeded"); + return FilterAction::Drop; + } + *last_time = duration_since_start; + } else { + self.ip_request_history.put(ip, duration_since_start); + } + + let session_count = self.ip_session_counts.get(&ip).copied().unwrap_or(0); + if session_count >= self.max_sessions_per_ip { + debug!( + ip = %ip, + max = self.max_sessions_per_ip, + "too many sessions for ip" + ); + return FilterAction::Drop; + } + + FilterAction::Pass + } + + pub fn on_session_added(&mut self, ip: IpAddr) { + *self.ip_session_counts.entry(ip).or_insert(0) += 1; + self.total_sessions += 1; + } + + pub fn on_session_removed(&mut self, ip: IpAddr) { + if let Some(count) = self.ip_session_counts.get_mut(&ip) { + *count = count.saturating_sub(1); + if *count == 0 { + self.ip_session_counts.remove(&ip); + } + } + self.total_sessions = self.total_sessions.saturating_sub(1); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn default_filter() -> Filter { + Filter::new( + 100, + Duration::from_secs(60), + Duration::from_secs(60), + 1000, + 10, + 50, + 100, + ) + } + + #[test] + fn test_basic_pass_no_limits() { + let mut filter = default_filter(); + let addr = "127.0.0.1:8080".parse().unwrap(); + let action = filter.apply(addr, Duration::from_secs(1), false); + assert_eq!(action, FilterAction::Pass); + } + + #[test] + fn test_high_watermark_drops() { + let high_watermark = 10; + let mut filter = Filter::new( + 100, + Duration::from_secs(60), + Duration::from_secs(60), + 1000, + 10, + 5, + high_watermark, + ); + for i in 0..high_watermark { + filter.on_session_added(format!("10.0.0.{}", i).parse().unwrap()); + } + let addr = "127.0.0.1:8080".parse().unwrap(); + let action = filter.apply(addr, Duration::from_secs(1), false); + assert_eq!(action, FilterAction::Drop); + } + + #[test] + fn test_between_watermarks_requires_cookie() { + let low_watermark = 5; + let mut filter = Filter::new( + 100, + Duration::from_secs(60), + Duration::from_secs(60), + 1000, + 10, + low_watermark, + 10, + ); + for i in 0..low_watermark { + filter.on_session_added(format!("10.0.0.{}", i).parse().unwrap()); + } + let addr = "127.0.0.1:8080".parse().unwrap(); + let action = filter.apply(addr, Duration::from_secs(1), false); + assert_eq!(action, FilterAction::SendCookie); + } + + #[test] + fn test_between_watermarks_passes_with_cookie() { + let low_watermark = 5; + let mut filter = Filter::new( + 100, + Duration::from_secs(60), + Duration::from_secs(60), + 1000, + 10, + low_watermark, + 10, + ); + for i in 0..low_watermark { + filter.on_session_added(format!("10.0.0.{}", i).parse().unwrap()); + } + let addr = "127.0.0.1:8080".parse().unwrap(); + let action = filter.apply(addr, Duration::from_secs(1), true); + assert_eq!(action, FilterAction::Pass); + } + + #[test] + fn test_handshake_rate_limit_drops() { + let handshake_rate_limit = 5; + let mut filter = Filter::new( + handshake_rate_limit, + Duration::from_secs(60), + Duration::from_secs(60), + 1000, + 10, + 50, + 100, + ); + let addr = "127.0.0.1:8080".parse().unwrap(); + for _ in 0..handshake_rate_limit { + filter.apply(addr, Duration::from_secs(1), false); + } + let action = filter.apply(addr, Duration::from_secs(1), false); + assert_eq!(action, FilterAction::Drop); + } + + #[test] + fn test_handshake_rate_limit_drops_with_cookie() { + let handshake_rate_limit = 5; + let mut filter = Filter::new( + handshake_rate_limit, + Duration::from_secs(60), + Duration::from_secs(60), + 1000, + 10, + 50, + 100, + ); + let addr = "127.0.0.1:8080".parse().unwrap(); + for _ in 0..handshake_rate_limit { + filter.apply(addr, Duration::from_secs(1), false); + } + let action = filter.apply(addr, Duration::from_secs(1), true); + assert_eq!(action, FilterAction::Drop); + } + + #[test] + fn test_tick_resets_counter() { + let handshake_rate_limit = 5; + let mut filter = Filter::new( + handshake_rate_limit, + Duration::from_secs(1), + Duration::from_secs(60), + 1000, + 10, + 50, + 100, + ); + let addr = "127.0.0.1:8080".parse().unwrap(); + for _ in 0..handshake_rate_limit { + filter.apply(addr, Duration::from_secs(0), false); + } + filter.tick(Duration::from_secs(1)); + let action = filter.apply(addr, Duration::from_secs(1), false); + assert_eq!(action, FilterAction::Pass); + } + + #[test] + fn test_tick_does_not_reset_before_interval() { + let handshake_rate_limit = 5; + let mut filter = Filter::new( + handshake_rate_limit, + Duration::from_secs(10), + Duration::from_secs(60), + 1000, + 10, + 50, + 100, + ); + let addr = "127.0.0.1:8080".parse().unwrap(); + for _ in 0..handshake_rate_limit { + filter.apply(addr, Duration::from_secs(0), false); + } + filter.tick(Duration::from_secs(5)); + let action = filter.apply(addr, Duration::from_secs(5), false); + assert_eq!(action, FilterAction::Drop); + } + + #[test] + fn test_ip_rate_limit_within_window() { + let low_watermark = 5; + let mut filter = Filter::new( + 100, + Duration::from_secs(60), + Duration::from_secs(5), + 1000, + 10, + low_watermark, + 10, + ); + for i in 0..low_watermark { + filter.on_session_added(format!("10.0.0.{}", i).parse().unwrap()); + } + let addr = "127.0.0.1:8080".parse().unwrap(); + filter.apply(addr, Duration::from_secs(0), true); + let action = filter.apply(addr, Duration::from_secs(3), true); + assert_eq!(action, FilterAction::Drop); + } + + #[test] + fn test_ip_rate_limit_after_window() { + let low_watermark = 5; + let mut filter = Filter::new( + 100, + Duration::from_secs(60), + Duration::from_secs(5), + 1000, + 10, + low_watermark, + 10, + ); + for i in 0..low_watermark { + filter.on_session_added(format!("10.0.0.{}", i).parse().unwrap()); + } + let addr = "127.0.0.1:8080".parse().unwrap(); + filter.apply(addr, Duration::from_secs(0), true); + let action = filter.apply(addr, Duration::from_secs(6), true); + assert_eq!(action, FilterAction::Pass); + } + + #[test] + fn test_max_sessions_per_ip_drops() { + let low_watermark = 5; + let max_sessions_per_ip = 2; + let mut filter = Filter::new( + 100, + Duration::from_secs(60), + Duration::from_secs(60), + 1000, + max_sessions_per_ip, + low_watermark, + 10, + ); + for i in 0..low_watermark { + filter.on_session_added(format!("10.0.0.{}", i).parse().unwrap()); + } + let ip: IpAddr = "192.168.1.1".parse().unwrap(); + for _ in 0..max_sessions_per_ip { + filter.on_session_added(ip); + } + let addr = "192.168.1.1:8080".parse().unwrap(); + let action = filter.apply(addr, Duration::from_secs(1), true); + assert_eq!(action, FilterAction::Drop); + } + + #[test] + fn test_max_sessions_per_ip_passes_under_limit() { + let low_watermark = 5; + let max_sessions_per_ip = 2; + let mut filter = Filter::new( + 100, + Duration::from_secs(60), + Duration::from_secs(60), + 1000, + max_sessions_per_ip, + low_watermark, + 10, + ); + for i in 0..low_watermark { + filter.on_session_added(format!("10.0.0.{}", i).parse().unwrap()); + } + let ip: IpAddr = "192.168.1.1".parse().unwrap(); + filter.on_session_added(ip); + let addr = "192.168.1.1:8080".parse().unwrap(); + let action = filter.apply(addr, Duration::from_secs(1), true); + assert_eq!(action, FilterAction::Pass); + } + + #[test] + fn test_combined_rate_limit_and_watermark() { + let handshake_rate_limit = 5; + let low_watermark = 5; + let mut filter = Filter::new( + handshake_rate_limit, + Duration::from_secs(60), + Duration::from_secs(60), + 1000, + 10, + low_watermark, + 10, + ); + for i in 0..low_watermark { + filter.on_session_added(format!("10.0.0.{}", i).parse().unwrap()); + } + let addr = "127.0.0.1:8080".parse().unwrap(); + for _ in 0..handshake_rate_limit { + filter.apply(addr, Duration::from_secs(1), false); + } + let action = filter.apply(addr, Duration::from_secs(1), false); + assert_eq!(action, FilterAction::Drop); + } + + #[test] + fn test_lru_cache_eviction() { + let low_watermark = 5; + let mut filter = Filter::new( + 100, + Duration::from_secs(60), + Duration::from_secs(5), + 2, + 10, + low_watermark, + 10, + ); + for i in 0..low_watermark { + filter.on_session_added(format!("10.0.0.{}", i).parse().unwrap()); + } + let addr1 = "192.168.1.1:8080".parse().unwrap(); + let addr2 = "192.168.1.2:8080".parse().unwrap(); + let addr3 = "192.168.1.3:8080".parse().unwrap(); + filter.apply(addr1, Duration::from_secs(0), true); + filter.apply(addr2, Duration::from_secs(1), true); + filter.apply(addr3, Duration::from_secs(2), true); + let action = filter.apply(addr1, Duration::from_secs(3), true); + assert_eq!(action, FilterAction::Pass); + } +} diff --git a/monad-wireauth-api/src/lib.rs b/monad-wireauth-api/src/lib.rs new file mode 100644 index 0000000000..39c62645da --- /dev/null +++ b/monad-wireauth-api/src/lib.rs @@ -0,0 +1,11 @@ +mod api; +mod context; +mod cookie; +mod error; +mod filter; +mod state; + +pub use api::API; +pub use context::{Context, StdContext, TestContext}; +pub use error::{Error, Result}; +pub use monad_wireauth_session::*; diff --git a/monad-wireauth-api/src/state.rs b/monad-wireauth-api/src/state.rs new file mode 100644 index 0000000000..a1a9eb622a --- /dev/null +++ b/monad-wireauth-api/src/state.rs @@ -0,0 +1,1150 @@ +use std::{ + collections::{BTreeSet, HashMap, HashSet}, + net::SocketAddr, + time::{Duration, SystemTime}, +}; + +use monad_wireauth_protocol::common::SerializedPublicKey; +use monad_wireauth_session::{Config, SessionIndex}; + +use crate::{InitiatorState, ResponderState, TransportState}; + +#[derive(Default)] +struct EstablishedSessions { + initiator: Option<(SessionIndex, Duration)>, + responder: Option<(SessionIndex, Duration)>, +} + +impl EstablishedSessions { + fn get_latest(&self) -> Option { + match (&self.initiator, &self.responder) { + (Some((id0, ts0)), Some((id1, ts1))) => { + if ts0 >= ts1 { + Some(*id0) + } else { + Some(*id1) + } + } + (Some((id, _)), None) => Some(*id), + (None, Some((id, _))) => Some(*id), + (None, None) => None, + } + } + + fn is_empty(&self) -> bool { + self.initiator.is_none() && self.responder.is_none() + } +} + +pub(crate) struct SessionIndexReservation<'a> { + state: &'a mut State, + index: SessionIndex, +} + +impl<'a> SessionIndexReservation<'a> { + pub(crate) fn index(&self) -> SessionIndex { + self.index + } + + pub(crate) fn commit(self) { + self.state.next_session_index = self.index; + self.state.next_session_index.increment(); + self.state.allocated_indices.insert(self.index); + } +} + +pub struct State { + initiating_sessions: HashMap, + responding_sessions: HashMap, + transport_sessions: HashMap, + last_established_session_by_public_key: HashMap, + last_established_session_by_socket: HashMap, + allocated_indices: HashSet, + next_session_index: SessionIndex, + initiated_session_by_peer: HashMap, + accepted_sessions_by_peer: BTreeSet<(SerializedPublicKey, SessionIndex)>, +} + +impl State { + pub fn new() -> Self { + Self { + initiating_sessions: HashMap::new(), + responding_sessions: HashMap::new(), + transport_sessions: HashMap::new(), + last_established_session_by_public_key: HashMap::new(), + last_established_session_by_socket: HashMap::new(), + allocated_indices: HashSet::new(), + next_session_index: SessionIndex::new(0), + initiated_session_by_peer: HashMap::new(), + accepted_sessions_by_peer: BTreeSet::new(), + } + } + + pub fn get_transport(&self, session_index: &SessionIndex) -> Option<&TransportState> { + self.transport_sessions.get(session_index) + } + + pub fn get_transport_mut( + &mut self, + session_index: &SessionIndex, + ) -> Option<&mut TransportState> { + self.transport_sessions.get_mut(session_index) + } + + pub fn get_session_id_by_public_key( + &self, + public_key: &SerializedPublicKey, + ) -> Option { + self.last_established_session_by_public_key + .get(public_key) + .and_then(|sessions| sessions.get_latest()) + } + + pub fn get_session_id_by_socket(&self, socket_addr: &SocketAddr) -> Option { + self.last_established_session_by_socket + .get(socket_addr) + .and_then(|sessions| sessions.get_latest()) + } + + pub(crate) fn reserve_session_index(&mut self) -> Option> { + let start_index = self.next_session_index; + let mut candidate = self.next_session_index; + + loop { + if !self.allocated_indices.contains(&candidate) { + return Some(SessionIndexReservation { + state: self, + index: candidate, + }); + } + + candidate.increment(); + if candidate == start_index { + return None; + } + } + } + + pub fn handle_established( + &mut self, + session_id: SessionIndex, + transport: TransportState, + _config: &Config, + ) -> Vec { + let remote_public_key = &transport.remote_public_key; + let remote_addr = transport.remote_addr; + let created = transport.created; + let is_initiator = transport.is_initiator; + + let key_bytes = SerializedPublicKey::from(remote_public_key); + + let mut replaced_sessions = Vec::new(); + + let sessions = self + .last_established_session_by_public_key + .entry(key_bytes) + .or_default(); + + if is_initiator { + if let Some((existing_id, _)) = sessions.initiator { + replaced_sessions.push(existing_id); + } + sessions.initiator = Some((session_id, created)); + } else { + if let Some((existing_id, _)) = sessions.responder { + replaced_sessions.push(existing_id); + } + sessions.responder = Some((session_id, created)); + } + + let sessions = self + .last_established_session_by_socket + .entry(remote_addr) + .or_default(); + + if is_initiator { + if let Some((existing_id, _)) = sessions.initiator { + if !replaced_sessions.contains(&existing_id) { + replaced_sessions.push(existing_id); + } + } + sessions.initiator = Some((session_id, created)); + } else { + if let Some((existing_id, _)) = sessions.responder { + if !replaced_sessions.contains(&existing_id) { + replaced_sessions.push(existing_id); + } + } + sessions.responder = Some((session_id, created)); + } + + self.transport_sessions.insert(session_id, transport); + + replaced_sessions + } + + pub fn handle_terminate( + &mut self, + session_id: SessionIndex, + remote_public_key: &SerializedPublicKey, + remote_addr: SocketAddr, + ) { + let transport = self.transport_sessions.remove(&session_id); + self.initiating_sessions.remove(&session_id); + self.responding_sessions.remove(&session_id); + self.allocated_indices.remove(&session_id); + + if let Some(transport) = transport { + if let Some(sessions) = self + .last_established_session_by_socket + .get_mut(&remote_addr) + { + if transport.is_initiator { + if sessions.initiator.map(|(id, _)| id) == Some(session_id) { + sessions.initiator = None; + } + } else if sessions.responder.map(|(id, _)| id) == Some(session_id) { + sessions.responder = None; + } + + if sessions.is_empty() { + self.last_established_session_by_socket.remove(&remote_addr); + } + } + + if let Some(sessions) = self + .last_established_session_by_public_key + .get_mut(remote_public_key) + { + if transport.is_initiator { + if sessions.initiator.map(|(id, _)| id) == Some(session_id) { + sessions.initiator = None; + } + } else if sessions.responder.map(|(id, _)| id) == Some(session_id) { + sessions.responder = None; + } + + if sessions.is_empty() { + self.last_established_session_by_public_key + .remove(remote_public_key); + } + } + } + + if let Some(&initiated_id) = self.initiated_session_by_peer.get(remote_public_key) { + if initiated_id == session_id { + self.initiated_session_by_peer.remove(remote_public_key); + } + } + + self.accepted_sessions_by_peer + .remove(&(*remote_public_key, session_id)); + } + + pub fn get_initiator(&self, session_index: &SessionIndex) -> Option<&InitiatorState> { + self.initiating_sessions.get(session_index) + } + + pub fn get_initiator_mut( + &mut self, + session_index: &SessionIndex, + ) -> Option<&mut InitiatorState> { + self.initiating_sessions.get_mut(session_index) + } + + pub fn get_responder(&self, session_index: &SessionIndex) -> Option<&ResponderState> { + self.responding_sessions.get(session_index) + } + + pub fn get_responder_mut( + &mut self, + session_index: &SessionIndex, + ) -> Option<&mut ResponderState> { + self.responding_sessions.get_mut(session_index) + } + + pub fn remove_initiator(&mut self, session_index: &SessionIndex) -> Option { + self.initiating_sessions.remove(session_index) + } + + pub fn remove_responder(&mut self, session_index: &SessionIndex) -> Option { + self.responding_sessions.remove(session_index) + } + + pub fn insert_initiator( + &mut self, + session_index: SessionIndex, + session: InitiatorState, + remote_key: SerializedPublicKey, + ) { + self.initiating_sessions.insert(session_index, session); + self.initiated_session_by_peer + .insert(remote_key, session_index); + } + + pub fn insert_responder( + &mut self, + session_index: SessionIndex, + session: ResponderState, + remote_key: SerializedPublicKey, + ) { + self.responding_sessions.insert(session_index, session); + self.accepted_sessions_by_peer + .insert((remote_key, session_index)); + } + + pub fn lookup_cookie_from_initiated_sessions( + &self, + remote_key: &SerializedPublicKey, + ) -> Option<[u8; 16]> { + self.initiated_session_by_peer + .get(remote_key) + .and_then(|&session_id| { + self.initiating_sessions + .get(&session_id) + .and_then(|s| s.stored_cookie()) + }) + } + + pub fn lookup_cookie_from_accepted_sessions( + &self, + remote_key: SerializedPublicKey, + ) -> Option<[u8; 16]> { + self.accepted_sessions_by_peer + .range((remote_key, SessionIndex::new(0))..=(remote_key, SessionIndex::new(u32::MAX))) + .find_map(|(_, session_id)| { + self.responding_sessions + .get(session_id) + .and_then(|s| s.stored_cookie()) + }) + } + + pub fn get_max_timestamp(&self, remote_key: &SerializedPublicKey) -> Option { + let accepted_max = self + .accepted_sessions_by_peer + .range((*remote_key, SessionIndex::new(0))..=(*remote_key, SessionIndex::new(u32::MAX))) + .filter_map(|(_, session_id)| self.responding_sessions.get(session_id)) + .filter_map(|s| s.initiator_system_time()) + .max(); + + let open_max = self + .last_established_session_by_public_key + .get(remote_key) + .and_then(|sessions| sessions.responder) + .map(|(session_id, _)| session_id) + .and_then(|session_id| self.transport_sessions.get(&session_id)) + .and_then(|s| s.initiator_system_time()); + + match (accepted_max, open_max) { + (Some(a), Some(o)) => Some(a.max(o)), + (Some(a), None) => Some(a), + (None, Some(o)) => Some(o), + (None, None) => None, + } + } + + pub fn terminate_by_public_key(&mut self, public_key: &SerializedPublicKey) -> Vec { + let mut session_ids = HashSet::new(); + + if let Some(&session_id) = self.initiated_session_by_peer.get(public_key) { + session_ids.insert(session_id); + } + + for (key, session_id) in self + .accepted_sessions_by_peer + .range((*public_key, SessionIndex::new(0))..=(*public_key, SessionIndex::new(u32::MAX))) + { + if key == public_key { + session_ids.insert(*session_id); + } + } + + if let Some(sessions) = self.last_established_session_by_public_key.get(public_key) { + if let Some((session_id, _)) = sessions.initiator { + session_ids.insert(session_id); + } + if let Some((session_id, _)) = sessions.responder { + session_ids.insert(session_id); + } + } + + let mut terminated_addrs = Vec::new(); + + for session_id in session_ids { + let remote_addr = self + .transport_sessions + .get(&session_id) + .map(|t| t.remote_addr) + .or_else(|| { + self.initiating_sessions + .get(&session_id) + .map(|i| i.remote_addr) + }) + .or_else(|| { + self.responding_sessions + .get(&session_id) + .map(|r| r.remote_addr) + }); + + if let Some(addr) = remote_addr { + self.handle_terminate(session_id, public_key, addr); + terminated_addrs.push(addr); + } + } + + terminated_addrs + } + + #[cfg(any(test, feature = "bench"))] + pub fn reset_replay_filter(&mut self, session_id: &SessionIndex) { + if let Some(session) = self.transport_sessions.get_mut(session_id) { + session.reset_replay_filter(); + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + net::{IpAddr, Ipv4Addr}, + time::SystemTime, + }; + + use monad_wireauth_protocol::{common::PublicKey, crypto}; + use secp256k1::rand::rng; + + use super::*; + + fn create_dummy_hash_output() -> monad_wireauth_protocol::common::HashOutput { + monad_wireauth_protocol::common::HashOutput([0u8; 32]) + } + + fn create_test_transport( + session_index: SessionIndex, + remote_public_key: &PublicKey, + remote_addr: SocketAddr, + is_initiator: bool, + ) -> TransportState { + let hash1 = create_dummy_hash_output(); + let hash2 = create_dummy_hash_output(); + let send_key = monad_wireauth_protocol::common::CipherKey::from(&hash1); + let recv_key = monad_wireauth_protocol::common::CipherKey::from(&hash2); + let common = monad_wireauth_session::SessionState::new( + remote_addr, + remote_public_key.clone(), + session_index, + Duration::ZERO, + 0, + None, + is_initiator, + ); + TransportState::new(session_index, send_key, recv_key, common) + } + + fn create_test_initiator(remote_public_key: &PublicKey) -> InitiatorState { + let mut rng = rng(); + let (public_key, private_key) = crypto::generate_keypair(&mut rng).unwrap(); + let config = Config::default(); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let local_index = SessionIndex::new(1); + InitiatorState::new( + &mut rng, + SystemTime::now(), + Duration::ZERO, + &config, + local_index, + &private_key, + public_key, + remote_public_key.clone(), + remote_addr, + None, + 0, + ) + .unwrap() + .0 + } + + fn create_test_responder( + remote_public_key: &PublicKey, + _cookie: Option<[u8; 16]>, + ) -> ResponderState { + let mut rng = rng(); + let (_local_public_key, _local_private_key) = crypto::generate_keypair(&mut rng).unwrap(); + + let remote_index = SessionIndex::new(42); + let sender_index = SessionIndex::new(1); + + let hash1 = create_dummy_hash_output(); + let hash2 = create_dummy_hash_output(); + + let (ephemeral_public, ephemeral_private) = crypto::generate_keypair(&mut rng).unwrap(); + + let handshake_state = monad_wireauth_protocol::handshake::HandshakeState { + hash: hash1.into(), + chaining_key: hash2.into(), + remote_static: Some(SerializedPublicKey::from(remote_public_key)), + receiver_index: remote_index.as_u32(), + sender_index: sender_index.as_u32(), + ephemeral_private: Some(ephemeral_private), + remote_ephemeral: Some(SerializedPublicKey::from(&ephemeral_public)), + }; + + let validated_init = monad_wireauth_session::ValidatedHandshakeInit { + handshake_state, + remote_public_key: remote_public_key.clone(), + system_time: SystemTime::now(), + remote_index, + }; + + let config = Config::default(); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), 51820); + let local_index = SessionIndex::new(2); + + ResponderState::new( + &mut rng, + Duration::ZERO, + &config, + local_index, + None, + validated_init, + remote_addr, + ) + .unwrap() + .0 + } + + #[test] + fn test_new() { + let state = State::new(); + assert_eq!(state.next_session_index, SessionIndex::new(0)); + assert!(state.allocated_indices.is_empty()); + assert!(state.transport_sessions.is_empty()); + assert!(state.initiating_sessions.is_empty()); + assert!(state.responding_sessions.is_empty()); + } + + #[test] + fn test_allocate_session_index() { + let mut state = State::new(); + + let reservation0 = state.reserve_session_index().unwrap(); + let idx0 = reservation0.index(); + reservation0.commit(); + + let reservation1 = state.reserve_session_index().unwrap(); + let idx1 = reservation1.index(); + reservation1.commit(); + + let reservation2 = state.reserve_session_index().unwrap(); + let idx2 = reservation2.index(); + reservation2.commit(); + + assert_eq!(idx0, SessionIndex::new(0)); + assert_eq!(idx1, SessionIndex::new(1)); + assert_eq!(idx2, SessionIndex::new(2)); + assert!(state.allocated_indices.contains(&idx0)); + assert!(state.allocated_indices.contains(&idx1)); + assert!(state.allocated_indices.contains(&idx2)); + } + + #[test] + fn test_allocate_session_index_skips_allocated() { + let mut state = State::new(); + + let reservation0 = state.reserve_session_index().unwrap(); + let idx0 = reservation0.index(); + reservation0.commit(); + + state.allocated_indices.remove(&idx0); + state.next_session_index = SessionIndex::new(0); + + let reservation1 = state.reserve_session_index().unwrap(); + let idx1 = reservation1.index(); + reservation1.commit(); + + assert_eq!(idx1, SessionIndex::new(0)); + } + + #[test] + fn test_get_nonexistent_transport() { + let state = State::new(); + let session_id = SessionIndex::new(42); + assert!(state.get_transport(&session_id).is_none()); + } + + #[test] + fn test_get_transport_mut() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(100); + let config = Config::default(); + + let transport = create_test_transport(session_id, &public_key, remote_addr, true); + state.handle_established(session_id, transport, &config); + + assert!(state.get_transport_mut(&session_id).is_some()); + assert!(state.get_transport_mut(&SessionIndex::new(999)).is_none()); + } + + #[test] + fn test_get_transport() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(100); + let config = Config::default(); + + let transport = create_test_transport(session_id, &public_key, remote_addr, true); + state.handle_established(session_id, transport, &config); + + assert!(state.get_transport(&session_id).is_some()); + assert!(state.get_transport(&SessionIndex::new(999)).is_none()); + } + + #[test] + fn test_get_session_id_by_public_key_empty() { + let state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + assert!(state.get_session_id_by_public_key(&key_bytes).is_none()); + } + + #[test] + fn test_get_session_id_by_public_key_single_initiator() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let session_id = SessionIndex::new(1); + let created = Duration::from_secs(100); + + state.last_established_session_by_public_key.insert( + key_bytes, + EstablishedSessions { + initiator: Some((session_id, created)), + responder: None, + }, + ); + + assert_eq!( + state.get_session_id_by_public_key(&key_bytes), + Some(session_id) + ); + } + + #[test] + fn test_get_session_id_by_public_key_single_responder() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let session_id = SessionIndex::new(2); + let created = Duration::from_secs(100); + + state.last_established_session_by_public_key.insert( + key_bytes, + EstablishedSessions { + initiator: None, + responder: Some((session_id, created)), + }, + ); + + assert_eq!( + state.get_session_id_by_public_key(&key_bytes), + Some(session_id) + ); + } + + #[test] + fn test_get_session_id_by_public_key_both_newer_initiator() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let session_id_init = SessionIndex::new(1); + let session_id_resp = SessionIndex::new(2); + + state.last_established_session_by_public_key.insert( + key_bytes, + EstablishedSessions { + initiator: Some((session_id_init, Duration::from_secs(200))), + responder: Some((session_id_resp, Duration::from_secs(100))), + }, + ); + + assert_eq!( + state.get_session_id_by_public_key(&key_bytes), + Some(session_id_init) + ); + } + + #[test] + fn test_get_session_id_by_public_key_both_newer_responder() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let session_id_init = SessionIndex::new(1); + let session_id_resp = SessionIndex::new(2); + + state.last_established_session_by_public_key.insert( + key_bytes, + EstablishedSessions { + initiator: Some((session_id_init, Duration::from_secs(100))), + responder: Some((session_id_resp, Duration::from_secs(200))), + }, + ); + + assert_eq!( + state.get_session_id_by_public_key(&key_bytes), + Some(session_id_resp) + ); + } + + #[test] + fn test_get_session_id_by_socket_empty() { + let state = State::new(); + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + assert!(state.get_session_id_by_socket(&addr).is_none()); + } + + #[test] + fn test_get_session_id_by_socket_single() { + let mut state = State::new(); + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(5); + let created = Duration::from_secs(100); + + state.last_established_session_by_socket.insert( + addr, + EstablishedSessions { + initiator: Some((session_id, created)), + responder: None, + }, + ); + + assert_eq!(state.get_session_id_by_socket(&addr), Some(session_id)); + } + + #[test] + fn test_get_session_id_by_socket_both_newer_initiator() { + let mut state = State::new(); + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id_init = SessionIndex::new(3); + let session_id_resp = SessionIndex::new(4); + + state.last_established_session_by_socket.insert( + addr, + EstablishedSessions { + initiator: Some((session_id_init, Duration::from_secs(300))), + responder: Some((session_id_resp, Duration::from_secs(100))), + }, + ); + + assert_eq!(state.get_session_id_by_socket(&addr), Some(session_id_init)); + } + + #[test] + fn test_insert_and_get_initiator() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let session_id = SessionIndex::new(10); + let initiator = create_test_initiator(&public_key); + + state.insert_initiator(session_id, initiator, key_bytes); + + assert!(state.get_initiator(&session_id).is_some()); + assert!(state.initiated_session_by_peer.contains_key(&key_bytes)); + assert_eq!(state.initiated_session_by_peer[&key_bytes], session_id); + } + + #[test] + fn test_insert_and_get_responder() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let session_id = SessionIndex::new(20); + let responder = create_test_responder(&public_key, None); + + state.insert_responder(session_id, responder, key_bytes); + + assert!(state.get_responder(&session_id).is_some()); + assert!(state + .accepted_sessions_by_peer + .contains(&(key_bytes, session_id))); + } + + #[test] + fn test_get_initiator_mut() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let session_id = SessionIndex::new(10); + let initiator = create_test_initiator(&public_key); + + state.insert_initiator(session_id, initiator, key_bytes); + assert!(state.get_initiator_mut(&session_id).is_some()); + } + + #[test] + fn test_get_responder_mut() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let session_id = SessionIndex::new(20); + let responder = create_test_responder(&public_key, None); + + state.insert_responder(session_id, responder, key_bytes); + assert!(state.get_responder_mut(&session_id).is_some()); + } + + #[test] + fn test_remove_initiator() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let session_id = SessionIndex::new(10); + let initiator = create_test_initiator(&public_key); + + state.insert_initiator(session_id, initiator, key_bytes); + assert!(state.remove_initiator(&session_id).is_some()); + assert!(state.get_initiator(&session_id).is_none()); + } + + #[test] + fn test_remove_responder() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let session_id = SessionIndex::new(20); + let responder = create_test_responder(&public_key, None); + + state.insert_responder(session_id, responder, key_bytes); + assert!(state.remove_responder(&session_id).is_some()); + assert!(state.get_responder(&session_id).is_none()); + } + + #[test] + fn test_handle_established_initiator() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(100); + let config = Config::default(); + + let transport = create_test_transport(session_id, &public_key, remote_addr, true); + + let terminated = state.handle_established(session_id, transport, &config); + + assert!(terminated.is_empty()); + assert!(state.get_transport(&session_id).is_some()); + let key_bytes = SerializedPublicKey::from(&public_key); + assert!(state + .last_established_session_by_public_key + .contains_key(&key_bytes)); + assert!(state + .last_established_session_by_socket + .contains_key(&remote_addr)); + } + + #[test] + fn test_handle_established_replaces_old_initiator() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let config = Config::default(); + + let old_session_id = SessionIndex::new(100); + let transport1 = create_test_transport(old_session_id, &public_key, remote_addr, true); + state.handle_established(old_session_id, transport1, &config); + + let new_session_id = SessionIndex::new(101); + let transport2 = create_test_transport(new_session_id, &public_key, remote_addr, true); + let terminated = state.handle_established(new_session_id, transport2, &config); + + assert_eq!(terminated.len(), 1); + assert_eq!(terminated[0], old_session_id); + } + + #[test] + fn test_handle_established_responder() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(200); + let config = Config::default(); + + let transport = create_test_transport(session_id, &public_key, remote_addr, false); + + let terminated = state.handle_established(session_id, transport, &config); + + assert!(terminated.is_empty()); + assert!(state.get_transport(&session_id).is_some()); + } + + #[test] + fn test_handle_established_both_initiator_and_responder() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let config = Config::default(); + + let init_session_id = SessionIndex::new(100); + let transport_init = create_test_transport(init_session_id, &public_key, remote_addr, true); + state.handle_established(init_session_id, transport_init, &config); + + let resp_session_id = SessionIndex::new(200); + let transport_resp = + create_test_transport(resp_session_id, &public_key, remote_addr, false); + let terminated = state.handle_established(resp_session_id, transport_resp, &config); + + assert!(terminated.is_empty()); + assert!(state.get_transport(&init_session_id).is_some()); + assert!(state.get_transport(&resp_session_id).is_some()); + + let key_bytes = SerializedPublicKey::from(&public_key); + let sessions = state + .last_established_session_by_public_key + .get(&key_bytes) + .unwrap(); + assert!(sessions.initiator.is_some()); + assert!(sessions.responder.is_some()); + } + + #[test] + fn test_handle_terminate_removes_transport() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(100); + let config = Config::default(); + + let transport = create_test_transport(session_id, &public_key, remote_addr, true); + state.handle_established(session_id, transport, &config); + + let reservation = state.reserve_session_index().unwrap(); + reservation.commit(); + + state.handle_terminate(session_id, &key_bytes, remote_addr); + + assert!(state.get_transport(&session_id).is_none()); + assert!(!state.allocated_indices.contains(&session_id)); + } + + #[test] + fn test_handle_terminate_cleans_up_by_public_key() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(100); + let config = Config::default(); + + let transport = create_test_transport(session_id, &public_key, remote_addr, true); + state.handle_established(session_id, transport, &config); + + state.handle_terminate(session_id, &key_bytes, remote_addr); + + assert!(!state + .last_established_session_by_public_key + .contains_key(&key_bytes)); + } + + #[test] + fn test_handle_terminate_preserves_other_slot() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let config = Config::default(); + + let init_session_id = SessionIndex::new(100); + let transport_init = create_test_transport(init_session_id, &public_key, remote_addr, true); + state.handle_established(init_session_id, transport_init, &config); + + let resp_session_id = SessionIndex::new(200); + let transport_resp = + create_test_transport(resp_session_id, &public_key, remote_addr, false); + state.handle_established(resp_session_id, transport_resp, &config); + + state.handle_terminate(init_session_id, &key_bytes, remote_addr); + + assert!(state + .last_established_session_by_public_key + .contains_key(&key_bytes)); + let sessions = state + .last_established_session_by_public_key + .get(&key_bytes) + .unwrap(); + assert!(sessions.initiator.is_none()); + assert!(sessions.responder.is_some()); + } + + #[test] + fn test_handle_terminate_cleans_up_by_socket() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(100); + let config = Config::default(); + + let transport = create_test_transport(session_id, &public_key, remote_addr, true); + state.handle_established(session_id, transport, &config); + + state.handle_terminate(session_id, &key_bytes, remote_addr); + + assert!(!state + .last_established_session_by_socket + .contains_key(&remote_addr)); + } + + #[test] + fn test_handle_terminate_removes_initiator() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(100); + + let initiator = create_test_initiator(&public_key); + state.insert_initiator(session_id, initiator, key_bytes); + + state.handle_terminate(session_id, &key_bytes, remote_addr); + + assert!(state.get_initiator(&session_id).is_none()); + } + + #[test] + fn test_handle_terminate_removes_responder() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(200); + + let responder = create_test_responder(&public_key, None); + state.insert_responder(session_id, responder, key_bytes); + + state.handle_terminate(session_id, &key_bytes, remote_addr); + + assert!(state.get_responder(&session_id).is_none()); + assert!(!state + .accepted_sessions_by_peer + .contains(&(key_bytes, session_id))); + } + + #[test] + fn test_handle_terminate_removes_initiated_session_by_peer() { + let mut state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820); + let session_id = SessionIndex::new(100); + + let initiator = create_test_initiator(&public_key); + state.insert_initiator(session_id, initiator, key_bytes); + + state.handle_terminate(session_id, &key_bytes, remote_addr); + + assert!(!state.initiated_session_by_peer.contains_key(&key_bytes)); + } + + #[test] + fn test_lookup_cookie_from_initiated_sessions_none() { + let state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + assert!(state + .lookup_cookie_from_initiated_sessions(&key_bytes) + .is_none()); + } + + #[test] + fn test_lookup_cookie_from_accepted_sessions_none() { + let state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + assert!(state + .lookup_cookie_from_accepted_sessions(key_bytes) + .is_none()); + } + + #[test] + fn test_get_max_timestamp_empty() { + let state = State::new(); + let mut rng = rng(); + let (public_key, _) = crypto::generate_keypair(&mut rng).unwrap(); + let key_bytes = SerializedPublicKey::from(&public_key); + assert!(state.get_max_timestamp(&key_bytes).is_none()); + } + + #[test] + fn test_reserve_success_and_commit() { + let mut state = State::new(); + + let index = { + let reservation = state.reserve_session_index().unwrap(); + reservation.index() + }; + assert_eq!(index, SessionIndex::new(0)); + assert_eq!(state.next_session_index, SessionIndex::new(0)); + + let reservation = state.reserve_session_index().unwrap(); + assert_eq!(reservation.index(), SessionIndex::new(0)); + reservation.commit(); + assert_eq!(state.next_session_index, SessionIndex::new(1)); + assert!(state.allocated_indices.contains(&SessionIndex::new(0))); + + let reservation2 = state.reserve_session_index().unwrap(); + let index2 = reservation2.index(); + assert_eq!(index2, SessionIndex::new(1)); + reservation2.commit(); + assert_eq!(state.next_session_index, SessionIndex::new(2)); + assert!(state.allocated_indices.contains(&SessionIndex::new(1))); + } + + #[test] + fn test_reserve_drop_without_commit() { + let mut state = State::new(); + + { + let _reservation = state.reserve_session_index().unwrap(); + assert_eq!(state.next_session_index, SessionIndex::new(0)); + } + + assert_eq!(state.next_session_index, SessionIndex::new(0)); + + let reservation2 = state.reserve_session_index().unwrap(); + let index2 = reservation2.index(); + assert_eq!(index2, SessionIndex::new(0)); + reservation2.commit(); + assert_eq!(state.next_session_index, SessionIndex::new(1)); + assert!(state.allocated_indices.contains(&SessionIndex::new(0))); + } +} diff --git a/monad-wireauth-api/tests/e2e.rs b/monad-wireauth-api/tests/e2e.rs new file mode 100644 index 0000000000..c6d43b1d85 --- /dev/null +++ b/monad-wireauth-api/tests/e2e.rs @@ -0,0 +1,207 @@ +use std::{rc::Rc, time::Duration}; + +use bytes::Bytes; +use monad_wireauth_api::{Config, StdContext, API}; +use monad_wireauth_protocol::common::PublicKey; +use monad_wireauth_session::DEFAULT_RETRY_ATTEMPTS; +use monoio::net::udp::UdpSocket; +use secp256k1::rand::{rngs::StdRng, SeedableRng}; +use zerocopy::IntoBytes; + +struct PeerNode { + manager: API, + socket: Rc, + public_key: PublicKey, +} + +impl PeerNode { + fn new(port: u16, seed: u64) -> std::io::Result { + let mut rng = StdRng::seed_from_u64(seed); + let (public_key, private_key) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + + let config = Config { + session_timeout: Duration::from_secs(10), + session_timeout_jitter: Duration::ZERO, + keepalive_interval: Duration::from_secs(3), + keepalive_jitter: Duration::ZERO, + rekey_interval: Duration::from_secs(60), + rekey_jitter: Duration::ZERO, + ..Default::default() + }; + + let context = StdContext::new(); + let manager = API::new(config, private_key, public_key.clone(), context); + + let addr = format!("127.0.0.1:{}", port); + let socket = UdpSocket::bind(addr)?; + + Ok(Self { + manager, + socket: Rc::new(socket), + public_key, + }) + } + + fn connect(&mut self, peer_public: PublicKey, peer_addr: std::net::SocketAddr) { + self.manager + .connect(peer_public, peer_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + } + + async fn send_all_packets(&mut self) -> std::io::Result<()> { + while let Some((addr, packet)) = self.manager.next_packet() { + let packet_vec = packet.to_vec(); + let (result, _) = self.socket.send_to(packet_vec, addr).await; + result?; + } + Ok(()) + } + + async fn recv_and_dispatch(&mut self) -> std::io::Result> { + let buf = vec![0u8; 65536]; + let (result, mut buf) = self.socket.recv_from(buf).await; + let (len, src) = result?; + let result = self.manager.dispatch(&mut buf[..len], src); + match result { + Ok(data) => Ok(data), + Err(_) => Ok(None), + } + } + + fn encrypt_by_public_key( + &mut self, + peer_public: &PublicKey, + plaintext: &mut [u8], + ) -> monad_wireauth_api::Result { + self.manager.encrypt_by_public_key(peer_public, plaintext) + } +} + +async fn exchange_handshake(alice: &mut PeerNode, bob: &mut PeerNode) -> std::io::Result<()> { + let mut iterations = 0; + let max_iterations = 10; + + while iterations < max_iterations { + alice.send_all_packets().await?; + bob.send_all_packets().await?; + + let alice_socket = alice.socket.clone(); + let bob_socket = bob.socket.clone(); + + let alice_task = async { + let buf = vec![0u8; 65536]; + monoio::select! { + recv_result = alice_socket.recv_from(buf) => { + let (result, mut buf) = recv_result; + if let Ok((len, src)) = result { + let _ = alice.manager.dispatch(&mut buf[..len], src); + true + } else { + false + } + }, + _ = monoio::time::sleep(Duration::from_millis(10)) => { + false + }, + } + }; + + let bob_task = async { + let buf = vec![0u8; 65536]; + monoio::select! { + recv_result = bob_socket.recv_from(buf) => { + let (result, mut buf) = recv_result; + if let Ok((len, src)) = result { + let _ = bob.manager.dispatch(&mut buf[..len], src); + true + } else { + false + } + }, + _ = monoio::time::sleep(Duration::from_millis(10)) => { + false + }, + } + }; + + let (alice_received, bob_received) = monoio::join!(alice_task, bob_task); + + if !alice_received && !bob_received { + break; + } + + iterations += 1; + } + + Ok(()) +} + +#[monoio::test(timer_enabled = true)] +async fn test_e2e_handshake_and_data() { + let mut alice = PeerNode::new(28001, 1).unwrap(); + let mut bob = PeerNode::new(28002, 2).unwrap(); + + let bob_addr = bob.socket.local_addr().unwrap(); + alice.connect(bob.public_key.clone(), bob_addr); + + exchange_handshake(&mut alice, &mut bob).await.unwrap(); + + let mut plaintext = b"hello from alice".to_vec(); + let header = alice + .encrypt_by_public_key(&bob.public_key, &mut plaintext) + .expect("alice encrypt failed"); + + let mut packet = Vec::new(); + packet.extend_from_slice(header.as_bytes()); + packet.extend_from_slice(&plaintext); + + let (result, _) = alice.socket.send_to(packet, bob_addr).await; + result.unwrap(); + + let received = bob.recv_and_dispatch().await.unwrap(); + assert_eq!(received, Some(Bytes::from(&b"hello from alice"[..]))); +} + +#[monoio::test(timer_enabled = true)] +async fn test_e2e_bidirectional() { + let mut alice = PeerNode::new(28003, 3).unwrap(); + let mut bob = PeerNode::new(28004, 4).unwrap(); + + let bob_addr = bob.socket.local_addr().unwrap(); + let alice_addr = alice.socket.local_addr().unwrap(); + + alice.connect(bob.public_key.clone(), bob_addr); + + exchange_handshake(&mut alice, &mut bob).await.unwrap(); + + let mut msg_alice = b"message from alice".to_vec(); + let header_alice = alice + .encrypt_by_public_key(&bob.public_key, &mut msg_alice) + .unwrap(); + + let mut packet_alice = Vec::new(); + packet_alice.extend_from_slice(header_alice.as_bytes()); + packet_alice.extend_from_slice(&msg_alice); + + let (result, _) = alice.socket.send_to(packet_alice, bob_addr).await; + result.unwrap(); + + let received_bob = bob.recv_and_dispatch().await.unwrap(); + assert_eq!(received_bob, Some(Bytes::from(&b"message from alice"[..]))); + + let mut msg_bob = b"message from bob".to_vec(); + let header_bob = bob + .encrypt_by_public_key(&alice.public_key, &mut msg_bob) + .unwrap(); + + let mut packet_bob = Vec::new(); + packet_bob.extend_from_slice(header_bob.as_bytes()); + packet_bob.extend_from_slice(&msg_bob); + + let (result, _) = bob.socket.send_to(packet_bob, alice_addr).await; + result.unwrap(); + + let received_alice = alice.recv_and_dispatch().await.unwrap(); + assert_eq!(received_alice, Some(Bytes::from(&b"message from bob"[..]))); +} diff --git a/monad-wireauth-api/tests/tests.rs b/monad-wireauth-api/tests/tests.rs new file mode 100644 index 0000000000..82066684cf --- /dev/null +++ b/monad-wireauth-api/tests/tests.rs @@ -0,0 +1,739 @@ +use std::{net::SocketAddr, time::Duration}; + +use monad_wireauth_api::{Config, TestContext, API}; +use monad_wireauth_protocol::{ + common::PublicKey, + messages::{CookieReply, DataPacketHeader, HandshakeInitiation, HandshakeResponse}, +}; +use monad_wireauth_session::DEFAULT_RETRY_ATTEMPTS; +use secp256k1::rand::rng; +use tracing_subscriber::EnvFilter; +use zerocopy::IntoBytes; + +fn init_tracing() { + let _ = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .try_init(); +} + +fn create_manager() -> (API, PublicKey, TestContext, Config) { + let mut rng = rng(); + let (public_key, private_key) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let config = Config::default(); + let context = TestContext::new(); + let context_clone = context.clone(); + let manager = API::new(config.clone(), private_key, public_key.clone(), context); + (manager, public_key, context_clone, config) +} + +fn collect(manager: &mut API) -> Vec +where + for<'a> &'a T: std::convert::TryFrom<&'a [u8]>, + for<'a> <&'a T as std::convert::TryFrom<&'a [u8]>>::Error: std::fmt::Debug, +{ + let (_, packet) = manager.next_packet().unwrap(); + let bytes = packet.to_vec(); + let _ = <&T>::try_from(&bytes[..]).unwrap(); + bytes +} + +fn dispatch(manager: &mut API, packet: &[u8], from: SocketAddr) -> Option> { + let mut packet_mut = packet.to_vec(); + manager + .dispatch(&mut packet_mut, from) + .ok() + .flatten() + .map(|b| b.to_vec()) +} + +fn encrypt( + manager: &mut API, + peer_pubkey: &PublicKey, + plaintext: &mut [u8], +) -> Vec { + let header = manager + .encrypt_by_public_key(peer_pubkey, plaintext) + .unwrap(); + let mut packet = Vec::with_capacity(DataPacketHeader::SIZE + plaintext.len()); + packet.extend_from_slice(header.as_bytes()); + packet.extend_from_slice(plaintext); + packet +} + +fn decrypt(manager: &mut API, packet: &[u8], from: SocketAddr) -> Vec { + dispatch(manager, packet, from).unwrap() +} + +//1. peer1 initiates to peer2 +//2. peer2 initiates to peer1 +//3. peer2 receives peer1 init and sends response +//4. peer1 receives peer2 init and sends response +//5. peer1 receives peer2 response +//6. peer2 receives peer1 response +//7. peer1 encrypts message to peer2 +//8. peer2 decrypts message from peer1 +//9. peer2 encrypts message to peer1 +//10. peer1 decrypts message from peer2 +#[test] +fn test_concurrent_init() { + init_tracing(); + let (mut peer1, peer1_pubkey, _, _) = create_manager(); + let (mut peer2, peer2_pubkey, _, _) = create_manager(); + let peer1_addr: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + + peer1 + .connect(peer2_pubkey.clone(), peer2_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + peer2 + .connect(peer1_pubkey.clone(), peer1_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + + let init1 = collect::(&mut peer1); + let init2 = collect::(&mut peer2); + + dispatch(&mut peer2, &init1, peer1_addr); + dispatch(&mut peer1, &init2, peer2_addr); + + let resp2 = collect::(&mut peer2); + let resp1 = collect::(&mut peer1); + + dispatch(&mut peer1, &resp2, peer2_addr); + dispatch(&mut peer2, &resp1, peer1_addr); + + let mut plaintext1 = b"hello from peer1".to_vec(); + let packet1 = encrypt(&mut peer1, &peer2_pubkey, &mut plaintext1); + let decrypted1 = decrypt(&mut peer2, &packet1, peer1_addr); + assert_eq!(decrypted1, b"hello from peer1"); + + let mut plaintext2 = b"hello from peer2".to_vec(); + let packet2 = encrypt(&mut peer2, &peer1_pubkey, &mut plaintext2); + let decrypted2 = decrypt(&mut peer1, &packet2, peer2_addr); + assert_eq!(decrypted2, b"hello from peer2"); +} + +//1. peer1 connects to peer2 with 2 retries +//2. peer1 sends first init - dropped +//3. advance time and tick - peer1 retries +//4. peer1 sends second init - dropped +//5. advance time and tick - peer1 retries +//6. peer1 sends third init - delivered to peer2 +//7. peer2 sends response +//8. peer1 receives response and completes handshake +//9. exchange several messages +#[test] +fn test_retries() { + init_tracing(); + let (mut peer1, _, peer1_ctx, _) = create_manager(); + let (mut peer2, peer2_pubkey, _, _) = create_manager(); + let peer1_addr: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + + peer1.connect(peer2_pubkey.clone(), peer2_addr, 2).unwrap(); + + let _init1 = collect::(&mut peer1); + + peer1_ctx.advance_time(Duration::from_secs(11)); + peer1.tick(); + let _init2 = collect::(&mut peer1); + + peer1_ctx.advance_time(Duration::from_secs(11)); + peer1.tick(); + let init3 = collect::(&mut peer1); + + dispatch(&mut peer2, &init3, peer1_addr); + let resp = collect::(&mut peer2); + dispatch(&mut peer1, &resp, peer2_addr); + + let mut plaintext1 = b"message1".to_vec(); + let packet1 = encrypt(&mut peer1, &peer2_pubkey, &mut plaintext1); + let decrypted1 = decrypt(&mut peer2, &packet1, peer1_addr); + assert_eq!(decrypted1, b"message1"); + + let mut plaintext2 = b"message2".to_vec(); + let packet2 = encrypt(&mut peer1, &peer2_pubkey, &mut plaintext2); + let decrypted2 = decrypt(&mut peer2, &packet2, peer1_addr); + assert_eq!(decrypted2, b"message2"); + + let mut plaintext3 = b"message3".to_vec(); + let packet3 = encrypt(&mut peer1, &peer2_pubkey, &mut plaintext3); + let decrypted3 = decrypt(&mut peer2, &packet3, peer1_addr); + assert_eq!(decrypted3, b"message3"); +} + +//1. create 5 peers with independent managers +//2. each peer initiates to all other peers +//3. complete all handshakes +//4. each peer exchanges messages with all other peers +#[test] +fn test_five_peers() { + init_tracing(); + let (mut m0, pk0, _, _) = create_manager(); + let (mut m1, pk1, _, _) = create_manager(); + let (mut m2, pk2, _, _) = create_manager(); + let (mut m3, pk3, _, _) = create_manager(); + let (mut m4, pk4, _, _) = create_manager(); + let a0: SocketAddr = "127.0.0.1:8000".parse().unwrap(); + let a1: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + let a2: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + let a3: SocketAddr = "127.0.0.1:8003".parse().unwrap(); + let a4: SocketAddr = "127.0.0.1:8004".parse().unwrap(); + + m0.connect(pk1.clone(), a1, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m0.connect(pk2.clone(), a2, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m0.connect(pk3.clone(), a3, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m0.connect(pk4.clone(), a4, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m1.connect(pk0.clone(), a0, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m1.connect(pk2.clone(), a2, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m1.connect(pk3.clone(), a3, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m1.connect(pk4.clone(), a4, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m2.connect(pk0.clone(), a0, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m2.connect(pk1.clone(), a1, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m2.connect(pk3.clone(), a3, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m2.connect(pk4.clone(), a4, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m3.connect(pk0.clone(), a0, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m3.connect(pk1.clone(), a1, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m3.connect(pk2.clone(), a2, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m3.connect(pk4, a4, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m4.connect(pk0.clone(), a0, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m4.connect(pk1.clone(), a1, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m4.connect(pk2, a2, DEFAULT_RETRY_ATTEMPTS).unwrap(); + m4.connect(pk3.clone(), a3, DEFAULT_RETRY_ATTEMPTS).unwrap(); + + let i01 = collect::(&mut m0); + let i02 = collect::(&mut m0); + let i03 = collect::(&mut m0); + let i04 = collect::(&mut m0); + let i10 = collect::(&mut m1); + let i12 = collect::(&mut m1); + let i13 = collect::(&mut m1); + let i14 = collect::(&mut m1); + let i20 = collect::(&mut m2); + let i21 = collect::(&mut m2); + let i23 = collect::(&mut m2); + let i24 = collect::(&mut m2); + let i30 = collect::(&mut m3); + let i31 = collect::(&mut m3); + let i32 = collect::(&mut m3); + let i34 = collect::(&mut m3); + let i40 = collect::(&mut m4); + let i41 = collect::(&mut m4); + let i42 = collect::(&mut m4); + let i43 = collect::(&mut m4); + + dispatch(&mut m1, &i01, a0); + dispatch(&mut m2, &i02, a0); + dispatch(&mut m3, &i03, a0); + dispatch(&mut m4, &i04, a0); + dispatch(&mut m0, &i10, a1); + dispatch(&mut m2, &i12, a1); + dispatch(&mut m3, &i13, a1); + dispatch(&mut m4, &i14, a1); + dispatch(&mut m0, &i20, a2); + dispatch(&mut m1, &i21, a2); + dispatch(&mut m3, &i23, a2); + dispatch(&mut m4, &i24, a2); + dispatch(&mut m0, &i30, a3); + dispatch(&mut m1, &i31, a3); + dispatch(&mut m2, &i32, a3); + dispatch(&mut m4, &i34, a3); + dispatch(&mut m0, &i40, a4); + dispatch(&mut m1, &i41, a4); + dispatch(&mut m2, &i42, a4); + dispatch(&mut m3, &i43, a4); + + let r10 = collect::(&mut m1); + let r20 = collect::(&mut m2); + let r30 = collect::(&mut m3); + let r40 = collect::(&mut m4); + let r01 = collect::(&mut m0); + let r21 = collect::(&mut m2); + let r31 = collect::(&mut m3); + let r41 = collect::(&mut m4); + let r02 = collect::(&mut m0); + let r12 = collect::(&mut m1); + let r32 = collect::(&mut m3); + let r42 = collect::(&mut m4); + let r03 = collect::(&mut m0); + let r13 = collect::(&mut m1); + let r23 = collect::(&mut m2); + let r43 = collect::(&mut m4); + let r04 = collect::(&mut m0); + let r14 = collect::(&mut m1); + let r24 = collect::(&mut m2); + let r34 = collect::(&mut m3); + + dispatch(&mut m0, &r10, a1); + dispatch(&mut m0, &r20, a2); + dispatch(&mut m0, &r30, a3); + dispatch(&mut m0, &r40, a4); + dispatch(&mut m1, &r01, a0); + dispatch(&mut m1, &r21, a2); + dispatch(&mut m1, &r31, a3); + dispatch(&mut m1, &r41, a4); + dispatch(&mut m2, &r02, a0); + dispatch(&mut m2, &r12, a1); + dispatch(&mut m2, &r32, a3); + dispatch(&mut m2, &r42, a4); + dispatch(&mut m3, &r03, a0); + dispatch(&mut m3, &r13, a1); + dispatch(&mut m3, &r23, a2); + dispatch(&mut m3, &r43, a4); + dispatch(&mut m4, &r04, a0); + dispatch(&mut m4, &r14, a1); + dispatch(&mut m4, &r24, a2); + dispatch(&mut m4, &r34, a3); + + let mut plaintext = b"0->1".to_vec(); + let packet = encrypt(&mut m0, &pk1, &mut plaintext); + let decrypted = decrypt(&mut m1, &packet, a0); + assert_eq!(decrypted, b"0->1"); + + let mut plaintext = b"2->3".to_vec(); + let packet = encrypt(&mut m2, &pk3, &mut plaintext); + let decrypted = decrypt(&mut m3, &packet, a2); + assert_eq!(decrypted, b"2->3"); + + let mut plaintext = b"4->0".to_vec(); + let packet = encrypt(&mut m4, &pk0, &mut plaintext); + let decrypted = decrypt(&mut m0, &packet, a4); + assert_eq!(decrypted, b"4->0"); +} + +//1. peer1 initiates to peer2 +//2. complete handshake +//3. peer1 encrypts by public key and sends to peer2 +//4. peer1 encrypts by socket and sends to peer2 +//5. peer2 decrypts both messages +#[test] +fn test_encrypt_by_pubkey_and_socket() { + init_tracing(); + let (mut peer1, _, _, _) = create_manager(); + let (mut peer2, peer2_pubkey, _, _) = create_manager(); + let peer1_addr: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + + peer1 + .connect(peer2_pubkey.clone(), peer2_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + + let init = collect::(&mut peer1); + dispatch(&mut peer2, &init, peer1_addr); + let resp = collect::(&mut peer2); + dispatch(&mut peer1, &resp, peer2_addr); + + let mut plaintext1 = b"by pubkey".to_vec(); + let packet1 = encrypt(&mut peer1, &peer2_pubkey, &mut plaintext1); + let decrypted1 = decrypt(&mut peer2, &packet1, peer1_addr); + assert_eq!(decrypted1, b"by pubkey"); + + let mut plaintext2 = b"by socket".to_vec(); + let header2 = peer1 + .encrypt_by_socket(&peer2_addr, &mut plaintext2) + .unwrap(); + let mut packet2 = Vec::with_capacity(DataPacketHeader::SIZE + plaintext2.len()); + packet2.extend_from_slice(header2.as_bytes()); + packet2.extend_from_slice(&plaintext2); + let decrypted2 = decrypt(&mut peer2, &packet2, peer1_addr); + assert_eq!(decrypted2, b"by socket"); +} + +//1. peer1 initiates to peer2 with low handshake rate limit +//2. peer2 sends cookie reply to 1st init +//3. peer1 receives cookie reply and stores it +//4. advance time past session timeout +//5. peer1 tick triggers retry with stored cookie +//6. peer2 accepts init with valid mac2 +#[test] +fn test_cookie_reply_on_init() { + init_tracing(); + let config = Config { + handshake_rate_limit: 10, + low_watermark_sessions: 1, + ..Config::default() + }; + + let mut rng = rng(); + let (public_key1, private_key1) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let context1 = TestContext::new(); + let mut peer1 = API::new(config.clone(), private_key1, public_key1, context1.clone()); + + let (public_key2, private_key2) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let context2 = TestContext::new(); + let mut peer2 = API::new(config, private_key2, public_key2.clone(), context2); + + let peer1_addr: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + + peer1 + .connect(public_key2.clone(), peer2_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + let init1 = collect::(&mut peer1); + dispatch(&mut peer2, &init1, peer1_addr); + + let resp = collect::(&mut peer2); + dispatch(&mut peer1, &resp, peer2_addr); + + let data = collect::(&mut peer1); + dispatch(&mut peer2, &data, peer1_addr); + + peer1 + .connect(public_key2, peer2_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + let init2 = collect::(&mut peer1); + dispatch(&mut peer2, &init2, peer1_addr); + + let cookie = collect::(&mut peer2); + dispatch(&mut peer1, &cookie, peer2_addr); + + context1.advance_time(Duration::from_secs(11)); + peer1.tick(); + + let _init2 = collect::(&mut peer1); +} + +//1. peer1 establishes session with peer2 +//2. peer1 attempts connect again to peer2 +//3. exchange messages to verify session still works +#[test] +fn test_connect_after_established() { + init_tracing(); + let (mut peer1, _, _, _) = create_manager(); + let (mut peer2, peer2_pubkey, _, _) = create_manager(); + let peer1_addr: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + + peer1 + .connect(peer2_pubkey.clone(), peer2_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + let init = collect::(&mut peer1); + dispatch(&mut peer2, &init, peer1_addr); + let resp = collect::(&mut peer2); + dispatch(&mut peer1, &resp, peer2_addr); + + let mut plaintext = b"before reconnect".to_vec(); + let packet = encrypt(&mut peer1, &peer2_pubkey, &mut plaintext); + let decrypted = decrypt(&mut peer2, &packet, peer1_addr); + assert_eq!(decrypted, b"before reconnect"); + + let _ = peer1.connect(peer2_pubkey.clone(), peer2_addr, DEFAULT_RETRY_ATTEMPTS); + + let mut plaintext = b"after reconnect".to_vec(); + let packet = encrypt(&mut peer1, &peer2_pubkey, &mut plaintext); + let decrypted = decrypt(&mut peer2, &packet, peer1_addr); + assert_eq!(decrypted, b"after reconnect"); +} + +//1. peer1 initiates to peer2 +//2. peer2 accepts init and sends response +//3. peer1 sends same init again +//4. verify peer2 rejects replay +#[test] +fn test_timestamp_replay() { + init_tracing(); + let (mut peer1, _, _, _) = create_manager(); + let (mut peer2, peer2_pubkey, _, _) = create_manager(); + let peer1_addr: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + + peer1 + .connect(peer2_pubkey, peer2_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + let init = collect::(&mut peer1); + + dispatch(&mut peer2, &init, peer1_addr); + let _resp = collect::(&mut peer2); + + let result2 = dispatch(&mut peer2, &init, peer1_addr); + assert!(result2.is_none()); +} + +//1. create 10 peer managers with same keypair +//2. each initiates to one responder with distinct key +//3. verify responder hits max accepted sessions limit +#[test] +fn test_too_many_accepted_sessions() { + init_tracing(); + let config = Config::default(); + + let mut rng = rng(); + let (responder_public, responder_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let responder_ctx = TestContext::new(); + let mut responder = API::new( + config.clone(), + responder_private, + responder_public.clone(), + responder_ctx, + ); + let responder_addr: SocketAddr = "127.0.0.1:9000".parse().unwrap(); + + let (shared_public, shared_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + + for i in 0..5 { + let initiator_ctx = TestContext::new(); + let mut initiator = API::new( + config.clone(), + shared_private.clone(), + shared_public.clone(), + initiator_ctx, + ); + let initiator_addr: SocketAddr = format!("127.0.0.1:800{}", i).parse().unwrap(); + + initiator + .connect( + responder_public.clone(), + responder_addr, + DEFAULT_RETRY_ATTEMPTS, + ) + .unwrap(); + + let init = collect::(&mut initiator); + dispatch(&mut responder, &init, initiator_addr); + } + + assert!(responder.next_packet().is_some()); +} + +//1. dispatch random invalid packet +//2. verify error is returned +#[test] +fn test_random_packet_error() { + init_tracing(); + let (mut peer, _, _, _) = create_manager(); + let addr: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + + let random_packet = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let mut packet = random_packet; + let result = peer.dispatch(&mut packet, addr); + assert!(result.is_err()); +} + +//1. create manager with low handshake rate limit +//2. exceed rate limit with multiple inits +//3. verify packets get dropped due to rate limit +#[test] +fn test_filter_drop_rate_limit() { + init_tracing(); + let config = Config { + handshake_rate_limit: 2, + ..Config::default() + }; + + let mut rng = rng(); + let (responder_public, responder_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let responder_ctx = TestContext::new(); + let mut responder = API::new( + config.clone(), + responder_private, + responder_public.clone(), + responder_ctx, + ); + let responder_addr: SocketAddr = "127.0.0.1:9000".parse().unwrap(); + + for i in 0..4 { + let (initiator_public, initiator_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let initiator_ctx = TestContext::new(); + let mut initiator = API::new( + config.clone(), + initiator_private, + initiator_public.clone(), + initiator_ctx, + ); + let initiator_addr: SocketAddr = format!("127.0.0.1:800{}", i).parse().unwrap(); + + initiator + .connect( + responder_public.clone(), + responder_addr, + DEFAULT_RETRY_ATTEMPTS, + ) + .unwrap(); + + let init = collect::(&mut initiator); + dispatch(&mut responder, &init, initiator_addr); + } + + let response_count = (0..4).filter(|_| responder.next_packet().is_some()).count(); + + assert!(response_count <= 2); +} + +//1. create manager and initiate connection +//2. check next_timer returns some duration +//3. advance time partially +//4. verify timer decreases +//5. advance time past deadline +//6. verify timer returns zero +#[test] +fn test_next_timer() { + init_tracing(); + let (mut peer1, _, peer1_ctx, config) = create_manager(); + let (_, peer2_pubkey, _, _) = create_manager(); + let peer2_addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + + let timer_before = peer1.next_timer(); + assert_eq!(timer_before, Some(config.handshake_rate_reset_interval)); + + peer1 + .connect(peer2_pubkey, peer2_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + + let timer_after = peer1.next_timer(); + assert!(timer_after.is_some()); + let initial_timer = timer_after.unwrap(); + assert!(initial_timer.as_secs() <= 10); + + peer1_ctx.advance_time(Duration::from_secs(5)); + + let timer_decreased = peer1.next_timer(); + assert!(timer_decreased.is_some()); + assert!(timer_decreased.unwrap() < initial_timer); + + peer1_ctx.advance_time(Duration::from_secs(20)); + + let timer_zero = peer1.next_timer(); + assert!(timer_zero.is_some()); + assert_eq!(timer_zero.unwrap(), Duration::ZERO); +} + +//1. peer1 initiates to peer2 +//2. peer2 receives init and creates responder session +//3. advance time past session timeout on peer2 +//4. peer2 tick triggers responder timeout +//5. verify responder session terminated +#[test] +fn test_responder_timeout() { + init_tracing(); + let config = Config::default(); + + let mut rng = rng(); + let (peer1_public, peer1_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let peer1_ctx = TestContext::new(); + let mut peer1 = API::new(config.clone(), peer1_private, peer1_public, peer1_ctx); + + let (peer2_public, peer2_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let peer2_ctx = TestContext::new(); + let mut peer2 = API::new( + config, + peer2_private, + peer2_public.clone(), + peer2_ctx.clone(), + ); + + let peer1_addr: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + + peer1 + .connect(peer2_public, peer2_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + + let init = collect::(&mut peer1); + dispatch(&mut peer2, &init, peer1_addr); + + let timer_before_timeout = peer2.next_timer(); + assert!(timer_before_timeout.is_some()); + + peer2_ctx.advance_time(Duration::from_secs(11)); + peer2.tick(); + + let timer_after_timeout = peer2.next_timer(); + assert!(timer_after_timeout.is_some()); +} + +#[test] +fn test_next_timer_includes_filter_reset() { + init_tracing(); + let mut rng = rng(); + let config = Config { + handshake_rate_reset_interval: Duration::from_secs(5), + ..Config::default() + }; + + let (peer_public, peer_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let peer_ctx = TestContext::new(); + let peer = API::new(config, peer_private, peer_public, peer_ctx); + + let next_timer = peer.next_timer(); + assert!(next_timer.is_some()); + assert_eq!(next_timer.unwrap(), Duration::from_secs(5)); +} + +#[test] +fn test_next_timer_returns_minimum_of_session_and_filter() { + init_tracing(); + let mut rng = rng(); + let config = Config::default(); + + let (peer1_public, peer1_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let peer1_ctx = TestContext::new(); + let mut peer1 = API::new(config.clone(), peer1_private, peer1_public, peer1_ctx); + + let (peer2_public, peer2_private) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let peer2_ctx = TestContext::new(); + let mut peer2 = API::new(config, peer2_private, peer2_public.clone(), peer2_ctx); + + let peer2_addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + let peer1_addr: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + + peer1 + .connect(peer2_public, peer2_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + + let init = collect::(&mut peer1); + dispatch(&mut peer2, &init, peer1_addr); + + let response = collect::(&mut peer2); + dispatch(&mut peer1, &response, peer2_addr); + + collect::(&mut peer1); + + let next_timer = peer1.next_timer(); + assert!(next_timer.is_some()); + let timer_value = next_timer.unwrap(); + assert!(timer_value <= Duration::from_secs(3)); + assert!(timer_value > Duration::ZERO); +} + +#[test] +fn test_disconnect() { + init_tracing(); + let (mut peer1, _peer1_pubkey, _, _) = create_manager(); + let (mut peer2, peer2_pubkey, _, _) = create_manager(); + let peer1_addr: SocketAddr = "127.0.0.1:8001".parse().unwrap(); + let peer2_addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); + + peer1 + .connect(peer2_pubkey.clone(), peer2_addr, DEFAULT_RETRY_ATTEMPTS) + .unwrap(); + + let init = collect::(&mut peer1); + dispatch(&mut peer2, &init, peer1_addr); + + let response = collect::(&mut peer2); + dispatch(&mut peer1, &response, peer2_addr); + + collect::(&mut peer1); + + let mut plaintext = b"hello".to_vec(); + let encrypted = encrypt(&mut peer1, &peer2_pubkey, &mut plaintext); + let decrypted = decrypt(&mut peer2, &encrypted, peer1_addr); + assert_eq!(&decrypted, b"hello"); + + peer1.disconnect(&peer2_pubkey); + + let mut plaintext2 = b"world".to_vec(); + let result = peer1.encrypt_by_public_key(&peer2_pubkey, &mut plaintext2); + assert!(result.is_err()); +} diff --git a/monad-wireauth-api/tests/two_peer_model.rs b/monad-wireauth-api/tests/two_peer_model.rs new file mode 100644 index 0000000000..9681f445c5 --- /dev/null +++ b/monad-wireauth-api/tests/two_peer_model.rs @@ -0,0 +1,421 @@ +mod tests { + use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Once, + time::Duration, + }; + + use monad_wireauth_api::{Config, TestContext, API}; + use monad_wireauth_protocol::common::PublicKey; + use proptest::prelude::*; + use secp256k1::rand::rng; + use tracing_subscriber::EnvFilter; + use zerocopy::IntoBytes; + + static INIT: Once = Once::new(); + + const TIME_ADVANCE_MILLIS: u64 = 1; + const REKEY_INTERVAL_SECS: u64 = 10; + const SESSION_TIMEOUT_SECS: u64 = 20; + + fn init_logging() { + INIT.call_once(|| { + let _ = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .try_init(); + }); + } + + fn test_config() -> Config { + Config { + rekey_interval: Duration::from_secs(REKEY_INTERVAL_SECS), + session_timeout: Duration::from_secs(SESSION_TIMEOUT_SECS), + ..Config::default() + } + } + + struct PeerState { + public_key: PublicKey, + private_key: monad_wireauth_protocol::common::PrivateKey, + manager: API, + context: TestContext, + addr: SocketAddr, + sent_data: Vec>, + received_data: Vec>, + } + + impl PeerState { + fn new(peer_id: u8) -> Self { + let mut rng = rng(); + let (public_key, private_key) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + let context = TestContext::new(); + let config = test_config(); + let manager = API::new( + config, + private_key.clone(), + public_key.clone(), + context.clone(), + ); + let addr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 0, 0, peer_id)), + 30000 + peer_id as u16, + ); + + Self { + public_key, + private_key, + manager, + context, + addr, + sent_data: Vec::new(), + received_data: Vec::new(), + } + } + } + + #[derive(Debug, Clone, Copy, PartialEq)] + enum ConnectionState { + Connected, + Disconnected, + } + + struct TwoPeerModel { + peers: [PeerState; 2], + initiated: bool, + expect_success: bool, + connection_state: ConnectionState, + connected_for: Duration, + } + + impl TwoPeerModel { + fn new() -> Self { + Self { + peers: [PeerState::new(1), PeerState::new(2)], + initiated: false, + expect_success: false, + connection_state: ConnectionState::Connected, + connected_for: Duration::ZERO, + } + } + + fn should_deliver_message(&self) -> bool { + match self.connection_state { + ConnectionState::Connected => true, + ConnectionState::Disconnected => false, + } + } + + fn process_all_messages(&mut self) { + loop { + let mut had_packets = false; + let mut packets_to_process = Vec::new(); + + for i in 0..2 { + while let Some((_dst, packet)) = self.peers[i].manager.next_packet() { + had_packets = true; + let other = i ^ 1; + packets_to_process.push((i, other, self.peers[i].addr, packet.to_vec())); + } + } + + if !had_packets { + break; + } + + for (_sender_idx, receiver_idx, src_addr, packet) in packets_to_process { + if !self.should_deliver_message() { + continue; + } + + let mut packet_copy = packet; + if let Ok(Some(plaintext)) = self.peers[receiver_idx] + .manager + .dispatch(&mut packet_copy, src_addr) + { + self.peers[receiver_idx] + .received_data + .push(plaintext.to_vec()); + } + } + } + } + + fn apply_action(&mut self, action: Action) { + match action { + Action::Disconnect => { + self.connection_state = ConnectionState::Disconnected; + self.connected_for = Duration::ZERO; + self.initiated = false; + self.expect_success = false; + } + + Action::Connect => { + if self.connection_state == ConnectionState::Disconnected { + self.connection_state = ConnectionState::Connected; + self.connected_for = Duration::ZERO; + } + } + + Action::Initiate { from } => { + let time_advance = Duration::from_millis(TIME_ADVANCE_MILLIS); + self.peers[0].context.advance_time(time_advance); + self.peers[1].context.advance_time(time_advance); + + let initiator_idx = (from - 1) as usize; + let responder_idx = initiator_idx ^ 1; + + let responder_pubkey = self.peers[responder_idx].public_key.clone(); + let responder_addr = self.peers[responder_idx].addr; + + let _ = self.peers[initiator_idx].manager.connect( + responder_pubkey, + responder_addr, + monad_wireauth_session::RETRY_ALWAYS, + ); + + self.process_all_messages(); + + if self.connection_state == ConnectionState::Connected { + self.expect_success = true; + } + self.initiated = true + } + + Action::Tick { seconds } => { + let duration = Duration::from_secs(seconds as u64); + + for i in 0..2 { + self.peers[i].context.advance_time(duration); + self.peers[i].manager.tick(); + } + + self.process_all_messages(); + + if self.connection_state == ConnectionState::Connected && self.initiated { + self.connected_for = self.connected_for.saturating_add(duration); + let rekey_interval = Duration::from_secs(REKEY_INTERVAL_SECS); + let session_timeout = Duration::from_secs(SESSION_TIMEOUT_SECS); + + if self.connected_for >= rekey_interval + && self.connected_for >= session_timeout + { + self.expect_success = true; + } + } + } + + Action::Send { from, data } => { + let data_bytes = vec![from; data as usize]; + let expect_success = self.expect_success; + let should_deliver = self.should_deliver_message(); + + let sender_idx = (from - 1) as usize; + let receiver_idx = sender_idx ^ 1; + + let receiver_pubkey = self.peers[receiver_idx].public_key.clone(); + let sender_addr = self.peers[sender_idx].addr; + + self.peers[sender_idx].sent_data.push(data_bytes.clone()); + let mut plaintext = data_bytes.clone(); + let send_result = self.peers[sender_idx] + .manager + .encrypt_by_public_key(&receiver_pubkey, &mut plaintext); + + if expect_success { + let header = send_result.unwrap_or_else(|e| { + panic!("after initiation and while connected, send should succeed, error={e:?}") + }); + + let mut packet = header.as_bytes().to_vec(); + packet.extend_from_slice(&plaintext); + + if should_deliver { + let dispatch_result = self.peers[receiver_idx] + .manager + .dispatch(&mut packet, sender_addr); + + let received = match dispatch_result { + Ok(Some(data)) => data, + Ok(None) => panic!("when send succeeds and connected, dispatch should return some data, got none"), + Err(e) => panic!("when send succeeds and connected, dispatch should succeed, error={e:?}"), + }; + + self.peers[receiver_idx] + .received_data + .push(received.to_vec()); + assert_eq!( + data_bytes, + received.to_vec(), + "dispatch should return decrypted data matching sent data" + ); + } + } + + self.process_all_messages(); + } + + Action::Reset { peer } => { + let peer_idx = (peer - 1) as usize; + let context = self.peers[peer_idx].context.clone(); + let config = test_config(); + + self.peers[peer_idx].manager = API::new( + config, + self.peers[peer_idx].private_key.clone(), + self.peers[peer_idx].public_key.clone(), + context.clone(), + ); + self.peers[peer_idx].context = context; + + self.connected_for = Duration::ZERO; + self.initiated = false; + self.expect_success = false; + } + + Action::Migrate { peer, new_addr } => { + let peer_idx = (peer - 1) as usize; + self.peers[peer_idx].addr = new_addr; + + let context = self.peers[peer_idx].context.clone(); + let config = test_config(); + + self.peers[peer_idx].manager = API::new( + config, + self.peers[peer_idx].private_key.clone(), + self.peers[peer_idx].public_key.clone(), + context.clone(), + ); + self.peers[peer_idx].context = context; + + self.connected_for = Duration::ZERO; + self.initiated = false; + self.expect_success = false; + } + } + } + } + + #[derive(Debug, Clone)] + enum Action { + Initiate { from: u8 }, + Tick { seconds: u8 }, + Send { from: u8, data: u8 }, + Reset { peer: u8 }, + Migrate { peer: u8, new_addr: SocketAddr }, + Disconnect, + Connect, + } + + fn basic_action_strategy() -> impl Strategy { + prop_oneof![ + (1..=2u8).prop_map(|from| Action::Initiate { from }), + (1..=3u8).prop_map(|seconds| Action::Tick { seconds }), + (1..=2u8, 1..=100u8).prop_map(|(from, data)| Action::Send { from, data }), + ] + } + + fn reset_action_strategy() -> impl Strategy { + prop_oneof![ + (1..=2u8).prop_map(|from| Action::Initiate { from }), + (1..=3u8).prop_map(|seconds| Action::Tick { seconds }), + (1..=2u8, 1..=100u8).prop_map(|(from, data)| Action::Send { from, data }), + (1..=2u8).prop_map(|peer| Action::Reset { peer }), + ] + } + + fn migration_action_strategy() -> impl Strategy { + prop_oneof![ + (1..=2u8).prop_map(|from| Action::Initiate { from }), + (1..=3u8).prop_map(|seconds| Action::Tick { seconds }), + (1..=2u8, 1..=100u8).prop_map(|(from, data)| Action::Send { from, data }), + (1..=2u8, 0..=255u8, 0..=255u8).prop_map(|(peer, octet3, octet4)| Action::Migrate { + peer, + new_addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 0, octet3, octet4)), + 30000 + peer as u16, + ) + }), + ] + } + + fn message_loss_action_strategy() -> impl Strategy { + prop_oneof![ + (1..=2u8).prop_map(|from| Action::Initiate { from }), + (1..=3u8).prop_map(|seconds| Action::Tick { seconds }), + (1..=2u8, 1..=100u8).prop_map(|(from, data)| Action::Send { from, data }), + Just(Action::Disconnect), + Just(Action::Connect), + ] + } + + fn all_actions_strategy() -> impl Strategy { + prop_oneof![ + (1..=2u8).prop_map(|from| Action::Initiate { from }), + (1..=3u8).prop_map(|seconds| Action::Tick { seconds }), + (1..=2u8, 1..=100u8).prop_map(|(from, data)| Action::Send { from, data }), + (1..=2u8).prop_map(|peer| Action::Reset { peer }), + (1..=2u8, 0..=255u8, 0..=255u8).prop_map(|(peer, octet3, octet4)| Action::Migrate { + peer, + new_addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 0, octet3, octet4)), + 30000 + peer as u16, + ) + }), + Just(Action::Disconnect), + Just(Action::Connect), + ] + } + + proptest! { + #[test] + fn test_two_peer_model(actions in prop::collection::vec(basic_action_strategy(), 1..50)) { + init_logging(); + let mut model = TwoPeerModel::new(); + + for action in actions { + model.apply_action(action); + } + } + + #[test] + fn test_two_peer_model_with_memory_reset(actions in prop::collection::vec(reset_action_strategy(), 1..50)) { + init_logging(); + let mut model = TwoPeerModel::new(); + + for action in actions { + model.apply_action(action); + } + } + + #[test] + fn test_two_peer_model_with_ip_migration(actions in prop::collection::vec(migration_action_strategy(), 1..50)) { + init_logging(); + let mut model = TwoPeerModel::new(); + + for action in actions { + model.apply_action(action); + } + } + + #[test] + fn test_two_peer_model_with_message_loss(actions in prop::collection::vec(message_loss_action_strategy(), 1..100)) { + init_logging(); + let mut model = TwoPeerModel::new(); + + for action in actions { + model.apply_action(action); + } + } + + #[test] + fn test_two_peer_model_combined_all_failure_scenarios(actions in prop::collection::vec(all_actions_strategy(), 1..200)) { + init_logging(); + let mut model = TwoPeerModel::new(); + + for action in actions { + model.apply_action(action); + } + } + } +} diff --git a/monad-wireauth-protocol/Cargo.toml b/monad-wireauth-protocol/Cargo.toml new file mode 100644 index 0000000000..0e2d3699ed --- /dev/null +++ b/monad-wireauth-protocol/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "monad-wireauth-protocol" +version = "0.1.0" +edition = "2021" + +[dependencies] +zerocopy = { workspace = true, features = ["derive"] } +blake3.workspace = true +aegis.workspace = true +secp256k1 = { workspace = true, features = ["global-context", "recovery", "rand"] } +rand.workspace = true +tai64.workspace = true +bytes.workspace = true +thiserror.workspace = true +zeroize = { workspace = true, features = ["derive"] } +tracing.workspace = true +hex.workspace = true + +[dev-dependencies] +insta = { workspace = true, features = ["yaml"] } +serde = { workspace = true, features = ["derive"] } diff --git a/monad-wireauth-protocol/src/common.rs b/monad-wireauth-protocol/src/common.rs new file mode 100644 index 0000000000..c3c822cd9b --- /dev/null +++ b/monad-wireauth-protocol/src/common.rs @@ -0,0 +1,343 @@ +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, LE, U32}; +use zeroize::{Zeroize, ZeroizeOnDrop}; + +use crate::errors::CryptoError; + +pub const CIPHER_TAG_SIZE: usize = 16; +pub const MAC_TAG_SIZE: usize = 16; +pub const PUBLIC_KEY_SIZE: usize = 33; +pub const HASH_OUTPUT_SIZE: usize = 32; + +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct SessionIndex(u32); + +impl SessionIndex { + pub const MIN: SessionIndex = SessionIndex(0); + pub const MAX: SessionIndex = SessionIndex(u32::MAX); + + pub fn new(value: u32) -> Self { + SessionIndex(value) + } + + pub fn as_u32(&self) -> u32 { + self.0 + } + + pub fn increment(&mut self) { + self.0 = self.0.wrapping_add(1); + } +} + +impl From for SessionIndex { + fn from(value: u32) -> Self { + SessionIndex(value) + } +} + +impl From> for SessionIndex { + fn from(value: U32) -> Self { + SessionIndex(value.get()) + } +} + +impl std::fmt::Display for SessionIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::fmt::Debug for SessionIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Clone, Zeroize, ZeroizeOnDrop, Debug, PartialEq)] +pub struct CipherKey([u8; 16]); + +impl From<&HashOutput> for CipherKey { + fn from(hash: &HashOutput) -> Self { + let mut key = [0u8; 16]; + key.copy_from_slice(&hash.0[..16]); + CipherKey(key) + } +} + +impl AsRef<[u8]> for CipherKey { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl AsRef<[u8; 16]> for CipherKey { + fn as_ref(&self) -> &[u8; 16] { + &self.0 + } +} + +#[repr(transparent)] +#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable, KnownLayout, Debug)] +pub struct CipherNonce(pub [u8; 16]); + +impl From for CipherNonce { + fn from(value: u64) -> Self { + let mut nonce = [0u8; 16]; + nonce[..8].copy_from_slice(&value.to_le_bytes()); + CipherNonce(nonce) + } +} + +impl AsRef<[u8]> for CipherNonce { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl AsRef<[u8; 16]> for CipherNonce { + fn as_ref(&self) -> &[u8; 16] { + &self.0 + } +} + +#[derive(Clone, Debug, Zeroize)] +pub struct HashOutput(pub [u8; 32]); + +impl AsRef<[u8]> for HashOutput { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl AsRef<[u8; 32]> for HashOutput { + fn as_ref(&self) -> &[u8; 32] { + &self.0 + } +} + +#[repr(transparent)] +#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable, KnownLayout, Eq)] +pub struct MacTag(pub [u8; 16]); + +impl AsRef<[u8]> for MacTag { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl AsRef<[u8; 16]> for MacTag { + fn as_ref(&self) -> &[u8; 16] { + &self.0 + } +} + +impl PartialEq<[u8; 16]> for MacTag { + fn eq(&self, other: &[u8; 16]) -> bool { + self.0 == *other + } +} + +impl PartialEq for MacTag { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl From for [u8; 16] { + fn from(tag: MacTag) -> Self { + tag.0 + } +} + +impl From<[u8; 16]> for MacTag { + fn from(bytes: [u8; 16]) -> Self { + MacTag(bytes) + } +} + +impl From for MacTag { + fn from(hash: HashOutput) -> Self { + let mut tag = [0u8; 16]; + tag.copy_from_slice(&hash.0[..16]); + MacTag(tag) + } +} + +#[repr(transparent)] +#[derive( + Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, FromBytes, IntoBytes, Immutable, KnownLayout, +)] +pub struct SerializedPublicKey([u8; 33]); + +impl SerializedPublicKey { + pub fn as_bytes(&self) -> &[u8; 33] { + &self.0 + } + + pub fn as_mut_bytes(&mut self) -> &mut [u8; 33] { + &mut self.0 + } +} + +impl From<&PublicKey> for SerializedPublicKey { + fn from(public_key: &PublicKey) -> Self { + SerializedPublicKey(public_key.0.serialize()) + } +} + +impl From for SerializedPublicKey { + fn from(public_key: PublicKey) -> Self { + SerializedPublicKey(public_key.0.serialize()) + } +} + +impl From<[u8; 33]> for SerializedPublicKey { + fn from(bytes: [u8; 33]) -> Self { + SerializedPublicKey(bytes) + } +} + +impl std::fmt::Debug for SerializedPublicKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{:02x}{:02x}{:02x}{:02x}", + self.0[0], self.0[1], self.0[2], self.0[3] + ) + } +} + +impl std::fmt::Display for SerializedPublicKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for byte in &self.0 { + write!(f, "{:02x}", byte)?; + } + Ok(()) + } +} + +#[derive(Clone)] +pub struct PublicKey(secp256k1::PublicKey); + +impl std::fmt::Debug for PublicKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let bytes = self.0.serialize(); + write!( + f, + "{:02x}{:02x}{:02x}{:02x}", + bytes[0], bytes[1], bytes[2], bytes[3] + ) + } +} + +impl From for PublicKey { + fn from(key: secp256k1::PublicKey) -> Self { + PublicKey(key) + } +} + +impl From for [u8; 33] { + fn from(public_key: PublicKey) -> Self { + public_key.0.serialize() + } +} + +impl From<&PublicKey> for [u8; 33] { + fn from(public_key: &PublicKey) -> Self { + public_key.0.serialize() + } +} + +impl TryFrom<[u8; 33]> for PublicKey { + type Error = CryptoError; + + fn try_from(bytes: [u8; 33]) -> Result { + let public_key = secp256k1::PublicKey::from_slice(&bytes)?; + Ok(PublicKey(public_key)) + } +} + +impl TryFrom<&[u8; 33]> for PublicKey { + type Error = CryptoError; + + fn try_from(bytes: &[u8; 33]) -> Result { + let public_key = secp256k1::PublicKey::from_slice(bytes)?; + Ok(PublicKey(public_key)) + } +} + +impl TryFrom for PublicKey { + type Error = CryptoError; + + fn try_from(key: SerializedPublicKey) -> Result { + let public_key = secp256k1::PublicKey::from_slice(key.as_bytes())?; + Ok(PublicKey(public_key)) + } +} + +impl TryFrom<&SerializedPublicKey> for PublicKey { + type Error = CryptoError; + + fn try_from(key: &SerializedPublicKey) -> Result { + let public_key = secp256k1::PublicKey::from_slice(key.as_bytes())?; + Ok(PublicKey(public_key)) + } +} + +impl PublicKey { + pub(crate) fn inner(&self) -> &secp256k1::PublicKey { + &self.0 + } +} + +#[derive(Clone)] +pub struct PrivateKey(secp256k1::SecretKey); + +impl Drop for PrivateKey { + fn drop(&mut self) { + self.0.non_secure_erase(); + } +} + +impl PrivateKey { + pub(crate) fn inner(&self) -> &secp256k1::SecretKey { + &self.0 + } + + pub(crate) fn from_inner(key: secp256k1::SecretKey) -> Self { + PrivateKey(key) + } + + pub fn from_bytes(bytes: &[u8]) -> Result { + let array: [u8; 32] = bytes + .try_into() + .map_err(|_| secp256k1::Error::InvalidSecretKey)?; + let secret = secp256k1::SecretKey::from_byte_array(array)?; + Ok(PrivateKey(secret)) + } +} + +#[derive(Clone, Zeroize, ZeroizeOnDrop)] +pub struct SharedSecret(pub [u8; 32]); + +impl AsRef<[u8]> for SharedSecret { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl AsRef<[u8; 32]> for SharedSecret { + fn as_ref(&self) -> &[u8; 32] { + &self.0 + } +} + +impl From<[u8; 32]> for SharedSecret { + fn from(bytes: [u8; 32]) -> Self { + SharedSecret(bytes) + } +} + +pub struct TransportKeys { + pub send_key: CipherKey, + pub recv_key: CipherKey, +} diff --git a/monad-wireauth-protocol/src/cookies.rs b/monad-wireauth-protocol/src/cookies.rs new file mode 100644 index 0000000000..69031890ac --- /dev/null +++ b/monad-wireauth-protocol/src/cookies.rs @@ -0,0 +1,392 @@ +use std::net::SocketAddr; + +use crate::{ + common::*, + crypto::{decrypt_in_place, encrypt_in_place, LABEL_COOKIE}, + errors::{CookieError, ProtocolError}, + hash, keyed_hash, + messages::*, +}; + +pub fn send_cookie_reply( + nonce_secret: &[u8; 32], + nonce_counter: u128, + responder_static_public: &SerializedPublicKey, + msg_sender_index: u32, + msg_mac1: &[u8; 16], + cookie: &[u8; 16], +) -> Result { + let nonce_hash = keyed_hash!(nonce_secret, &nonce_counter.to_le_bytes()); + let hash_bytes: &[u8] = nonce_hash.as_ref(); + let mut nonce_bytes = [0u8; 16]; + nonce_bytes.copy_from_slice(&hash_bytes[..16]); + let nonce = CipherNonce(nonce_bytes); + + let mut reply = CookieReply { + receiver_index: msg_sender_index.into(), + nonce, + ..Default::default() + }; + + let temp_key = hash!(LABEL_COOKIE, responder_static_public.as_bytes()); + + reply.encrypted_cookie = *cookie; + reply.encrypted_cookie_tag = encrypt_in_place( + &(&temp_key).into(), + &reply.nonce, + &mut reply.encrypted_cookie, + msg_mac1, + ); + + Ok(reply) +} + +/// decrypts cookie in place and returns the decrypted cookie as a separate buffer for convenience +pub fn accept_cookie_reply( + responder_static_public: &SerializedPublicKey, + reply: &mut CookieReply, + msg_mac1: &[u8; 16], +) -> Result<[u8; 16], ProtocolError> { + let temp_key = hash!(LABEL_COOKIE, responder_static_public.as_bytes()); + + decrypt_in_place( + &(&temp_key).into(), + &reply.nonce, + &mut reply.encrypted_cookie, + &reply.encrypted_cookie_tag, + msg_mac1, + ) + .map_err(CookieError::CookieDecryptionFailed)?; + + Ok(reply.encrypted_cookie) +} + +pub fn generate_cookie(cookie_secret: &[u8; 32], nonce: u64, remote_addr: &SocketAddr) -> [u8; 16] { + let mut address_bytes = [0u8; 18]; + match remote_addr.ip() { + std::net::IpAddr::V4(addr) => { + address_bytes[..4].copy_from_slice(&addr.octets()); + } + std::net::IpAddr::V6(addr) => { + address_bytes[..16].copy_from_slice(&addr.octets()); + } + }; + address_bytes[16..18].copy_from_slice(&remote_addr.port().to_le_bytes()); + + let cookie_hash = keyed_hash!(cookie_secret, &nonce.to_le_bytes(), &address_bytes); + let mut cookie = [0u8; 16]; + let hash_bytes: &[u8] = cookie_hash.as_ref(); + cookie.copy_from_slice(&hash_bytes[..16]); + cookie +} + +pub fn verify_cookie( + cookie_secret: &[u8; 32], + nonce: u64, + remote_addr: &SocketAddr, + static_public: &SerializedPublicKey, + message: &M, +) -> Result<(), ProtocolError> { + let expected_cookie = generate_cookie(cookie_secret, nonce, remote_addr); + crate::crypto::verify_mac2(message, static_public, &expected_cookie) + .map_err(|e| CookieError::InvalidCookieMac(e).into()) +} + +#[cfg(test)] +mod tests { + use std::convert::TryFrom; + + use secp256k1::rand::rng; + use zerocopy::IntoBytes; + + use super::*; + use crate::messages::{CookieReply, HandshakeInitiation, HandshakeResponse}; + + #[test] + fn test_cookie_send_and_accept() { + let mut rng = rng(); + + let (_initiator_public, _initiator_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + let (responder_public, _responder_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + + let msg_sender_index = 12345u32; + let msg_mac1 = [0x42u8; 16]; + + let cookie_secret = [0x33u8; 32]; + let nonce = 555u64; + let initiator_addr: SocketAddr = "127.0.0.1:51820".parse().unwrap(); + let cookie = generate_cookie(&cookie_secret, nonce, &initiator_addr); + + let nonce_secret = [0x44u8; 32]; + let cookie_nonce = 0u128; + let reply = send_cookie_reply( + &nonce_secret, + cookie_nonce, + &SerializedPublicKey::from(&responder_public), + msg_sender_index, + &msg_mac1, + &cookie, + ) + .expect("Failed to create cookie reply"); + + let reply_bytes = reply.as_bytes(); + let mut reply_bytes_mut = reply_bytes.to_vec(); + let reply = <&mut CookieReply>::try_from(reply_bytes_mut.as_mut_slice()) + .expect("Failed to parse cookie reply"); + + let decrypted_cookie = accept_cookie_reply( + &SerializedPublicKey::from(&responder_public), + reply, + &msg_mac1, + ) + .expect("Failed to accept cookie reply"); + + assert_eq!(cookie, decrypted_cookie); + } + + #[test] + fn test_cookie_with_wrong_mac1_fails() { + let mut rng = rng(); + + let (_initiator_public, _initiator_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + let (responder_public, _responder_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + + let msg_sender_index = 12345u32; + let msg_mac1 = [0x42u8; 16]; + let wrong_mac1 = [0x99u8; 16]; + + let cookie_secret = [0x33u8; 32]; + let nonce = 555u64; + let initiator_addr: SocketAddr = "127.0.0.1:51820".parse().unwrap(); + let cookie = generate_cookie(&cookie_secret, nonce, &initiator_addr); + + let nonce_secret = [0x55u8; 32]; + let cookie_nonce = 1u128; + let reply = send_cookie_reply( + &nonce_secret, + cookie_nonce, + &SerializedPublicKey::from(&responder_public), + msg_sender_index, + &msg_mac1, + &cookie, + ) + .expect("Failed to create cookie reply"); + + let reply_bytes = reply.as_bytes(); + let mut reply_bytes_mut = reply_bytes.to_vec(); + let reply = <&mut CookieReply>::try_from(reply_bytes_mut.as_mut_slice()) + .expect("Failed to parse cookie reply"); + + let result = accept_cookie_reply( + &SerializedPublicKey::from(&responder_public), + reply, + &wrong_mac1, + ); + + assert!(result.is_err()); + } + + #[test] + fn test_cookie_with_wrong_public_key_fails() { + let mut rng = rng(); + + let (_initiator_public, _initiator_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + let (responder_public, _responder_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + let (wrong_public, _wrong_private) = crate::crypto::generate_keypair(&mut rng).unwrap(); + + let msg_sender_index = 12345u32; + let msg_mac1 = [0x42u8; 16]; + + let cookie_secret = [0x33u8; 32]; + let nonce = 555u64; + let initiator_addr: SocketAddr = "127.0.0.1:51820".parse().unwrap(); + let cookie = generate_cookie(&cookie_secret, nonce, &initiator_addr); + + let nonce_secret = [0x66u8; 32]; + let cookie_nonce = 2u128; + let reply = send_cookie_reply( + &nonce_secret, + cookie_nonce, + &SerializedPublicKey::from(&responder_public), + msg_sender_index, + &msg_mac1, + &cookie, + ) + .expect("Failed to create cookie reply"); + + let reply_bytes = reply.as_bytes(); + let mut reply_bytes_mut = reply_bytes.to_vec(); + let reply = <&mut CookieReply>::try_from(reply_bytes_mut.as_mut_slice()) + .expect("Failed to parse cookie reply"); + + let result = + accept_cookie_reply(&SerializedPublicKey::from(&wrong_public), reply, &msg_mac1); + + assert!(result.is_err()); + } + + #[test] + fn test_generate_cookie_ipv4() { + let cookie_secret = [0x11u8; 32]; + let nonce = 42u64; + let addr: SocketAddr = "192.168.1.1:8080".parse().unwrap(); + + let cookie1 = generate_cookie(&cookie_secret, nonce, &addr); + let cookie2 = generate_cookie(&cookie_secret, nonce, &addr); + assert_eq!(cookie1, cookie2); + + let addr2: SocketAddr = "192.168.1.2:8080".parse().unwrap(); + let cookie3 = generate_cookie(&cookie_secret, nonce, &addr2); + assert_ne!(cookie1, cookie3); + } + + #[test] + fn test_generate_cookie_ipv6() { + let cookie_secret = [0x22u8; 32]; + let nonce = 99u64; + let addr: SocketAddr = "[2001:db8::1]:8080".parse().unwrap(); + + let cookie1 = generate_cookie(&cookie_secret, nonce, &addr); + let cookie2 = generate_cookie(&cookie_secret, nonce, &addr); + assert_eq!(cookie1, cookie2); + + let addr2: SocketAddr = "[2001:db8::2]:8080".parse().unwrap(); + let cookie3 = generate_cookie(&cookie_secret, nonce, &addr2); + assert_ne!(cookie1, cookie3); + } + + #[test] + fn test_verify_cookie_with_zero_mac2() { + let mut rng = rng(); + let (responder_public, _) = crate::crypto::generate_keypair(&mut rng).unwrap(); + + let cookie_secret = [0x33u8; 32]; + let nonce = 555u64; + let addr: SocketAddr = "127.0.0.1:51820".parse().unwrap(); + + let msg = HandshakeInitiation { + mac2: [0u8; 16].into(), + ..Default::default() + }; + + let result = verify_cookie( + &cookie_secret, + nonce, + &addr, + &SerializedPublicKey::from(&responder_public), + &msg, + ); + assert!(result.is_err()); + } + + #[test] + fn test_verify_cookie_with_valid_mac2() { + let mut rng = rng(); + let (responder_public, _) = crate::crypto::generate_keypair(&mut rng).unwrap(); + + let cookie_secret = [0x33u8; 32]; + let nonce = 555u64; + let addr: SocketAddr = "127.0.0.1:51820".parse().unwrap(); + + let cookie = generate_cookie(&cookie_secret, nonce, &addr); + + let mut msg = HandshakeInitiation::default(); + + let responder_static_bytes: [u8; 33] = (&responder_public).into(); + let cookie_key = crate::hash!(crate::crypto::LABEL_COOKIE, &responder_static_bytes); + let mac2: crate::common::MacTag = + crate::keyed_hash!(cookie_key.as_ref(), msg.mac2_input(), &cookie).into(); + msg.mac2 = mac2; + + let result = verify_cookie( + &cookie_secret, + nonce, + &addr, + &SerializedPublicKey::from(&responder_public), + &msg, + ); + assert!(result.is_ok()); + } + + #[test] + fn test_verify_cookie_with_wrong_mac2() { + let mut rng = rng(); + let (responder_public, _) = crate::crypto::generate_keypair(&mut rng).unwrap(); + + let cookie_secret = [0x33u8; 32]; + let nonce = 555u64; + let addr: SocketAddr = "127.0.0.1:51820".parse().unwrap(); + + let msg = HandshakeInitiation { + mac2: [0xFFu8; 16].into(), + ..Default::default() + }; + + let result = verify_cookie( + &cookie_secret, + nonce, + &addr, + &SerializedPublicKey::from(&responder_public), + &msg, + ); + assert!(result.is_err()); + } + + #[test] + fn test_verify_cookie_response() { + let mut rng = rng(); + let (initiator_public, _) = crate::crypto::generate_keypair(&mut rng).unwrap(); + + let cookie_secret = [0x44u8; 32]; + let nonce = 666u64; + let addr: SocketAddr = "10.0.0.1:12345".parse().unwrap(); + + let cookie = generate_cookie(&cookie_secret, nonce, &addr); + + let mut msg = HandshakeResponse::default(); + + let initiator_static_bytes: [u8; 33] = (&initiator_public).into(); + let cookie_key = crate::hash!(crate::crypto::LABEL_COOKIE, &initiator_static_bytes); + let mac2 = crate::keyed_hash!(cookie_key.as_ref(), msg.mac2_input(), &cookie).into(); + msg.mac2 = mac2; + + let result = verify_cookie( + &cookie_secret, + nonce, + &addr, + &SerializedPublicKey::from(&initiator_public), + &msg, + ); + assert!(result.is_ok()); + } + + #[test] + fn test_verify_cookie_response_with_wrong_mac2() { + let mut rng = rng(); + let (initiator_public, _) = crate::crypto::generate_keypair(&mut rng).unwrap(); + + let cookie_secret = [0x55u8; 32]; + let nonce = 777u64; + let addr: SocketAddr = "10.0.0.2:54321".parse().unwrap(); + + let msg = HandshakeResponse { + mac2: [0xAAu8; 16].into(), + ..Default::default() + }; + + let result = verify_cookie( + &cookie_secret, + nonce, + &addr, + &SerializedPublicKey::from(&initiator_public), + &msg, + ); + assert!(result.is_err()); + } +} diff --git a/monad-wireauth-protocol/src/crypto.rs b/monad-wireauth-protocol/src/crypto.rs new file mode 100644 index 0000000000..d62dcb282d --- /dev/null +++ b/monad-wireauth-protocol/src/crypto.rs @@ -0,0 +1,123 @@ +use crate::{common::*, errors::CryptoError}; + +pub const CONSTRUCTION: &[u8] = b"Noise_IKpsk2_secp256k1_AEGIS128L_BLAKE3"; +pub const IDENTIFIER: &[u8] = b"authenticated udp v1 -- monad"; +pub const LABEL_MAC1: &[u8] = b"mac1----"; +pub const LABEL_COOKIE: &[u8] = b"cookie--"; + +pub fn encrypt_in_place( + key: &CipherKey, + nonce: &CipherNonce, + data: &mut [u8], + ad: &[u8], +) -> [u8; 16] { + let cipher = aegis::aegis128l::Aegis128L::<16>::new(key.as_ref(), nonce.as_ref()); + cipher.encrypt_in_place(data, ad) +} + +pub fn decrypt_in_place( + key: &CipherKey, + nonce: &CipherNonce, + data: &mut [u8], + tag: &[u8; 16], + ad: &[u8], +) -> Result<(), CryptoError> { + let cipher = aegis::aegis128l::Aegis128L::<16>::new(key.as_ref(), nonce.as_ref()); + cipher + .decrypt_in_place(data, tag, ad) + .map_err(|_| CryptoError::MacVerificationFailed) +} + +#[macro_export] +macro_rules! hash { + ($data:expr) => {{ + use $crate::common::HashOutput; + HashOutput(blake3::hash($data).into()) + }}; + ($data1:expr, $data2:expr) => {{ + use $crate::common::HashOutput; + let mut hasher = blake3::Hasher::new(); + hasher.update($data1); + hasher.update($data2); + HashOutput(hasher.finalize().into()) + }}; + ($data1:expr, $data2:expr, $data3:expr) => {{ + use $crate::common::HashOutput; + let mut hasher = blake3::Hasher::new(); + hasher.update($data1); + hasher.update($data2); + hasher.update($data3); + HashOutput(hasher.finalize().into()) + }}; +} + +#[macro_export] +macro_rules! keyed_hash { + ($key:expr, $data:expr) => {{ + use $crate::common::HashOutput; + HashOutput(blake3::keyed_hash($key, $data).into()) + }}; + ($key:expr, $data1:expr, $data2:expr) => {{ + use $crate::common::HashOutput; + let mut hasher = blake3::Hasher::new_keyed($key); + hasher.update($data1); + hasher.update($data2); + HashOutput(hasher.finalize().into()) + }}; + ($key:expr, $data1:expr, $data2:expr, $data3:expr) => {{ + use $crate::common::HashOutput; + let mut hasher = blake3::Hasher::new_keyed($key); + hasher.update($data1); + hasher.update($data2); + hasher.update($data3); + HashOutput(hasher.finalize().into()) + }}; +} + +pub fn verify_keyed_hash(key: &HashOutput, data: &[u8], tag: &[u8; 16]) -> Result<(), CryptoError> { + let computed: MacTag = keyed_hash!(key.as_ref(), data).into(); + if computed == *tag { + Ok(()) + } else { + Err(CryptoError::MacVerificationFailed) + } +} + +pub fn generate_keypair( + rng: &mut R, +) -> Result<(PublicKey, PrivateKey), CryptoError> { + let (secret_key, public_key) = secp256k1::SECP256K1.generate_keypair(rng); + Ok((public_key.into(), PrivateKey::from_inner(secret_key))) +} + +pub fn ecdh( + private_key: &PrivateKey, + public_key: &SerializedPublicKey, +) -> Result { + let public_key = PublicKey::try_from(public_key.as_bytes())?; + let shared_secret = secp256k1::ecdh::SharedSecret::new(public_key.inner(), private_key.inner()); + Ok(SharedSecret(shared_secret.secret_bytes())) +} + +pub fn verify_mac1( + message: &M, + static_public: &SerializedPublicKey, +) -> Result<(), CryptoError> { + let mac_key = hash!(LABEL_MAC1, static_public.as_bytes()); + verify_keyed_hash(&mac_key, message.mac1_input(), message.mac1().as_ref()) +} + +pub fn verify_mac2( + message: &M, + static_public: &SerializedPublicKey, + cookie: &[u8; 16], +) -> Result<(), CryptoError> { + let cookie_key = hash!(LABEL_COOKIE, static_public.as_bytes()); + let expected_mac2: MacTag = + keyed_hash!(cookie_key.as_ref(), message.mac2_input(), cookie).into(); + if message.mac2() == &expected_mac2 { + Ok(()) + } else { + Err(CryptoError::MacVerificationFailed) + } +} diff --git a/monad-wireauth-protocol/src/errors.rs b/monad-wireauth-protocol/src/errors.rs new file mode 100644 index 0000000000..af514018fa --- /dev/null +++ b/monad-wireauth-protocol/src/errors.rs @@ -0,0 +1,88 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum HandshakeError { + #[error("MAC1 verification failed: {0}")] + Mac1VerificationFailed(#[source] CryptoError), + + #[error("MAC2 verification failed: {0}")] + Mac2VerificationFailed(#[source] CryptoError), + + #[error("static key decryption failed: {0}")] + StaticKeyDecryptionFailed(#[source] CryptoError), + + #[error("timestamp decryption failed: {0}")] + TimestampDecryptionFailed(#[source] CryptoError), + + #[error("invalid timestamp format: unable to parse TAI64N from {size} bytes")] + InvalidTimestamp { size: usize }, + + #[error("empty message decryption failed: {0}")] + EmptyMessageDecryptionFailed(#[source] CryptoError), + + #[error("timestamp replay detected: received {received:?}, expected after {expected:?}")] + TimestampReplay { + received: tai64::Tai64N, + expected: tai64::Tai64N, + }, + + #[error("invalid message type: {0:#04x} is not a recognized handshake message")] + InvalidMessageType(u32), + + #[error("invalid receiver index: {index} does not match any active session")] + InvalidReceiverIndex { index: crate::SessionIndex }, +} + +#[derive(Error, Debug)] +pub enum MessageError { + #[error("buffer too small: need at least {required} bytes, got {actual}")] + BufferTooSmall { required: usize, actual: usize }, + + #[error("invalid message type: {0:#04x} is not a recognized protocol message")] + InvalidMessageType(u32), + + #[error("invalid message header: unable to parse or malformed structure")] + InvalidHeader, + + #[error("invalid data packet header: unable to parse or malformed structure")] + InvalidDataPacketHeader, +} + +#[derive(Error, Debug)] +pub enum CryptoError { + #[error("MAC verification failed: message authentication code does not match")] + MacVerificationFailed, + + #[error("invalid key: {0}")] + InvalidKey(#[from] secp256k1::Error), + + #[error("ECDH operation failed: unable to compute shared secret")] + EcdhFailed, +} + +#[derive(Error, Debug)] +pub enum CookieError { + #[error("invalid message type: {0:#04x} is not a cookie reply message")] + InvalidMessageType(u32), + + #[error("cookie decryption failed: {0}")] + CookieDecryptionFailed(#[source] CryptoError), + + #[error("invalid cookie MAC: {0}")] + InvalidCookieMac(#[source] CryptoError), +} + +#[derive(Error, Debug)] +pub enum ProtocolError { + #[error("handshake error: {0}")] + Handshake(#[from] HandshakeError), + + #[error("message error: {0}")] + Message(#[from] MessageError), + + #[error("crypto error: {0}")] + Crypto(#[from] CryptoError), + + #[error("cookie error: {0}")] + Cookie(#[from] CookieError), +} diff --git a/monad-wireauth-protocol/src/handshake.rs b/monad-wireauth-protocol/src/handshake.rs new file mode 100644 index 0000000000..af83d3291c --- /dev/null +++ b/monad-wireauth-protocol/src/handshake.rs @@ -0,0 +1,853 @@ +use std::time::SystemTime; + +use tai64::Tai64N; +use zeroize::Zeroizing; + +use crate::{ + common::*, + crypto::{ + decrypt_in_place, ecdh, encrypt_in_place, generate_keypair, CONSTRUCTION, IDENTIFIER, + LABEL_COOKIE, LABEL_MAC1, + }, + errors::{HandshakeError, ProtocolError}, + hash, keyed_hash, + messages::*, +}; + +pub struct HandshakeState { + pub chaining_key: Zeroizing, + pub hash: Zeroizing, + pub ephemeral_private: Option, + pub remote_ephemeral: Option, + pub remote_static: Option, + pub sender_index: u32, + pub receiver_index: u32, +} + +impl Default for HandshakeState { + fn default() -> Self { + let zero_hash = hash!(&[]); + Self { + chaining_key: zero_hash.clone().into(), + hash: zero_hash.into(), + ephemeral_private: None, + remote_ephemeral: None, + remote_static: None, + sender_index: 0, + receiver_index: 0, + } + } +} + +pub fn send_handshake_init( + rng: &mut R, + system_time: SystemTime, + local_session_index: u32, + initiator_static_private: &PrivateKey, + initiator_static_public: &SerializedPublicKey, + responder_static_public: &SerializedPublicKey, + stored_cookie: Option<&[u8; 16]>, +) -> Result<(HandshakeInitiation, HandshakeState), ProtocolError> { + let (ephemeral_public, ephemeral_private) = generate_keypair(rng)?; + let mut msg = HandshakeInitiation { + ephemeral_public: ephemeral_public.into(), + sender_index: local_session_index.into(), + ..Default::default() + }; + let mut inititiator = HandshakeState { + chaining_key: hash!(CONSTRUCTION).into(), + sender_index: local_session_index, + ephemeral_private: Some(ephemeral_private), + ..Default::default() + }; + + inititiator.hash = hash!( + hash!(inititiator.chaining_key.as_ref(), IDENTIFIER).as_ref(), + responder_static_public.as_bytes() + ) + .into(); + inititiator.hash = hash!(inititiator.hash.as_ref(), msg.ephemeral_public.as_bytes()).into(); + + let temp = keyed_hash!( + inititiator.chaining_key.as_ref(), + msg.ephemeral_public.as_bytes() + ); + inititiator.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + + let ecdh_es = ecdh( + inititiator + .ephemeral_private + .as_ref() + .expect("ephemeral private key must be set"), + responder_static_public, + )?; + let temp = keyed_hash!(inititiator.chaining_key.as_ref(), ecdh_es.as_ref()); + inititiator.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + let key = keyed_hash!(temp.as_ref(), inititiator.chaining_key.as_ref(), &[0x2]); + + msg.encrypted_static = *initiator_static_public; + msg.encrypted_static_tag = encrypt_in_place( + &(&key).into(), + &(0u64.into()), + msg.encrypted_static.as_mut_bytes(), + inititiator.hash.as_ref(), + ); + inititiator.hash = hash!( + inititiator.hash.as_ref(), + msg.encrypted_static.as_bytes(), + &msg.encrypted_static_tag + ) + .into(); + + let ecdh_ss = ecdh(initiator_static_private, responder_static_public)?; + let temp = keyed_hash!(inititiator.chaining_key.as_ref(), ecdh_ss.as_ref()); + inititiator.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + let key = keyed_hash!(temp.as_ref(), inititiator.chaining_key.as_ref(), &[0x2]); + + let timestamp: Tai64N = system_time.into(); + msg.encrypted_timestamp = timestamp.to_bytes(); + msg.encrypted_timestamp_tag = encrypt_in_place( + &(&key).into(), + &(0u64.into()), + &mut msg.encrypted_timestamp, + inititiator.hash.as_ref(), + ); + inititiator.hash = hash!( + inititiator.hash.as_ref(), + &msg.encrypted_timestamp, + &msg.encrypted_timestamp_tag + ) + .into(); + + let mac_key = hash!(LABEL_MAC1, responder_static_public.as_bytes()); + msg.mac1 = keyed_hash!(mac_key.as_ref(), msg.mac1_input()).into(); + if let Some(cookie) = stored_cookie { + let cookie_key = hash!(LABEL_COOKIE, responder_static_public.as_bytes()); + msg.mac2 = keyed_hash!(cookie_key.as_ref(), msg.mac2_input(), cookie).into(); + } + + Ok((msg, inititiator)) +} + +pub fn accept_handshake_init( + responder_static_private: &PrivateKey, + responder_static_public: &SerializedPublicKey, + msg: &mut HandshakeInitiation, +) -> Result<(HandshakeState, SystemTime), ProtocolError> { + crate::crypto::verify_mac1(msg, responder_static_public) + .map_err(HandshakeError::Mac1VerificationFailed)?; + + let mut state = HandshakeState { + chaining_key: hash!(CONSTRUCTION).into(), + ..Default::default() + }; + + let hash_ck_id = hash!(state.chaining_key.as_ref(), IDENTIFIER); + state.hash = hash!(hash_ck_id.as_ref(), responder_static_public.as_bytes()).into(); + + state.hash = hash!(state.hash.as_ref(), msg.ephemeral_public.as_bytes()).into(); + state.remote_ephemeral = Some(msg.ephemeral_public); + state.receiver_index = msg.sender_index.get(); + + let temp = keyed_hash!(state.chaining_key.as_ref(), msg.ephemeral_public.as_bytes()); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + + let ecdh_ee = ecdh(responder_static_private, &msg.ephemeral_public)?; + let temp = keyed_hash!(state.chaining_key.as_ref(), ecdh_ee.as_ref()); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + let key = keyed_hash!(temp.as_ref(), state.chaining_key.as_ref(), &[0x2]); + + let hash = state.hash.clone(); + state.hash = hash!( + state.hash.as_ref(), + msg.encrypted_static.as_bytes(), + &msg.encrypted_static_tag + ) + .into(); + decrypt_in_place( + &(&key).into(), + &(0u64).into(), + msg.encrypted_static.as_mut_bytes(), + &msg.encrypted_static_tag, + hash.as_ref(), + ) + .map_err(HandshakeError::StaticKeyDecryptionFailed)?; + + state.remote_static = Some(msg.encrypted_static); + + let ecdh_ss = ecdh(responder_static_private, &msg.encrypted_static)?; + let temp = keyed_hash!(state.chaining_key.as_ref(), ecdh_ss.as_ref()); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + let key = keyed_hash!(temp.as_ref(), state.chaining_key.as_ref(), &[0x2]); + + let hash = state.hash.clone(); + state.hash = hash!( + state.hash.as_ref(), + &msg.encrypted_timestamp, + &msg.encrypted_timestamp_tag + ) + .into(); + decrypt_in_place( + &(&key).into(), + &(0u64).into(), + &mut msg.encrypted_timestamp, + &msg.encrypted_timestamp_tag, + hash.as_ref(), + ) + .map_err(HandshakeError::TimestampDecryptionFailed)?; + + let timestamp = Tai64N::from_slice(&msg.encrypted_timestamp).map_err(|_| { + HandshakeError::InvalidTimestamp { + size: msg.encrypted_timestamp.len(), + } + })?; + let system_time = timestamp.to_system_time(); + + Ok((state, system_time)) +} + +pub fn send_handshake_response( + rng: &mut R, + local_session_index: u32, + state: &mut HandshakeState, + psk: &[u8; 32], + stored_cookie: Option<&[u8; 16]>, +) -> Result<(HandshakeResponse, TransportKeys), ProtocolError> { + let (ephemeral_public, ephemeral_private) = generate_keypair(rng)?; + state.sender_index = local_session_index; + let initiator_static_public = state + .remote_static + .as_ref() + .expect("remote static key must be set"); + let mut msg = HandshakeResponse { + sender_index: local_session_index.into(), + receiver_index: state.receiver_index.into(), + ephemeral_public: (&ephemeral_public).into(), + ..Default::default() + }; + + state.hash = hash!(state.hash.as_ref(), msg.ephemeral_public.as_bytes()).into(); + + let temp = keyed_hash!(state.chaining_key.as_ref(), msg.ephemeral_public.as_bytes()); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + + let remote_ephemeral = state + .remote_ephemeral + .as_ref() + .expect("remote ephemeral key must be set"); + + let ecdh_ee = ecdh(&ephemeral_private, remote_ephemeral)?; + let temp = keyed_hash!(state.chaining_key.as_ref(), ecdh_ee.as_ref()); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + + let ecdh_se = ecdh(&ephemeral_private, initiator_static_public)?; + let temp = keyed_hash!(state.chaining_key.as_ref(), ecdh_se.as_ref()); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + + let temp = keyed_hash!(state.chaining_key.as_ref(), psk); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + let temp2 = keyed_hash!(temp.as_ref(), state.chaining_key.as_ref(), &[0x2]); + let key = keyed_hash!(temp.as_ref(), temp2.as_ref(), &[0x3]); + + state.hash = hash!(state.hash.as_ref(), temp2.as_ref()).into(); + + msg.encrypted_nothing_tag = + encrypt_in_place(&(&key).into(), &(0u64).into(), &mut [], state.hash.as_ref()); + state.hash = hash!(state.hash.as_ref(), &msg.encrypted_nothing_tag).into(); + + let mac_key = hash!(LABEL_MAC1, initiator_static_public.as_bytes()); + msg.mac1 = keyed_hash!(mac_key.as_ref(), msg.mac1_input()).into(); + + if let Some(cookie) = stored_cookie { + let cookie_key = hash!(LABEL_COOKIE, initiator_static_public.as_bytes()); + msg.mac2 = keyed_hash!(cookie_key.as_ref(), msg.mac2_input(), cookie).into(); + } + + Ok((msg, derive_transport_keys(&state.chaining_key, false))) +} + +pub fn accept_handshake_response( + initiator_static_private: &PrivateKey, + initiator_static_public: &SerializedPublicKey, + msg: &mut HandshakeResponse, + state: &mut HandshakeState, + psk: &[u8; 32], +) -> Result { + crate::crypto::verify_mac1(msg, initiator_static_public) + .map_err(HandshakeError::Mac1VerificationFailed)?; + + state.receiver_index = msg.sender_index.get(); + state.remote_ephemeral = Some(msg.ephemeral_public); + state.hash = hash!(state.hash.as_ref(), msg.ephemeral_public.as_bytes()).into(); + + let temp = keyed_hash!(state.chaining_key.as_ref(), msg.ephemeral_public.as_bytes()); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + + let ecdh_ee = ecdh( + state + .ephemeral_private + .as_ref() + .expect("ephemeral private key must be set"), + &msg.ephemeral_public, + )?; + let temp = keyed_hash!(state.chaining_key.as_ref(), ecdh_ee.as_ref()); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + + let ecdh_se = ecdh(initiator_static_private, &msg.ephemeral_public)?; + let temp = keyed_hash!(state.chaining_key.as_ref(), ecdh_se.as_ref()); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + + let temp = keyed_hash!(state.chaining_key.as_ref(), psk); + state.chaining_key = keyed_hash!(temp.as_ref(), &[0x1]).into(); + let temp2 = keyed_hash!(temp.as_ref(), state.chaining_key.as_ref(), &[0x2]); + let key = keyed_hash!(temp.as_ref(), temp2.as_ref(), &[0x3]); + state.hash = hash!(state.hash.as_ref(), temp2.as_ref()).into(); + + let hash = state.hash.clone(); + state.hash = hash!(state.hash.as_ref(), &msg.encrypted_nothing_tag).into(); + decrypt_in_place( + &(&key).into(), + &(0u64).into(), + &mut [], + &msg.encrypted_nothing_tag, + hash.as_ref(), + ) + .map_err(HandshakeError::EmptyMessageDecryptionFailed)?; + + Ok(derive_transport_keys(&state.chaining_key, true)) +} + +pub fn derive_transport_keys(chaining_key: &HashOutput, is_initiator: bool) -> TransportKeys { + let temp1 = keyed_hash!(chaining_key.as_ref(), &[]); + let temp2 = keyed_hash!(temp1.as_ref(), &[0x1]); + let temp3 = keyed_hash!(temp1.as_ref(), temp2.as_ref(), &[0x2]); + + if is_initiator { + TransportKeys { + send_key: CipherKey::from(&temp2), + recv_key: CipherKey::from(&temp3), + } + } else { + TransportKeys { + send_key: CipherKey::from(&temp3), + recv_key: CipherKey::from(&temp2), + } + } +} + +#[cfg(test)] +mod tests { + use std::time::{Duration, SystemTime}; + + use secp256k1::rand::{rngs::StdRng, Rng, SeedableRng}; + use serde::Serialize; + use zerocopy::IntoBytes; + + use super::*; + + #[derive(Serialize)] + struct ProtocolTrace { + test_name: String, + seed: u64, + timestamp: u64, + initiator_static_private: String, + initiator_static_public: String, + responder_static_private: String, + responder_static_public: String, + initiator_session_index: u32, + responder_session_index: u32, + init_message: MessageTrace, + response_message: MessageTrace, + transport_keys: TransportKeysTrace, + } + + #[derive(Serialize)] + struct MessageTrace { + raw_bytes: String, + sender_index: u32, + receiver_index: u32, + ephemeral_public: String, + encrypted_static: String, + encrypted_static_tag: String, + encrypted_timestamp: String, + encrypted_timestamp_tag: String, + encrypted_nothing_tag: Option, + mac1: String, + mac2: String, + } + + #[derive(Serialize)] + struct TransportKeysTrace { + initiator_send_key: String, + initiator_recv_key: String, + responder_send_key: String, + responder_recv_key: String, + } + + #[derive(Serialize)] + struct CookieProtocolTrace { + test_name: String, + seed: u64, + cookie_secret: String, + cookie_value: String, + init_without_cookie: MessageTrace, + cookie_reply: CookieReplyTrace, + init_with_cookie: MessageTrace, + response_with_cookie: MessageTrace, + } + + #[derive(Serialize)] + struct CookieReplyTrace { + raw_bytes: String, + receiver_index: u32, + nonce: String, + encrypted_cookie: String, + } + + #[derive(Serialize)] + struct EncryptedDataTrace { + test_name: String, + plaintext: String, + nonce: String, + sender_index: u32, + receiver_index: u32, + encrypted_payload: String, + auth_tag: String, + complete_packet: String, + } + + fn to_hex(data: &[u8]) -> String { + hex::encode(data) + } + + fn extract_init_message_trace(msg: &HandshakeInitiation) -> MessageTrace { + MessageTrace { + raw_bytes: to_hex(msg.as_bytes()), + sender_index: msg.sender_index.get(), + receiver_index: 0, + ephemeral_public: to_hex(msg.ephemeral_public.as_bytes()), + encrypted_static: to_hex(msg.encrypted_static.as_bytes()), + encrypted_static_tag: to_hex(&msg.encrypted_static_tag), + encrypted_timestamp: to_hex(&msg.encrypted_timestamp), + encrypted_timestamp_tag: to_hex(&msg.encrypted_timestamp_tag), + encrypted_nothing_tag: None, + mac1: to_hex(msg.mac1.as_ref()), + mac2: to_hex(msg.mac2.as_ref()), + } + } + + fn extract_response_trace(msg: &HandshakeResponse) -> MessageTrace { + MessageTrace { + raw_bytes: to_hex(msg.as_bytes()), + sender_index: msg.sender_index.get(), + receiver_index: msg.receiver_index.get(), + ephemeral_public: to_hex(msg.ephemeral_public.as_bytes()), + encrypted_static: String::new(), + encrypted_static_tag: String::new(), + encrypted_timestamp: String::new(), + encrypted_timestamp_tag: String::new(), + encrypted_nothing_tag: Some(to_hex(&msg.encrypted_nothing_tag)), + mac1: to_hex(msg.mac1.as_ref()), + mac2: to_hex(msg.mac2.as_ref()), + } + } + + #[test] + fn test_complete_handshake_trace() { + let seed = 42u64; + let mut rng = StdRng::seed_from_u64(seed); + let timestamp = 1700000000u64; + let system_time = SystemTime::UNIX_EPOCH + Duration::from_secs(timestamp); + + let (initiator_static_public, initiator_static_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + let (responder_static_public, responder_static_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + + let initiator_index = 100u32; + let responder_index = 200u32; + + let (init_msg, init_state) = send_handshake_init( + &mut rng, + system_time, + initiator_index, + &initiator_static_private, + &SerializedPublicKey::from(&initiator_static_public), + &SerializedPublicKey::from(&responder_static_public), + None, + ) + .unwrap(); + + let init_trace = extract_init_message_trace(&init_msg); + + let mut init_msg_mut = init_msg; + let (responder_state, _timestamp) = accept_handshake_init( + &responder_static_private, + &SerializedPublicKey::from(&responder_static_public), + &mut init_msg_mut, + ) + .unwrap(); + + let mut responder_state_mut = responder_state; + let psk = [0u8; 32]; + let (resp_msg, responder_transport_keys) = send_handshake_response( + &mut rng, + responder_index, + &mut responder_state_mut, + &psk, + None, + ) + .unwrap(); + + let response_trace = extract_response_trace(&resp_msg); + + let mut resp_msg_mut = resp_msg; + let mut init_state_mut = init_state; + let initiator_transport_keys = accept_handshake_response( + &initiator_static_private, + &SerializedPublicKey::from(&initiator_static_public), + &mut resp_msg_mut, + &mut init_state_mut, + &psk, + ) + .unwrap(); + + let trace = ProtocolTrace { + test_name: "standard_handshake".to_string(), + seed, + timestamp, + initiator_static_private: to_hex(&initiator_static_private.inner().secret_bytes()), + initiator_static_public: to_hex(&<[u8; 33]>::from(&initiator_static_public)), + responder_static_private: to_hex(&responder_static_private.inner().secret_bytes()), + responder_static_public: to_hex(&<[u8; 33]>::from(&responder_static_public)), + initiator_session_index: initiator_index, + responder_session_index: responder_index, + init_message: init_trace, + response_message: response_trace, + transport_keys: TransportKeysTrace { + initiator_send_key: to_hex(initiator_transport_keys.send_key.as_ref()), + initiator_recv_key: to_hex(initiator_transport_keys.recv_key.as_ref()), + responder_send_key: to_hex(responder_transport_keys.send_key.as_ref()), + responder_recv_key: to_hex(responder_transport_keys.recv_key.as_ref()), + }, + }; + + insta::assert_yaml_snapshot!("complete_handshake_trace", trace); + } + + #[test] + fn test_cookie_handshake_trace() { + let seed = 43u64; + let mut rng = StdRng::seed_from_u64(seed); + let timestamp = 1700000000u64; + let system_time = SystemTime::UNIX_EPOCH + Duration::from_secs(timestamp); + + let (initiator_static_public, initiator_static_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + let (responder_static_public, responder_static_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + + let initiator_index = 101u32; + let responder_index = 201u32; + + let (init_msg, _init_state) = send_handshake_init( + &mut rng, + system_time, + initiator_index, + &initiator_static_private, + &SerializedPublicKey::from(&initiator_static_public), + &SerializedPublicKey::from(&responder_static_public), + None, + ) + .unwrap(); + + let init_no_cookie_trace = extract_init_message_trace(&init_msg); + let init_mac1 = init_msg.mac1; + let init_sender_index = init_msg.sender_index.get(); + + let cookie_secret = [0x77u8; 32]; + let nonce = 42u64; + let initiator_addr: std::net::SocketAddr = "192.168.1.1:51820".parse().unwrap(); + let cookie = crate::cookies::generate_cookie(&cookie_secret, nonce, &initiator_addr); + + let nonce_secret = [0x77u8; 32]; + let cookie_nonce = 1234u128; + let cookie_reply = crate::cookies::send_cookie_reply( + &nonce_secret, + cookie_nonce, + &SerializedPublicKey::from(&responder_static_public), + init_sender_index, + init_mac1.as_ref(), + &cookie, + ) + .unwrap(); + + let cookie_reply_trace = CookieReplyTrace { + raw_bytes: to_hex(cookie_reply.as_bytes()), + receiver_index: cookie_reply.receiver_index.get(), + nonce: to_hex(cookie_reply.nonce.as_ref()), + encrypted_cookie: to_hex(&cookie_reply.encrypted_cookie), + }; + + let mut cookie_reply_mut = cookie_reply; + let extracted_cookie = crate::cookies::accept_cookie_reply( + &SerializedPublicKey::from(&responder_static_public), + &mut cookie_reply_mut, + init_mac1.as_ref(), + ) + .unwrap(); + + let (init_msg_with_cookie, _init_state_with_cookie) = send_handshake_init( + &mut rng, + system_time, + initiator_index + 1, + &initiator_static_private, + &SerializedPublicKey::from(&initiator_static_public), + &SerializedPublicKey::from(&responder_static_public), + Some(&extracted_cookie), + ) + .unwrap(); + + let init_with_cookie_trace = extract_init_message_trace(&init_msg_with_cookie); + + let mut init_msg_mut = init_msg_with_cookie; + let (responder_state, _timestamp) = accept_handshake_init( + &responder_static_private, + &SerializedPublicKey::from(&responder_static_public), + &mut init_msg_mut, + ) + .unwrap(); + + let mut responder_state_mut = responder_state; + let psk = [0u8; 32]; + let (resp_msg, _responder_keys) = send_handshake_response( + &mut rng, + responder_index, + &mut responder_state_mut, + &psk, + Some(&cookie), + ) + .unwrap(); + + let response_with_cookie_trace = extract_response_trace(&resp_msg); + + let trace = CookieProtocolTrace { + test_name: "cookie_handshake".to_string(), + seed, + cookie_secret: to_hex(&cookie_secret), + cookie_value: to_hex(&cookie), + init_without_cookie: init_no_cookie_trace, + cookie_reply: cookie_reply_trace, + init_with_cookie: init_with_cookie_trace, + response_with_cookie: response_with_cookie_trace, + }; + + insta::assert_yaml_snapshot!("cookie_handshake_trace", trace); + } + + #[test] + fn test_data_encryption_trace() { + let seed = 44u64; + let mut rng = StdRng::seed_from_u64(seed); + let system_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1700000000); + + let (initiator_static_public, initiator_static_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + let (responder_static_public, responder_static_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + + let initiator_index = 102u32; + let responder_index = 202u32; + + let psk = [0u8; 32]; + let (init_msg, init_state) = send_handshake_init( + &mut rng, + system_time, + initiator_index, + &initiator_static_private, + &SerializedPublicKey::from(&initiator_static_public), + &SerializedPublicKey::from(&responder_static_public), + None, + ) + .unwrap(); + + let mut init_msg_mut = init_msg; + let (responder_state, _timestamp) = accept_handshake_init( + &responder_static_private, + &SerializedPublicKey::from(&responder_static_public), + &mut init_msg_mut, + ) + .unwrap(); + + let mut responder_state_mut = responder_state; + let (resp_msg, responder_transport_keys) = send_handshake_response( + &mut rng, + responder_index, + &mut responder_state_mut, + &psk, + None, + ) + .unwrap(); + + let mut resp_msg_mut = resp_msg; + let mut init_state_mut = init_state; + let initiator_transport_keys = accept_handshake_response( + &initiator_static_private, + &SerializedPublicKey::from(&initiator_static_public), + &mut resp_msg_mut, + &mut init_state_mut, + &psk, + ) + .unwrap(); + + // Encrypt data from initiator to responder + let test_message = b"Hello, encrypted world!"; + let mut encrypted_message = test_message.to_vec(); + let nonce = crate::common::CipherNonce::from(1u64); + let tag = crate::crypto::encrypt_in_place( + &initiator_transport_keys.send_key, + &nonce, + &mut encrypted_message, + &[], + ); + + let mut data_packet = Vec::new(); + data_packet.extend_from_slice(&[0x04]); // Data message type + data_packet.extend_from_slice(&responder_index.to_le_bytes()); + data_packet.extend_from_slice(nonce.as_ref()); + data_packet.extend_from_slice(&encrypted_message); + data_packet.extend_from_slice(&tag); + + let initiator_to_responder = EncryptedDataTrace { + test_name: "initiator_to_responder".to_string(), + plaintext: String::from_utf8_lossy(test_message).to_string(), + nonce: to_hex(nonce.as_ref()), + sender_index: initiator_index, + receiver_index: responder_index, + encrypted_payload: to_hex(&encrypted_message), + auth_tag: to_hex(&tag), + complete_packet: to_hex(&data_packet), + }; + + // Encrypt data from responder to initiator + let test_message_2 = b"Response from responder!"; + let mut encrypted_message_2 = test_message_2.to_vec(); + let nonce_2 = crate::common::CipherNonce::from(1u64); + let tag_2 = crate::crypto::encrypt_in_place( + &responder_transport_keys.send_key, + &nonce_2, + &mut encrypted_message_2, + &[], + ); + + let mut data_packet_2 = Vec::new(); + data_packet_2.extend_from_slice(&[0x04]); // Data message type + data_packet_2.extend_from_slice(&initiator_index.to_le_bytes()); + data_packet_2.extend_from_slice(nonce_2.as_ref()); + data_packet_2.extend_from_slice(&encrypted_message_2); + data_packet_2.extend_from_slice(&tag_2); + + let responder_to_initiator = EncryptedDataTrace { + test_name: "responder_to_initiator".to_string(), + plaintext: String::from_utf8_lossy(test_message_2).to_string(), + nonce: to_hex(nonce_2.as_ref()), + sender_index: responder_index, + receiver_index: initiator_index, + encrypted_payload: to_hex(&encrypted_message_2), + auth_tag: to_hex(&tag_2), + complete_packet: to_hex(&data_packet_2), + }; + + insta::assert_yaml_snapshot!( + "data_encryption_traces", + vec![initiator_to_responder, responder_to_initiator,] + ); + } + + #[test] + fn test_multiple_protocol_vectors() { + let seeds = [100u64, 200u64, 300u64]; + let mut all_traces = Vec::new(); + + for seed in &seeds { + let mut rng = StdRng::seed_from_u64(*seed); + let timestamp = 1700000000u64 + seed; + let system_time = SystemTime::UNIX_EPOCH + Duration::from_secs(timestamp); + + let (initiator_static_public, initiator_static_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + let (responder_static_public, responder_static_private) = + crate::crypto::generate_keypair(&mut rng).unwrap(); + + let initiator_index = rng.random::(); + let responder_index = rng.random::(); + + let (init_msg, init_state) = send_handshake_init( + &mut rng, + system_time, + initiator_index, + &initiator_static_private, + &SerializedPublicKey::from(&initiator_static_public), + &SerializedPublicKey::from(&responder_static_public), + None, + ) + .unwrap(); + + let init_trace = extract_init_message_trace(&init_msg); + + let mut init_msg_mut = init_msg; + let (responder_state, _timestamp) = accept_handshake_init( + &responder_static_private, + &SerializedPublicKey::from(&responder_static_public), + &mut init_msg_mut, + ) + .unwrap(); + + let mut responder_state_mut = responder_state; + let psk = [0u8; 32]; + let (resp_msg, responder_keys) = send_handshake_response( + &mut rng, + responder_index, + &mut responder_state_mut, + &psk, + None, + ) + .unwrap(); + + let response_trace = extract_response_trace(&resp_msg); + + let mut resp_msg_mut = resp_msg; + let mut init_state_mut = init_state; + let initiator_keys = accept_handshake_response( + &initiator_static_private, + &SerializedPublicKey::from(&initiator_static_public), + &mut resp_msg_mut, + &mut init_state_mut, + &psk, + ) + .unwrap(); + + all_traces.push(ProtocolTrace { + test_name: format!("vector_{}", seed), + seed: *seed, + timestamp, + initiator_static_private: to_hex(&initiator_static_private.inner().secret_bytes()), + initiator_static_public: to_hex(&<[u8; 33]>::from(&initiator_static_public)), + responder_static_private: to_hex(&responder_static_private.inner().secret_bytes()), + responder_static_public: to_hex(&<[u8; 33]>::from(&responder_static_public)), + initiator_session_index: initiator_index, + responder_session_index: responder_index, + init_message: init_trace, + response_message: response_trace, + transport_keys: TransportKeysTrace { + initiator_send_key: to_hex(initiator_keys.send_key.as_ref()), + initiator_recv_key: to_hex(initiator_keys.recv_key.as_ref()), + responder_send_key: to_hex(responder_keys.send_key.as_ref()), + responder_recv_key: to_hex(responder_keys.recv_key.as_ref()), + }, + }); + } + + insta::assert_yaml_snapshot!("multiple_protocol_vectors", all_traces); + } +} diff --git a/monad-wireauth-protocol/src/lib.rs b/monad-wireauth-protocol/src/lib.rs new file mode 100644 index 0000000000..22b6abc5c4 --- /dev/null +++ b/monad-wireauth-protocol/src/lib.rs @@ -0,0 +1,8 @@ +pub mod common; +pub mod cookies; +pub mod crypto; +pub mod errors; +pub mod handshake; +pub mod messages; + +pub use common::SessionIndex; diff --git a/monad-wireauth-protocol/src/messages.rs b/monad-wireauth-protocol/src/messages.rs new file mode 100644 index 0000000000..0251686ff3 --- /dev/null +++ b/monad-wireauth-protocol/src/messages.rs @@ -0,0 +1,572 @@ +use std::convert::TryFrom; + +use bytes::Bytes; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, LE, U32, U64}; + +use crate::{common::*, errors::MessageError}; + +/// Trait for messages that have MAC1 and MAC2 fields +pub trait MacMessage: IntoBytes { + fn mac1(&self) -> &MacTag; + fn mac2(&self) -> &MacTag; + fn mac1_input(&self) -> &[u8]; + fn mac2_input(&self) -> &[u8]; +} + +pub const TYPE_HANDSHAKE_INITIATION: u8 = 1; +pub const TYPE_HANDSHAKE_RESPONSE: u8 = 2; +pub const TYPE_COOKIE_REPLY: u8 = 3; +pub const TYPE_DATA: u8 = 4; + +pub const TIMESTAMP_SIZE: usize = 12; + +#[repr(C, packed)] +#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone)] +pub struct HandshakeInitiation { + pub message_type: u8, + pub reserved: [u8; 3], + pub sender_index: U32, + pub ephemeral_public: SerializedPublicKey, + pub encrypted_static: SerializedPublicKey, + pub encrypted_static_tag: [u8; CIPHER_TAG_SIZE], + pub encrypted_timestamp: [u8; TIMESTAMP_SIZE], + pub encrypted_timestamp_tag: [u8; CIPHER_TAG_SIZE], + pub mac1: MacTag, + pub mac2: MacTag, +} + +impl Default for HandshakeInitiation { + fn default() -> Self { + unsafe { + let mut msg: Self = core::mem::zeroed(); + msg.message_type = TYPE_HANDSHAKE_INITIATION; + msg + } + } +} + +impl MacMessage for HandshakeInitiation { + fn mac1(&self) -> &MacTag { + &self.mac1 + } + + fn mac2(&self) -> &MacTag { + &self.mac2 + } + + fn mac1_input(&self) -> &[u8] { + self.as_bytes()[..Self::MAC1_OFFSET].as_ref() + } + + fn mac2_input(&self) -> &[u8] { + self.as_bytes()[..Self::MAC2_OFFSET].as_ref() + } +} + +impl HandshakeInitiation { + pub const SIZE: usize = 4 + + 4 + + PUBLIC_KEY_SIZE + + PUBLIC_KEY_SIZE + + CIPHER_TAG_SIZE + + TIMESTAMP_SIZE + + CIPHER_TAG_SIZE + + MAC_TAG_SIZE + + MAC_TAG_SIZE; + + pub const MAC1_OFFSET: usize = 4 + + 4 + + PUBLIC_KEY_SIZE + + PUBLIC_KEY_SIZE + + CIPHER_TAG_SIZE + + TIMESTAMP_SIZE + + CIPHER_TAG_SIZE; + + pub const MAC2_OFFSET: usize = Self::MAC1_OFFSET + MAC_TAG_SIZE; +} + +impl<'a> TryFrom<&'a [u8]> for &'a HandshakeInitiation { + type Error = MessageError; + + fn try_from(bytes: &'a [u8]) -> Result { + if bytes.len() != HandshakeInitiation::SIZE { + return Err(MessageError::BufferTooSmall { + required: HandshakeInitiation::SIZE, + actual: bytes.len(), + }); + } + HandshakeInitiation::ref_from_bytes(bytes).map_err(|_| MessageError::InvalidHeader) + } +} + +impl<'a> TryFrom<&'a mut [u8]> for &'a mut HandshakeInitiation { + type Error = MessageError; + + fn try_from(bytes: &'a mut [u8]) -> Result { + if bytes.len() != HandshakeInitiation::SIZE { + return Err(MessageError::BufferTooSmall { + required: HandshakeInitiation::SIZE, + actual: bytes.len(), + }); + } + HandshakeInitiation::mut_from_bytes(bytes).map_err(|_| MessageError::InvalidHeader) + } +} + +impl From for Bytes { + fn from(msg: HandshakeInitiation) -> Self { + Bytes::copy_from_slice(msg.as_bytes()) + } +} + +#[repr(C, packed)] +#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone)] +pub struct HandshakeResponse { + pub message_type: u8, + pub reserved: [u8; 3], + pub sender_index: U32, + pub receiver_index: U32, + pub ephemeral_public: SerializedPublicKey, + pub encrypted_nothing_tag: [u8; CIPHER_TAG_SIZE], + pub mac1: MacTag, + pub mac2: MacTag, +} + +impl Default for HandshakeResponse { + fn default() -> Self { + unsafe { + let mut msg: Self = core::mem::zeroed(); + msg.message_type = TYPE_HANDSHAKE_RESPONSE; + msg + } + } +} + +impl MacMessage for HandshakeResponse { + fn mac1(&self) -> &MacTag { + &self.mac1 + } + + fn mac2(&self) -> &MacTag { + &self.mac2 + } + + fn mac1_input(&self) -> &[u8] { + self.as_bytes()[..Self::MAC1_OFFSET].as_ref() + } + + fn mac2_input(&self) -> &[u8] { + self.as_bytes()[..Self::MAC2_OFFSET].as_ref() + } +} + +impl HandshakeResponse { + pub const SIZE: usize = + 4 + 4 + 4 + PUBLIC_KEY_SIZE + CIPHER_TAG_SIZE + MAC_TAG_SIZE + MAC_TAG_SIZE; + + pub const MAC1_OFFSET: usize = 4 + 4 + 4 + PUBLIC_KEY_SIZE + CIPHER_TAG_SIZE; + + pub const MAC2_OFFSET: usize = Self::MAC1_OFFSET + MAC_TAG_SIZE; +} + +impl<'a> TryFrom<&'a [u8]> for &'a HandshakeResponse { + type Error = MessageError; + + fn try_from(bytes: &'a [u8]) -> Result { + if bytes.len() != HandshakeResponse::SIZE { + return Err(MessageError::BufferTooSmall { + required: HandshakeResponse::SIZE, + actual: bytes.len(), + }); + } + HandshakeResponse::ref_from_bytes(bytes).map_err(|_| MessageError::InvalidHeader) + } +} + +impl<'a> TryFrom<&'a mut [u8]> for &'a mut HandshakeResponse { + type Error = MessageError; + + fn try_from(bytes: &'a mut [u8]) -> Result { + if bytes.len() != HandshakeResponse::SIZE { + return Err(MessageError::BufferTooSmall { + required: HandshakeResponse::SIZE, + actual: bytes.len(), + }); + } + HandshakeResponse::mut_from_bytes(bytes).map_err(|_| MessageError::InvalidHeader) + } +} + +impl From for Bytes { + fn from(msg: HandshakeResponse) -> Self { + Bytes::copy_from_slice(msg.as_bytes()) + } +} + +#[repr(C, packed)] +#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone)] +pub struct CookieReply { + pub message_type: u8, + pub reserved: [u8; 3], + pub receiver_index: U32, + pub nonce: CipherNonce, + pub encrypted_cookie: [u8; 16], + pub encrypted_cookie_tag: [u8; CIPHER_TAG_SIZE], +} + +impl Default for CookieReply { + fn default() -> Self { + unsafe { + let mut msg: Self = core::mem::zeroed(); + msg.message_type = TYPE_COOKIE_REPLY; + msg + } + } +} + +impl CookieReply { + pub const SIZE: usize = 4 + 4 + 16 + 16 + CIPHER_TAG_SIZE; +} + +impl<'a> TryFrom<&'a [u8]> for &'a CookieReply { + type Error = MessageError; + + fn try_from(bytes: &'a [u8]) -> Result { + if bytes.len() != CookieReply::SIZE { + return Err(MessageError::BufferTooSmall { + required: CookieReply::SIZE, + actual: bytes.len(), + }); + } + CookieReply::ref_from_bytes(bytes).map_err(|_| MessageError::InvalidHeader) + } +} + +impl<'a> TryFrom<&'a mut [u8]> for &'a mut CookieReply { + type Error = MessageError; + + fn try_from(bytes: &'a mut [u8]) -> Result { + if bytes.len() != CookieReply::SIZE { + return Err(MessageError::BufferTooSmall { + required: CookieReply::SIZE, + actual: bytes.len(), + }); + } + CookieReply::mut_from_bytes(bytes).map_err(|_| MessageError::InvalidHeader) + } +} + +impl From for Bytes { + fn from(msg: CookieReply) -> Self { + Bytes::copy_from_slice(msg.as_bytes()) + } +} + +#[repr(C, packed)] +#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone)] +pub struct DataPacketHeader { + pub message_type: u8, + pub reserved: [u8; 3], + pub receiver_index: U32, + pub counter: U64, + pub tag: [u8; CIPHER_TAG_SIZE], +} + +impl Default for DataPacketHeader { + fn default() -> Self { + unsafe { + let mut msg: Self = core::mem::zeroed(); + msg.message_type = TYPE_DATA; + msg + } + } +} + +impl DataPacketHeader { + pub const SIZE: usize = 4 + 4 + 8 + CIPHER_TAG_SIZE; +} + +impl<'a> TryFrom<&'a [u8]> for &'a DataPacketHeader { + type Error = MessageError; + + fn try_from(bytes: &'a [u8]) -> Result { + if bytes.len() < DataPacketHeader::SIZE { + return Err(MessageError::BufferTooSmall { + required: DataPacketHeader::SIZE, + actual: bytes.len(), + }); + } + DataPacketHeader::ref_from_bytes(&bytes[..DataPacketHeader::SIZE]) + .map_err(|_| MessageError::InvalidDataPacketHeader) + } +} + +impl<'a> TryFrom<&'a mut [u8]> for &'a mut DataPacketHeader { + type Error = MessageError; + + fn try_from(bytes: &'a mut [u8]) -> Result { + if bytes.len() < DataPacketHeader::SIZE { + return Err(MessageError::BufferTooSmall { + required: DataPacketHeader::SIZE, + actual: bytes.len(), + }); + } + let (header_bytes, _) = bytes.split_at_mut(DataPacketHeader::SIZE); + DataPacketHeader::mut_from_bytes(header_bytes) + .map_err(|_| MessageError::InvalidDataPacketHeader) + } +} + +impl From for Bytes { + fn from(header: DataPacketHeader) -> Self { + Bytes::copy_from_slice(header.as_bytes()) + } +} + +pub struct DataPacket<'a> { + pub header: &'a DataPacketHeader, + pub plaintext: &'a mut [u8], +} + +impl<'a> TryFrom<&'a mut [u8]> for DataPacket<'a> { + type Error = MessageError; + + fn try_from(bytes: &'a mut [u8]) -> Result { + if bytes.len() < DataPacketHeader::SIZE { + return Err(MessageError::BufferTooSmall { + required: DataPacketHeader::SIZE, + actual: bytes.len(), + }); + } + + let (header_bytes, plaintext) = bytes.split_at_mut(DataPacketHeader::SIZE); + let header = DataPacketHeader::ref_from_bytes(header_bytes) + .map_err(|_| MessageError::InvalidDataPacketHeader)?; + + Ok(DataPacket { header, plaintext }) + } +} + +#[cfg(test)] +mod tests { + use std::convert::TryFrom; + + use super::*; + use crate::errors::MessageError; + + #[test] + fn test_handshake_initiation_default() { + let msg = HandshakeInitiation::default(); + assert_eq!(msg.message_type, TYPE_HANDSHAKE_INITIATION); + } + + #[test] + fn test_handshake_initiation_mac1_input() { + let msg = HandshakeInitiation::default(); + let mac1_input = msg.mac1_input(); + assert_eq!(mac1_input.len(), HandshakeInitiation::MAC1_OFFSET); + } + + #[test] + fn test_handshake_initiation_mac2_input() { + let msg = HandshakeInitiation::default(); + let mac2_input = msg.mac2_input(); + assert_eq!(mac2_input.len(), HandshakeInitiation::MAC2_OFFSET); + } + + #[test] + fn test_handshake_initiation_from_bytes() { + let mut bytes = [0u8; HandshakeInitiation::SIZE]; + bytes[0] = TYPE_HANDSHAKE_INITIATION; + + let msg = <&HandshakeInitiation>::try_from(&bytes[..]).unwrap(); + assert_eq!(msg.message_type, TYPE_HANDSHAKE_INITIATION); + } + + #[test] + fn test_handshake_initiation_from_bytes_invalid_size() { + let bytes = [0u8; HandshakeInitiation::SIZE - 1]; + let result = <&HandshakeInitiation>::try_from(&bytes[..]); + assert!(matches!(result, Err(MessageError::BufferTooSmall { .. }))); + } + + #[test] + fn test_handshake_initiation_from_mut_bytes_invalid_size() { + let mut bytes = [0u8; HandshakeInitiation::SIZE - 1]; + let result = <&mut HandshakeInitiation>::try_from(&mut bytes[..]); + assert!(matches!(result, Err(MessageError::BufferTooSmall { .. }))); + } + + #[test] + fn test_handshake_response_default() { + let msg = HandshakeResponse::default(); + assert_eq!(msg.message_type, TYPE_HANDSHAKE_RESPONSE); + } + + #[test] + fn test_handshake_response_mac1_input() { + let msg = HandshakeResponse::default(); + let mac1_input = msg.mac1_input(); + assert_eq!(mac1_input.len(), HandshakeResponse::MAC1_OFFSET); + } + + #[test] + fn test_handshake_response_mac2_input() { + let msg = HandshakeResponse::default(); + let mac2_input = msg.mac2_input(); + assert_eq!(mac2_input.len(), HandshakeResponse::MAC2_OFFSET); + } + + #[test] + fn test_handshake_response_from_bytes() { + let mut bytes = [0u8; HandshakeResponse::SIZE]; + bytes[0] = TYPE_HANDSHAKE_RESPONSE; + + let msg = <&HandshakeResponse>::try_from(&bytes[..]).unwrap(); + assert_eq!(msg.message_type, TYPE_HANDSHAKE_RESPONSE); + } + + #[test] + fn test_handshake_response_from_bytes_invalid_size() { + let bytes = [0u8; HandshakeResponse::SIZE - 1]; + let result = <&HandshakeResponse>::try_from(&bytes[..]); + assert!(matches!(result, Err(MessageError::BufferTooSmall { .. }))); + } + + #[test] + fn test_handshake_response_from_mut_bytes() { + let mut bytes = [0u8; HandshakeResponse::SIZE]; + bytes[0] = TYPE_HANDSHAKE_RESPONSE; + + let msg = <&mut HandshakeResponse>::try_from(&mut bytes[..]).unwrap(); + assert_eq!(msg.message_type, TYPE_HANDSHAKE_RESPONSE); + } + + #[test] + fn test_handshake_response_from_mut_bytes_invalid_size() { + let mut bytes = [0u8; HandshakeResponse::SIZE - 1]; + let result = <&mut HandshakeResponse>::try_from(&mut bytes[..]); + assert!(matches!(result, Err(MessageError::BufferTooSmall { .. }))); + } + + #[test] + fn test_cookie_reply_default() { + let msg = CookieReply::default(); + assert_eq!(msg.message_type, TYPE_COOKIE_REPLY); + } + + #[test] + fn test_cookie_reply_from_bytes() { + let mut bytes = [0u8; CookieReply::SIZE]; + bytes[0] = TYPE_COOKIE_REPLY; + + let msg = <&CookieReply>::try_from(&bytes[..]).unwrap(); + assert_eq!(msg.message_type, TYPE_COOKIE_REPLY); + } + + #[test] + fn test_cookie_reply_from_bytes_invalid_size() { + let bytes = [0u8; CookieReply::SIZE - 1]; + let result = <&CookieReply>::try_from(&bytes[..]); + assert!(matches!(result, Err(MessageError::BufferTooSmall { .. }))); + } + + #[test] + fn test_cookie_reply_from_mut_bytes() { + let mut bytes = [0u8; CookieReply::SIZE]; + bytes[0] = TYPE_COOKIE_REPLY; + + let msg = <&mut CookieReply>::try_from(&mut bytes[..]).unwrap(); + assert_eq!(msg.message_type, TYPE_COOKIE_REPLY); + } + + #[test] + fn test_cookie_reply_from_mut_bytes_invalid_size() { + let mut bytes = [0u8; CookieReply::SIZE - 1]; + let result = <&mut CookieReply>::try_from(&mut bytes[..]); + assert!(matches!(result, Err(MessageError::BufferTooSmall { .. }))); + } + + #[test] + fn test_data_packet_header_default() { + let msg = DataPacketHeader::default(); + assert_eq!(msg.message_type, TYPE_DATA); + } + + #[test] + fn test_data_packet_header_from_bytes() { + let mut bytes = [0u8; DataPacketHeader::SIZE]; + bytes[0] = TYPE_DATA; + + let msg = <&DataPacketHeader>::try_from(&bytes[..]).unwrap(); + assert_eq!(msg.message_type, TYPE_DATA); + } + + #[test] + fn test_data_packet_header_from_bytes_with_extra() { + let mut bytes = [0u8; DataPacketHeader::SIZE + 100]; + bytes[0] = TYPE_DATA; + + let msg = <&DataPacketHeader>::try_from(&bytes[..]).unwrap(); + assert_eq!(msg.message_type, TYPE_DATA); + } + + #[test] + fn test_data_packet_header_from_bytes_invalid_size() { + let bytes = [0u8; DataPacketHeader::SIZE - 1]; + let result = <&DataPacketHeader>::try_from(&bytes[..]); + assert!(matches!(result, Err(MessageError::BufferTooSmall { .. }))); + } + + #[test] + fn test_data_packet_header_from_mut_bytes() { + let mut bytes = [0u8; DataPacketHeader::SIZE]; + bytes[0] = TYPE_DATA; + + let msg = <&mut DataPacketHeader>::try_from(&mut bytes[..]).unwrap(); + assert_eq!(msg.message_type, TYPE_DATA); + } + + #[test] + fn test_data_packet_header_from_mut_bytes_with_extra() { + let mut bytes = [0u8; DataPacketHeader::SIZE + 100]; + bytes[0] = TYPE_DATA; + + let msg = <&mut DataPacketHeader>::try_from(&mut bytes[..]).unwrap(); + assert_eq!(msg.message_type, TYPE_DATA); + } + + #[test] + fn test_data_packet_header_from_mut_bytes_invalid_size() { + let mut bytes = [0u8; DataPacketHeader::SIZE - 1]; + let result = <&mut DataPacketHeader>::try_from(&mut bytes[..]); + assert!(matches!(result, Err(MessageError::BufferTooSmall { .. }))); + } + + #[test] + fn test_data_packet_from_bytes() { + let mut bytes = [0u8; DataPacketHeader::SIZE + 100]; + bytes[0] = TYPE_DATA; + + let packet = DataPacket::try_from(&mut bytes[..]).unwrap(); + assert_eq!(packet.header.message_type, TYPE_DATA); + assert_eq!(packet.plaintext.len(), 100); + } + + #[test] + fn test_data_packet_from_bytes_no_payload() { + let mut bytes = [0u8; DataPacketHeader::SIZE]; + bytes[0] = TYPE_DATA; + + let packet = DataPacket::try_from(&mut bytes[..]).unwrap(); + assert_eq!(packet.header.message_type, TYPE_DATA); + assert_eq!(packet.plaintext.len(), 0); + } + + #[test] + fn test_data_packet_from_bytes_invalid_size() { + let mut bytes = [0u8; DataPacketHeader::SIZE - 1]; + let result = DataPacket::try_from(&mut bytes[..]); + assert!(matches!(result, Err(MessageError::BufferTooSmall { .. }))); + } +} diff --git a/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__complete_handshake_trace.snap b/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__complete_handshake_trace.snap new file mode 100644 index 0000000000..1bf607d60f --- /dev/null +++ b/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__complete_handshake_trace.snap @@ -0,0 +1,42 @@ +--- +source: monad-wireauth-protocol/src/handshake.rs +expression: trace +--- +test_name: standard_handshake +seed: 42 +timestamp: 1700000000 +initiator_static_private: a22427226377cc867d51ad3f130af08ad13451de7160efa2b23076fd782de967 +initiator_static_public: 0351177dde89242d9121d787a681bd2a0bd6013428a6b83e684a253815db96d8b3 +responder_static_private: ea9f11f8dfb0ca08a8810f9ea39c3a6afb780859e8d8c7bc37b78e2f9b8d68d9 +responder_static_public: 035b0ef8c9bd756af433edc3129975888f6f18b8185b2afbaabc8bb3029a00cf81 +initiator_session_index: 100 +responder_session_index: 200 +init_message: + raw_bytes: 010000006400000002f8112234026e68b1e4d1565540d7b791ced1b64c5f30525cbe14f21dd7aa8c7867edd9f17422725dee2667cf726d0ed50f398048d01fe572838c93b934e8be07a7ff68535b15d42f2107b9acf8061b904542b61ddd9d795c8edecef3345508b3270675521141d54ce2d824f23c27a87935cca82c9284c13b7d569957c400000000000000000000000000000000 + sender_index: 100 + receiver_index: 0 + ephemeral_public: 02f8112234026e68b1e4d1565540d7b791ced1b64c5f30525cbe14f21dd7aa8c78 + encrypted_static: 67edd9f17422725dee2667cf726d0ed50f398048d01fe572838c93b934e8be07a7 + encrypted_static_tag: ff68535b15d42f2107b9acf8061b9045 + encrypted_timestamp: 42b61ddd9d795c8edecef334 + encrypted_timestamp_tag: 5508b3270675521141d54ce2d824f23c + encrypted_nothing_tag: ~ + mac1: 27a87935cca82c9284c13b7d569957c4 + mac2: "00000000000000000000000000000000" +response_message: + raw_bytes: 02000000c80000006400000003c48b53afac0d2d5169cd6848f03f67a21db6f506f8b8fc2dbef2552b6a7dc111c8829fbc22114e65bf26be96959d83a7797408dbb8ddfede5c1b4c795a3c4ef400000000000000000000000000000000 + sender_index: 200 + receiver_index: 100 + ephemeral_public: 03c48b53afac0d2d5169cd6848f03f67a21db6f506f8b8fc2dbef2552b6a7dc111 + encrypted_static: "" + encrypted_static_tag: "" + encrypted_timestamp: "" + encrypted_timestamp_tag: "" + encrypted_nothing_tag: c8829fbc22114e65bf26be96959d83a7 + mac1: 797408dbb8ddfede5c1b4c795a3c4ef4 + mac2: "00000000000000000000000000000000" +transport_keys: + initiator_send_key: fc1746686aad0bd4ee7789dc2dfa37b1 + initiator_recv_key: af8122429bbc4fc39749e2032b17691e + responder_send_key: af8122429bbc4fc39749e2032b17691e + responder_recv_key: fc1746686aad0bd4ee7789dc2dfa37b1 diff --git a/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__cookie_handshake_trace.snap b/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__cookie_handshake_trace.snap new file mode 100644 index 0000000000..19f2e385ef --- /dev/null +++ b/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__cookie_handshake_trace.snap @@ -0,0 +1,49 @@ +--- +source: monad-wireauth-protocol/src/handshake.rs +expression: trace +--- +test_name: cookie_handshake +seed: 43 +cookie_secret: "7777777777777777777777777777777777777777777777777777777777777777" +cookie_value: 7bb20b9a4d121660f35c669bec4f69df +init_without_cookie: + raw_bytes: 01000000650000000265e25c471caa29f8611c3d20cee58f933fc1a3e4e6fb24cd5b732c2459d320ec0d47119242bc575cef24279f9296f8fbe29110dd82bca2e9d10b3d3664cb263ecb06c3b44f5353e673abb7507dd6629257dc6a2719b942fb54bef8342d9fba6dd46091005730dbea9740df0bdc2e9cdafbc5021053b947a9858ea7a63f00000000000000000000000000000000 + sender_index: 101 + receiver_index: 0 + ephemeral_public: 0265e25c471caa29f8611c3d20cee58f933fc1a3e4e6fb24cd5b732c2459d320ec + encrypted_static: 0d47119242bc575cef24279f9296f8fbe29110dd82bca2e9d10b3d3664cb263ecb + encrypted_static_tag: 06c3b44f5353e673abb7507dd6629257 + encrypted_timestamp: dc6a2719b942fb54bef8342d + encrypted_timestamp_tag: 9fba6dd46091005730dbea9740df0bdc + encrypted_nothing_tag: ~ + mac1: 2e9cdafbc5021053b947a9858ea7a63f + mac2: "00000000000000000000000000000000" +cookie_reply: + raw_bytes: 0300000065000000e18475b5c5b8d03d7294feb1f37c88238a428d90f132081601674e46e5ba75772170a13a17bdaf24bc88e51b8a8479ce + receiver_index: 101 + nonce: e18475b5c5b8d03d7294feb1f37c8823 + encrypted_cookie: 8a428d90f132081601674e46e5ba7577 +init_with_cookie: + raw_bytes: 0100000066000000023864b260a21d2edb2b6d4b375febac55c06f675258d0d7cf67f98bb85a724b15e4452952ac222bd57e88f8fdbdcecdf0b5133626f88ac7eeb281514b61ef23d345e74df2394df4accafe3a818b5ed2b79f8a74acbf3fdbf49dfe58a7035609379c08b14d2dfbece9f62fdfd2effe7f0f5d881e18c065aa17e98f4167e8d13c103fa867ad1b83b0e5bc76fcceb5 + sender_index: 102 + receiver_index: 0 + ephemeral_public: 023864b260a21d2edb2b6d4b375febac55c06f675258d0d7cf67f98bb85a724b15 + encrypted_static: e4452952ac222bd57e88f8fdbdcecdf0b5133626f88ac7eeb281514b61ef23d345 + encrypted_static_tag: e74df2394df4accafe3a818b5ed2b79f + encrypted_timestamp: 8a74acbf3fdbf49dfe58a703 + encrypted_timestamp_tag: 5609379c08b14d2dfbece9f62fdfd2ef + encrypted_nothing_tag: ~ + mac1: fe7f0f5d881e18c065aa17e98f4167e8 + mac2: d13c103fa867ad1b83b0e5bc76fcceb5 +response_with_cookie: + raw_bytes: 02000000c900000066000000038625117887c0985ae9c3383da4d1aae4ac406c6523a37e70317c66f4987c210479e80bae2dff6699a68d160e25c8edf70635769467adb99b839cfdf5aa11cf52faa8bd118f88f6265b0eb242fcd004ae + sender_index: 201 + receiver_index: 102 + ephemeral_public: 038625117887c0985ae9c3383da4d1aae4ac406c6523a37e70317c66f4987c2104 + encrypted_static: "" + encrypted_static_tag: "" + encrypted_timestamp: "" + encrypted_timestamp_tag: "" + encrypted_nothing_tag: 79e80bae2dff6699a68d160e25c8edf7 + mac1: 0635769467adb99b839cfdf5aa11cf52 + mac2: faa8bd118f88f6265b0eb242fcd004ae diff --git a/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__data_encryption_traces.snap b/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__data_encryption_traces.snap new file mode 100644 index 0000000000..60629d87ba --- /dev/null +++ b/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__data_encryption_traces.snap @@ -0,0 +1,20 @@ +--- +source: monad-wireauth-protocol/src/handshake.rs +expression: "vec![initiator_to_responder, responder_to_initiator,]" +--- +- test_name: initiator_to_responder + plaintext: "Hello, encrypted world!" + nonce: "01000000000000000000000000000000" + sender_index: 102 + receiver_index: 202 + encrypted_payload: 854bf0e0c5fbb0c976465c6645920a485c19bd544b4eb9 + auth_tag: bbef4ee6d67de4e12f05ff8d455f411b + complete_packet: 04ca00000001000000000000000000000000000000854bf0e0c5fbb0c976465c6645920a485c19bd544b4eb9bbef4ee6d67de4e12f05ff8d455f411b +- test_name: responder_to_initiator + plaintext: Response from responder! + nonce: "01000000000000000000000000000000" + sender_index: 202 + receiver_index: 102 + encrypted_payload: e3d9bad5a97edb95122e1c59dee16ada81e941ec66fec578 + auth_tag: f0d4e02ea40bf601b84923069abd933e + complete_packet: 046600000001000000000000000000000000000000e3d9bad5a97edb95122e1c59dee16ada81e941ec66fec578f0d4e02ea40bf601b84923069abd933e diff --git a/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__multiple_protocol_vectors.snap b/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__multiple_protocol_vectors.snap new file mode 100644 index 0000000000..4ee7c29357 --- /dev/null +++ b/monad-wireauth-protocol/src/snapshots/monad_wireauth_protocol__handshake__tests__multiple_protocol_vectors.snap @@ -0,0 +1,118 @@ +--- +source: monad-wireauth-protocol/src/handshake.rs +expression: all_traces +--- +- test_name: vector_100 + seed: 100 + timestamp: 1700000100 + initiator_static_private: 99dd7fc1ad584d9b174275ef9de7bda04fc61e38899fdce22fd31a49f3fc47d6 + initiator_static_public: 02a3548483481d4f63dd1eceb6468b1ea59069e5a6799722390b448e3d5b8de8ee + responder_static_private: 126242834c575d4ccd51fa7081775c09746305ab0889844fff09c2018a5548bd + responder_static_public: 0216a1745cb30cb453c65f69627c82cf28f215a31d90c86fbab3f9db75f9126540 + initiator_session_index: 4057598532 + responder_session_index: 3810930514 + init_message: + raw_bytes: 01000000440adaf103a8bee2d6eb6aa7e6342133bfd70fcbccd4ab5df857b3d8738bb4af70805bdaabd9e6ba356d9a487a65ec70f94148b0afe135e8a0544e71ebd89baf7922c9b385d6d3ab80b6e78f8a1b25cb8f56e384ef464448eb894fbd9bcdaa5cb6e80fb5e6d29d8a5d0be1b9f35566d60dbbb2cc6d77e74115ab407c8fecf47500f500000000000000000000000000000000 + sender_index: 4057598532 + receiver_index: 0 + ephemeral_public: 03a8bee2d6eb6aa7e6342133bfd70fcbccd4ab5df857b3d8738bb4af70805bdaab + encrypted_static: d9e6ba356d9a487a65ec70f94148b0afe135e8a0544e71ebd89baf7922c9b385d6 + encrypted_static_tag: d3ab80b6e78f8a1b25cb8f56e384ef46 + encrypted_timestamp: 4448eb894fbd9bcdaa5cb6e8 + encrypted_timestamp_tag: 0fb5e6d29d8a5d0be1b9f35566d60dbb + encrypted_nothing_tag: ~ + mac1: b2cc6d77e74115ab407c8fecf47500f5 + mac2: "00000000000000000000000000000000" + response_message: + raw_bytes: 02000000522f26e3440adaf102d53ba4d43ff563a77771973b51231948b171c5a2a68e5e0065badd8ccdafbbc0cd2081fb6b80666deba04c952b32a5ccce2afb4044a4d3616f2b87ba2de3cbbf00000000000000000000000000000000 + sender_index: 3810930514 + receiver_index: 4057598532 + ephemeral_public: 02d53ba4d43ff563a77771973b51231948b171c5a2a68e5e0065badd8ccdafbbc0 + encrypted_static: "" + encrypted_static_tag: "" + encrypted_timestamp: "" + encrypted_timestamp_tag: "" + encrypted_nothing_tag: cd2081fb6b80666deba04c952b32a5cc + mac1: ce2afb4044a4d3616f2b87ba2de3cbbf + mac2: "00000000000000000000000000000000" + transport_keys: + initiator_send_key: ebf7a4ed2503c39bc000eb5a4b8d3925 + initiator_recv_key: 1740a6c9c52768d8037372cc20bf56e9 + responder_send_key: 1740a6c9c52768d8037372cc20bf56e9 + responder_recv_key: ebf7a4ed2503c39bc000eb5a4b8d3925 +- test_name: vector_200 + seed: 200 + timestamp: 1700000200 + initiator_static_private: 68472ad1a4d2a888dce7571ac9b8080d9c405ea1e1889372504aa50628975c7a + initiator_static_public: 038620c6628bd80c1148c9a90833c06e5df5d31150392b35d3de7e3e2561b135dc + responder_static_private: c8d2362ecaaf13d084cdc2bd3133ba16b6f48408673e28ac61b1dbfbf54365ed + responder_static_public: 02c69fa1f3ff47b7f962ff001716b12bbec1faced974756b7f6f25658ef537806d + initiator_session_index: 2527227798 + responder_session_index: 818229675 + init_message: + raw_bytes: 01000000966fa29603f8c7291f6fe8ad581e735812a7060b472a6d62a18e7854ccd85d33b37a34d9713d426a627fc8b172ccd9f18d7a91565a394e270d09991719075305d7c7842c8db1844874b1547df0165a185173024daaf5c270d89095de7aa0cbace46e69eaae783bdd6005da9703ba6e375d2b89f6a61c4fbc5b67cd8f12bb5d82762f00000000000000000000000000000000 + sender_index: 2527227798 + receiver_index: 0 + ephemeral_public: 03f8c7291f6fe8ad581e735812a7060b472a6d62a18e7854ccd85d33b37a34d971 + encrypted_static: 3d426a627fc8b172ccd9f18d7a91565a394e270d09991719075305d7c7842c8db1 + encrypted_static_tag: 844874b1547df0165a185173024daaf5 + encrypted_timestamp: c270d89095de7aa0cbace46e + encrypted_timestamp_tag: 69eaae783bdd6005da9703ba6e375d2b + encrypted_nothing_tag: ~ + mac1: 89f6a61c4fbc5b67cd8f12bb5d82762f + mac2: "00000000000000000000000000000000" + response_message: + raw_bytes: 02000000ab31c530966fa2960216db8b479b16b416e6dbc721ac6dc70570f8e8b3ca38a171ddd282828f09d898a65d56bb2d62c7c218773e1c9055e06c0b39019982cf89ede80170c719018a4900000000000000000000000000000000 + sender_index: 818229675 + receiver_index: 2527227798 + ephemeral_public: 0216db8b479b16b416e6dbc721ac6dc70570f8e8b3ca38a171ddd282828f09d898 + encrypted_static: "" + encrypted_static_tag: "" + encrypted_timestamp: "" + encrypted_timestamp_tag: "" + encrypted_nothing_tag: a65d56bb2d62c7c218773e1c9055e06c + mac1: 0b39019982cf89ede80170c719018a49 + mac2: "00000000000000000000000000000000" + transport_keys: + initiator_send_key: 14d9484cf797536e9a1adcef72dc1203 + initiator_recv_key: d9698f68143b505facce394ef60682a1 + responder_send_key: d9698f68143b505facce394ef60682a1 + responder_recv_key: 14d9484cf797536e9a1adcef72dc1203 +- test_name: vector_300 + seed: 300 + timestamp: 1700000300 + initiator_static_private: e797f9258332c8d45a477eb922f0c7f51396d65409765311ddcef605caed25b5 + initiator_static_public: 036700cffdb1df1364cf80ed98ff6192ae3ad04e40c323cbfc28a73bcdd26b1691 + responder_static_private: d24f8034ba487a79ffa042a1080bf2bf8344e88151808c83fe325954378aa13a + responder_static_public: 0391bc8676e4d9ed9da88a599e351f69141c41f9b15ddbd5c02688f8c03f1d0dec + initiator_session_index: 3458659921 + responder_session_index: 3444663251 + init_message: + raw_bytes: 0100000051f626ce02a5ce77b5a2841ea69b23070e208e947a1047056d2af60c121a3a88d4a7f472dd2ec954ebf1a07c98bc2f6bd92a73829a91a277b970cfe4f4ee41cd8a7883f3a5cb13fe7b6c382553f34c5b90d8fc318c1a170cc4fff56b95d47f32789f0e6e3161495c349f9672828cb91321b3639deb6b19a91f12eeaf509781fcb1c900000000000000000000000000000000 + sender_index: 3458659921 + receiver_index: 0 + ephemeral_public: 02a5ce77b5a2841ea69b23070e208e947a1047056d2af60c121a3a88d4a7f472dd + encrypted_static: 2ec954ebf1a07c98bc2f6bd92a73829a91a277b970cfe4f4ee41cd8a7883f3a5cb + encrypted_static_tag: 13fe7b6c382553f34c5b90d8fc318c1a + encrypted_timestamp: 170cc4fff56b95d47f32789f + encrypted_timestamp_tag: 0e6e3161495c349f9672828cb91321b3 + encrypted_nothing_tag: ~ + mac1: 639deb6b19a91f12eeaf509781fcb1c9 + mac2: "00000000000000000000000000000000" + response_message: + raw_bytes: 02000000d36351cd51f626ce029d8409c38893233727bb8176ae3e16750d8a1eccd128d110166a07c4d18318e3c99295dafcf42a45562cb8657c14f83cee08f9f4d1b6e3aa55f11b132035f7a200000000000000000000000000000000 + sender_index: 3444663251 + receiver_index: 3458659921 + ephemeral_public: 029d8409c38893233727bb8176ae3e16750d8a1eccd128d110166a07c4d18318e3 + encrypted_static: "" + encrypted_static_tag: "" + encrypted_timestamp: "" + encrypted_timestamp_tag: "" + encrypted_nothing_tag: c99295dafcf42a45562cb8657c14f83c + mac1: ee08f9f4d1b6e3aa55f11b132035f7a2 + mac2: "00000000000000000000000000000000" + transport_keys: + initiator_send_key: 19d860737ab92abb275bb2b239bdd1f7 + initiator_recv_key: bd457fab82ec100f4749de8787fcd2fb + responder_send_key: bd457fab82ec100f4749de8787fcd2fb + responder_recv_key: 19d860737ab92abb275bb2b239bdd1f7 diff --git a/monad-wireauth-session/Cargo.toml b/monad-wireauth-session/Cargo.toml new file mode 100644 index 0000000000..2f9783abba --- /dev/null +++ b/monad-wireauth-session/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "monad-wireauth-session" +version = "0.1.0" +edition = "2021" + +[dependencies] +monad-wireauth-protocol.workspace = true +bytes.workspace = true +rand.workspace = true +thiserror.workspace = true +secp256k1 = { workspace = true, features = ["global-context"] } +zerocopy.workspace = true +zeroize.workspace = true +tracing.workspace = true +hex.workspace = true + +[dev-dependencies] +tracing-subscriber = { workspace = true, features = ["env-filter"] } +rstest.workspace = true +proptest.workspace = true + +[features] +bench = [] \ No newline at end of file diff --git a/monad-wireauth-session/README.md b/monad-wireauth-session/README.md new file mode 100644 index 0000000000..cdd9aad46b --- /dev/null +++ b/monad-wireauth-session/README.md @@ -0,0 +1,72 @@ +# Session + +## Initiator + +```mermaid +flowchart TD + Start([new]) --> Init[Initiator] + Cookie{{CookieReply}} + Init --> Cookie --> Init + Init -->|HandshakeResponse| Transport[Transport] + Init -->|timeout| Retry{retry > 0?} + Retry -->|yes| Rekey[emit RekeyEvent] + Retry -->|no| End([terminated]) + Rekey --> End +``` + +| State | Event | Next State | Actions | +|-------|-------|------------|---------| +| - | new() | Init | send HandshakeInit, start session_timeout | +| Init | CookieReply | Init | store cookie | +| Init | HandshakeResponse | Transport | validate, establish transport, start session_timeout/rekey/max_session_duration, send empty DataPacket | +| Init | timeout (retry > 0) | terminated | decrement retry, emit RekeyEvent | +| Init | timeout (retry == 0) | terminated | - | + +## Responder + +```mermaid +flowchart TD + Start([validate_init + new]) --> Resp[Responder] + Cookie{{CookieReply}} + Resp --> Cookie --> Resp + Resp -->|DataPacket| Transport[Transport] + Resp -->|timeout| End([terminated]) +``` + +| State | Event | Next State | Actions | +|-------|-------|------------|---------| +| - | validate_init() + new() | Resp | send HandshakeResponse, start session_timeout | +| Resp | CookieReply | Resp | store cookie | +| Resp | DataPacket | Transport | decrypt, establish transport, start session_timeout/keepalive/max_session_duration | +| Resp | timeout | terminated | - | + +## Transport + +```mermaid +flowchart TD + Start([establish]) --> Active[Active] + Encrypt{{encrypt}} + Decrypt{{decrypt}} + Keepalive{{keepalive timer}} + RekeyTimer{{rekey timer}} + Active --> Encrypt --> Active + Active --> Decrypt --> Active + Active --> Keepalive --> Active + Active --> RekeyTimer --> Active + Active -->|timeout| Check{initiator && retry > 0?} + Active -->|max_session_duration| End([terminated]) + Check -->|yes| Rekey[emit RekeyEvent] + Check -->|no| End + Rekey --> End +``` + +| State | Event | Next State | Actions | +|-------|-------|------------|---------| +| - | establish | Active | - | +| Active | encrypt() | Active | increment send_nonce, reset keepalive | +| Active | decrypt() | Active | check replay, reset session_timeout | +| Active | keepalive timer | Active | send empty DataPacket, reset keepalive | +| Active | rekey timer | Active | emit RekeyEvent | +| Active | timeout (initiator, retry > 0) | terminated | decrement retry, emit RekeyEvent | +| Active | timeout (other) | terminated | - | +| Active | max_session_duration | terminated | - | diff --git a/monad-wireauth-session/src/common.rs b/monad-wireauth-session/src/common.rs new file mode 100644 index 0000000000..08545ca398 --- /dev/null +++ b/monad-wireauth-session/src/common.rs @@ -0,0 +1,299 @@ +use std::{ + net::SocketAddr, + time::{Duration, SystemTime}, +}; + +use monad_wireauth_protocol::{common::*, cookies}; +use tracing::debug; +use zeroize::Zeroizing; + +pub const RETRY_ALWAYS: u64 = u64::MAX; +pub const DEFAULT_RETRY_ATTEMPTS: u64 = 3; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SessionStatus { + Initiating, + Open, + Closed, +} + +#[derive(Debug, Clone)] +pub struct EstablishedEvent { + pub remote_public_key: PublicKey, + pub remote_addr: SocketAddr, + pub is_initiator: bool, + pub created: Duration, + pub local_index: SessionIndex, +} + +#[derive(Debug, Clone)] +pub struct TerminatedEvent { + pub remote_public_key: PublicKey, + pub remote_addr: SocketAddr, +} + +#[derive(Debug, Clone)] +pub struct SessionTimeoutResult { + pub terminated: TerminatedEvent, + pub rekey: Option, +} + +#[derive(Debug, Clone)] +pub struct RekeyEvent { + pub remote_public_key: PublicKey, + pub remote_addr: SocketAddr, + pub retry_attempts: u64, + pub stored_cookie: Option<[u8; 16]>, +} + +#[derive(Clone)] +pub struct MessageEvent { + pub remote_addr: SocketAddr, + pub header: monad_wireauth_protocol::messages::DataPacketHeader, +} + +#[derive(Clone)] +pub struct Config { + pub session_timeout: Duration, + pub session_timeout_jitter: Duration, + pub keepalive_interval: Duration, + pub keepalive_jitter: Duration, + pub rekey_interval: Duration, + pub rekey_jitter: Duration, + pub max_session_duration: Duration, + pub handshake_rate_limit: u64, + pub handshake_rate_reset_interval: Duration, + pub cookie_refresh_duration: Duration, + pub low_watermark_sessions: usize, + pub high_watermark_sessions: usize, + pub max_sessions_per_ip: usize, + pub ip_rate_limit_window: Duration, + pub max_requests_per_ip: usize, + pub ip_history_capacity: usize, + pub psk: Zeroizing<[u8; 32]>, +} + +impl Default for Config { + fn default() -> Self { + Self { + session_timeout: Duration::from_secs(10), + session_timeout_jitter: Duration::from_secs(1), + keepalive_interval: Duration::from_secs(3), + keepalive_jitter: Duration::from_secs(1), + rekey_interval: Duration::from_secs(6 * 60 * 60), + rekey_jitter: Duration::from_secs(60), + max_session_duration: Duration::from_secs(7 * 60 * 60), + handshake_rate_limit: 2000, + handshake_rate_reset_interval: Duration::from_secs(1), + cookie_refresh_duration: Duration::from_secs(120), + low_watermark_sessions: 10_000, + high_watermark_sessions: 100_000, + max_sessions_per_ip: 10, + ip_rate_limit_window: Duration::from_secs(10), + max_requests_per_ip: 10, + ip_history_capacity: 1_000_000, + psk: Zeroizing::new([0u8; 32]), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum SessionError { + #[error("handshake validation failed: {0}")] + InvalidHandshake(#[source] monad_wireauth_protocol::errors::ProtocolError), + #[error("session not established: attempted operation on non-existent or expired session")] + NotEstablished, + #[error("invalid packet format: {0}")] + InvalidPacket(#[source] monad_wireauth_protocol::errors::MessageError), + #[error("cryptographic operation failed: {0}")] + CryptoError(#[source] monad_wireauth_protocol::errors::CryptoError), + #[error("MAC verification failed: {0}")] + InvalidMac(#[source] monad_wireauth_protocol::errors::CryptoError), + #[error("cookie validation failed: {0}")] + InvalidCookie(#[source] monad_wireauth_protocol::errors::CookieError), + #[error("replay attack detected: packet counter {counter} already seen")] + ReplayAttack { counter: u64 }, + #[error("timestamp replay detected: timestamp not newer than last seen")] + TimestampReplay, + #[error("session timed out: exceeded maximum duration or idle time")] + SessionTimeout, +} + +pub struct SessionState { + pub keepalive_deadline: Option, + pub rekey_deadline: Option, + pub session_timeout_deadline: Option, + pub max_session_duration_deadline: Option, + pub stored_cookie: Option<[u8; 16]>, + pub last_handshake_mac1: Option<[u8; 16]>, + pub retry_attempts: u64, + pub initiator_system_time: Option, + pub remote_addr: SocketAddr, + pub remote_public_key: PublicKey, + pub local_index: SessionIndex, + pub created: Duration, + pub is_initiator: bool, +} + +impl SessionState { + pub fn new( + remote_addr: SocketAddr, + remote_public_key: PublicKey, + local_index: SessionIndex, + created: Duration, + retry_attempts: u64, + initiator_system_time: Option, + is_initiator: bool, + ) -> Self { + Self { + keepalive_deadline: None, + rekey_deadline: None, + session_timeout_deadline: None, + max_session_duration_deadline: None, + stored_cookie: None, + last_handshake_mac1: None, + retry_attempts, + initiator_system_time, + remote_addr, + remote_public_key, + local_index, + created, + is_initiator, + } + } + + pub fn reset_keepalive(&mut self, duration_since_start: Duration, timer_duration: Duration) { + self.keepalive_deadline = Some(duration_since_start + timer_duration); + } + + pub fn reset_rekey(&mut self, duration_since_start: Duration, timer_duration: Duration) { + self.rekey_deadline = Some(duration_since_start + timer_duration); + } + + pub fn reset_session_timeout( + &mut self, + duration_since_start: Duration, + timer_duration: Duration, + ) { + self.session_timeout_deadline = Some(duration_since_start + timer_duration); + } + + pub fn clear_keepalive(&mut self) { + self.keepalive_deadline = None; + } + + pub fn clear_rekey(&mut self) { + self.rekey_deadline = None; + } + + pub fn clear_session_timeout(&mut self) { + self.session_timeout_deadline = None; + } + + pub fn set_max_session_duration( + &mut self, + duration_since_start: Duration, + timer_duration: Duration, + ) { + self.max_session_duration_deadline = Some(duration_since_start + timer_duration); + } + + pub fn clear_max_session_duration(&mut self) { + self.max_session_duration_deadline = None; + } + + pub fn get_next_deadline(&self) -> Option { + [ + self.keepalive_deadline, + self.rekey_deadline, + self.session_timeout_deadline, + self.max_session_duration_deadline, + ] + .iter() + .filter_map(|&timer| timer) + .min() + } + + pub fn stored_cookie(&self) -> Option<[u8; 16]> { + self.stored_cookie + } + + pub fn initiator_system_time(&self) -> Option { + self.initiator_system_time + } + + pub fn handle_cookie( + &mut self, + cookie_reply: &mut monad_wireauth_protocol::messages::CookieReply, + ) -> Result<(), SessionError> { + let Some(mac1) = self.last_handshake_mac1 else { + debug!("no last_handshake_mac1 stored"); + return Err(SessionError::InvalidCookie( + monad_wireauth_protocol::errors::CookieError::InvalidCookieMac( + monad_wireauth_protocol::errors::CryptoError::MacVerificationFailed, + ), + )); + }; + + let cookie = cookies::accept_cookie_reply( + &SerializedPublicKey::from(&self.remote_public_key), + cookie_reply, + &mac1, + ) + .map_err(|e| { + debug!(error=?e, "failed to accept cookie reply"); + use monad_wireauth_protocol::errors::ProtocolError; + match e { + ProtocolError::Cookie(c) => SessionError::InvalidCookie(c), + ProtocolError::Crypto(c) => SessionError::CryptoError(c), + _ => SessionError::InvalidHandshake(e), + } + })?; + + self.stored_cookie = Some(cookie); + debug!("cookie stored successfully"); + Ok(()) + } + + pub fn handle_session_timeout(&mut self) -> (TerminatedEvent, Option) { + debug!( + retry_attempts = self.retry_attempts, + remote_addr = ?self.remote_addr, + is_initiator = self.is_initiator, + "handling session timeout" + ); + + let terminated = TerminatedEvent { + remote_public_key: self.remote_public_key.clone(), + remote_addr: self.remote_addr, + }; + + if !self.is_initiator { + return (terminated, None); + } + + let should_retry = self.retry_attempts > 0 || self.retry_attempts == RETRY_ALWAYS; + if self.retry_attempts > 0 && self.retry_attempts != RETRY_ALWAYS { + self.retry_attempts -= 1; + } + + let rekey = should_retry.then(|| RekeyEvent { + remote_public_key: self.remote_public_key.clone(), + remote_addr: self.remote_addr, + retry_attempts: self.retry_attempts, + stored_cookie: self.stored_cookie, + }); + + (terminated, rekey) + } +} + +pub(crate) fn add_jitter( + rng: &mut R, + base: Duration, + jitter: Duration, +) -> Duration { + let jitter_millis = jitter.as_millis() as u64; + let random_jitter = rng.next_u64() % (jitter_millis + 1); + base + Duration::from_millis(random_jitter) +} diff --git a/monad-wireauth-session/src/initiator.rs b/monad-wireauth-session/src/initiator.rs new file mode 100644 index 0000000000..5714696292 --- /dev/null +++ b/monad-wireauth-session/src/initiator.rs @@ -0,0 +1,173 @@ +use std::{ + net::SocketAddr, + ops::{Deref, DerefMut}, + time::{Duration, SystemTime}, +}; + +use monad_wireauth_protocol::{ + common::*, + handshake::{self}, + messages::{CookieReply, DataPacketHeader, HandshakeInitiation, HandshakeResponse}, +}; + +use crate::{ + common::{add_jitter, Config, SessionError, SessionState, SessionTimeoutResult}, + transport::TransportState, +}; + +pub struct ValidatedHandshakeResponse { + transport_keys: monad_wireauth_protocol::common::TransportKeys, + remote_index: SessionIndex, +} + +pub struct InitiatorState { + handshake_state: handshake::HandshakeState, + common: SessionState, +} + +impl InitiatorState { + #[allow(clippy::too_many_arguments)] + pub fn new( + rng: &mut R, + system_time: SystemTime, + duration_since_start: Duration, + config: &Config, + local_session_index: SessionIndex, + local_static_key: &PrivateKey, + local_static_public: PublicKey, + remote_static_key: PublicKey, + remote_addr: SocketAddr, + cookie_secret: Option<[u8; 16]>, + retry_attempts: u64, + ) -> Result<(Self, (Duration, HandshakeInitiation)), SessionError> { + let (init_msg, handshake_state) = handshake::send_handshake_init( + rng, + system_time, + local_session_index.as_u32(), + local_static_key, + &SerializedPublicKey::from(&local_static_public), + &SerializedPublicKey::from(&remote_static_key), + cookie_secret.as_ref(), + ) + .map_err(SessionError::InvalidHandshake)?; + + let mac1 = init_msg.mac1.0; + let mut common = SessionState::new( + remote_addr, + remote_static_key, + local_session_index, + duration_since_start, + retry_attempts, + None, + true, + ); + common.stored_cookie = cookie_secret; + common.last_handshake_mac1 = Some(mac1); + + let mut session = InitiatorState { + handshake_state, + common, + }; + + let timeout_with_jitter = + add_jitter(rng, config.session_timeout, config.session_timeout_jitter); + session + .common + .reset_session_timeout(duration_since_start, timeout_with_jitter); + + let timer = session + .common + .get_next_deadline() + .expect("expected at least one timer to be set"); + + Ok((session, (timer, init_msg))) + } + + pub fn validate_response( + &mut self, + config: &Config, + local_static_key: &PrivateKey, + local_static_public: &PublicKey, + msg: &mut HandshakeResponse, + ) -> Result { + let transport_keys = handshake::accept_handshake_response( + local_static_key, + &SerializedPublicKey::from(local_static_public), + msg, + &mut self.handshake_state, + &config.psk, + ) + .map_err(SessionError::InvalidHandshake)?; + + Ok(ValidatedHandshakeResponse { + transport_keys, + remote_index: self.handshake_state.receiver_index.into(), + }) + } + + pub fn establish( + mut self, + rng: &mut R, + config: &Config, + duration_since_start: Duration, + validated_response: ValidatedHandshakeResponse, + _remote_addr: SocketAddr, + ) -> (TransportState, Duration, DataPacketHeader) { + self.common.reset_session_timeout( + duration_since_start, + add_jitter(rng, config.session_timeout, config.session_timeout_jitter), + ); + self.common.reset_rekey( + duration_since_start, + add_jitter(rng, config.rekey_interval, config.rekey_jitter), + ); + self.common + .set_max_session_duration(duration_since_start, config.max_session_duration); + + let mut transport = TransportState::new( + validated_response.remote_index, + validated_response.transport_keys.send_key, + validated_response.transport_keys.recv_key, + self.common, + ); + let (header, timer) = transport.encrypt(config, duration_since_start, &mut []); + (transport, timer, header) + } + + pub fn handle_cookie(&mut self, cookie_reply: &mut CookieReply) -> Result<(), SessionError> { + self.common.handle_cookie(cookie_reply) + } + + pub fn tick( + &mut self, + duration_since_start: Duration, + ) -> Option<(Option, SessionTimeoutResult)> { + let session_timeout_expired = self + .common + .session_timeout_deadline + .is_some_and(|deadline| deadline <= duration_since_start); + + if !session_timeout_expired { + return None; + } + + self.common.clear_session_timeout(); + let (terminated, rekey) = self.handle_session_timeout(); + let timer = self.common.get_next_deadline(); + Some((timer, SessionTimeoutResult { terminated, rekey })) + } +} + +impl Deref for InitiatorState { + type Target = SessionState; + + fn deref(&self) -> &Self::Target { + &self.common + } +} + +impl DerefMut for InitiatorState { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.common + } +} diff --git a/monad-wireauth-session/src/lib.rs b/monad-wireauth-session/src/lib.rs new file mode 100644 index 0000000000..6c4438349b --- /dev/null +++ b/monad-wireauth-session/src/lib.rs @@ -0,0 +1,13 @@ +mod common; +mod replay_filter; + +pub mod initiator; +pub mod responder; +pub mod transport; + +pub use common::*; +pub use initiator::{InitiatorState, ValidatedHandshakeResponse}; +pub use monad_wireauth_protocol::SessionIndex; +pub use replay_filter::ReplayFilter; +pub use responder::{ResponderState, ValidatedHandshakeInit}; +pub use transport::TransportState; diff --git a/monad-wireauth-session/src/replay_filter.rs b/monad-wireauth-session/src/replay_filter.rs new file mode 100644 index 0000000000..858676d80f --- /dev/null +++ b/monad-wireauth-session/src/replay_filter.rs @@ -0,0 +1,447 @@ +use crate::SessionError; + +const REPLAY_WINDOW_SIZE: usize = 32; +const REPLAY_WINDOW_BITS: usize = REPLAY_WINDOW_SIZE * 64; + +pub struct ReplayFilter { + next: u64, + bitmap: [u64; REPLAY_WINDOW_SIZE], +} + +impl Default for ReplayFilter { + fn default() -> Self { + Self::new() + } +} + +impl ReplayFilter { + pub fn new() -> Self { + Self { + next: 0, + bitmap: [0; REPLAY_WINDOW_SIZE], + } + } + + pub fn check(&self, counter: u64) -> Result<(), SessionError> { + if counter >= self.next { + return Ok(()); + } + + if counter.saturating_add(REPLAY_WINDOW_BITS as u64) <= self.next { + return Err(SessionError::ReplayAttack { counter }); + } + + if self.is_set(counter) { + return Err(SessionError::ReplayAttack { counter }); + } + + Ok(()) + } + + pub fn update(&mut self, counter: u64) { + if counter >= self.next { + let gap = counter.saturating_sub(self.next); + if gap >= REPLAY_WINDOW_BITS as u64 { + self.bitmap.iter_mut().for_each(|word| *word = 0); + } else { + (self.next..counter).for_each(|i| self.clear(i)); + } + self.next = counter.saturating_add(1); + } + + self.set(counter); + } + + fn is_set(&self, counter: u64) -> bool { + let bit_idx = counter % REPLAY_WINDOW_BITS as u64; + let word = (bit_idx / 64) as usize; + let bit = (bit_idx % 64) as usize; + ((self.bitmap[word] >> bit) & 1) == 1 + } + + fn set(&mut self, counter: u64) { + let bit_idx = counter % REPLAY_WINDOW_BITS as u64; + let word = (bit_idx / 64) as usize; + let bit = (bit_idx % 64) as usize; + self.bitmap[word] |= 1u64 << bit; + } + + fn clear(&mut self, counter: u64) { + let bit_idx = counter % REPLAY_WINDOW_BITS as u64; + let word = (bit_idx / 64) as usize; + let bit = (bit_idx % 64) as usize; + self.bitmap[word] &= !(1u64 << bit); + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + use rstest::rstest; + + use super::*; + + #[rstest] + #[case(0)] + #[case(1)] + #[case(100)] + #[case(1000)] + fn test_sequential_counters(#[case] start: u64) { + let mut filter = ReplayFilter::new(); + + for i in 0..100 { + let counter = start.saturating_add(i); + assert!(filter.check(counter).is_ok()); + filter.update(counter); + assert_eq!(filter.next, counter + 1); + } + } + + #[rstest] + #[case(0)] + #[case(10)] + #[case(100)] + #[case(1000)] + fn test_duplicate_detection(#[case] counter: u64) { + let mut filter = ReplayFilter::new(); + + assert!(filter.check(counter).is_ok()); + filter.update(counter); + assert!(filter.check(counter).is_err()); + assert!(matches!( + filter.check(counter), + Err(SessionError::ReplayAttack { .. }) + )); + } + + #[rstest] + #[case(vec![0, 2, 4, 6, 8])] + #[case(vec![10, 5, 15, 7, 12])] + #[case(vec![100, 50, 150, 75, 125])] + fn test_out_of_order_within_window(#[case] counters: Vec) { + let mut filter = ReplayFilter::new(); + let mut seen = std::collections::HashSet::new(); + + for counter in counters { + let is_new = seen.insert(counter); + let result = filter.check(counter); + assert_eq!(result.is_ok(), is_new); + if is_new { + filter.update(counter); + } + } + } + + #[rstest] + fn test_window_boundaries() { + let mut filter = ReplayFilter::new(); + + assert!(filter.check(0).is_ok()); + filter.update(0); + assert!(filter.check(REPLAY_WINDOW_BITS as u64 - 1).is_ok()); + filter.update(REPLAY_WINDOW_BITS as u64 - 1); + + assert!(filter.check(0).is_err()); + + assert!(filter.check(REPLAY_WINDOW_BITS as u64).is_ok()); + filter.update(REPLAY_WINDOW_BITS as u64); + + assert!(filter.check(0).is_err()); + } + + #[rstest] + fn test_far_future_counter() { + let mut filter = ReplayFilter::new(); + + assert!(filter.check(0).is_ok()); + filter.update(0); + + let far_future = REPLAY_WINDOW_BITS as u64 * 2; + assert!(filter.check(far_future).is_ok()); + filter.update(far_future); + + assert!(filter.check(0).is_err()); + assert!(filter.check(1).is_err()); + assert!(filter + .check(far_future - REPLAY_WINDOW_BITS as u64) + .is_err()); + + assert!(filter + .check(far_future - REPLAY_WINDOW_BITS as u64 + 1) + .is_err()); + } + + #[rstest] + #[case(0, 1)] + #[case(0, 10)] + #[case(0, REPLAY_WINDOW_BITS as u64 - 1)] + #[case(100, 100 + REPLAY_WINDOW_BITS as u64 - 1)] + fn test_gap_handling(#[case] start: u64, #[case] end: u64) { + let mut filter = ReplayFilter::new(); + + assert!(filter.check(start).is_ok()); + filter.update(start); + assert!(filter.check(end).is_ok()); + filter.update(end); + + (start..=end).for_each(|i| { + let should_accept = i != start && i != end && end - i < REPLAY_WINDOW_BITS as u64; + let result = filter.check(i); + assert_eq!(result.is_ok(), should_accept); + if should_accept { + filter.update(i); + } + }); + } + + #[rstest] + fn test_bitmap_clearing_on_large_jump() { + let mut filter = ReplayFilter::new(); + + for i in 0..10 { + assert!(filter.check(i).is_ok()); + filter.update(i); + } + + let large_jump = REPLAY_WINDOW_BITS as u64 * 2; + assert!(filter.check(large_jump).is_ok()); + filter.update(large_jump); + + let expected = if (large_jump % REPLAY_WINDOW_BITS as u64) < 64 { + 1u64 << (large_jump % REPLAY_WINDOW_BITS as u64) + } else { + 0 + }; + if let Some(word) = filter.bitmap.iter().find(|&&word| word != 0) { + assert_eq!(*word, expected); + } + } + + #[rstest] + #[case(vec![5, 3, 7, 2, 8, 1, 9, 0, 6, 4])] + #[case(vec![100, 95, 105, 90, 110, 85, 115, 80, 120, 75])] + fn test_complex_out_of_order_sequence(#[case] sequence: Vec) { + let mut filter = ReplayFilter::new(); + let mut processed = vec![]; + + for counter in &sequence { + let result = filter.check(*counter); + assert!(result.is_ok(), "Counter {} should be accepted", counter); + filter.update(*counter); + processed.push(*counter); + + for &prev in &processed { + assert!( + filter.check(prev).is_err(), + "Previously seen counter {} should be rejected", + prev + ); + } + } + } + + #[rstest] + fn test_window_wraparound() { + let mut filter = ReplayFilter::new(); + + for i in 0..REPLAY_WINDOW_BITS * 3 { + let counter = i as u64; + assert!(filter.check(counter).is_ok()); + filter.update(counter); + + let should_check_old = i > 0 + && counter >= REPLAY_WINDOW_BITS as u64 + && (i - 1) as u64 + REPLAY_WINDOW_BITS as u64 <= counter; + + if should_check_old { + let old_counter = (i - 1) as u64; + assert!(filter.check(old_counter).is_err()); + } + } + } + + #[rstest] + #[case(0, REPLAY_WINDOW_BITS as u64)] + #[case(1000, 1000 + REPLAY_WINDOW_BITS as u64)] + fn test_edge_of_window(#[case] start: u64, #[case] boundary: u64) { + let mut filter = ReplayFilter::new(); + + assert!(filter.check(boundary).is_ok()); + filter.update(boundary); + + let should_reject_start = start + REPLAY_WINDOW_BITS as u64 <= boundary; + if should_reject_start { + assert!(filter.check(start).is_err()); + } + + let should_check_within = boundary > 0 && boundary >= REPLAY_WINDOW_BITS as u64; + if should_check_within { + let within_window = boundary - REPLAY_WINDOW_BITS as u64 + 1; + assert!(filter.check(within_window).is_err()); + } + } + + proptest! { + #[test] + fn prop_no_duplicate_acceptance(counters in prop::collection::vec(0u64..10000, 1..100)) { + let mut filter = ReplayFilter::new(); + let mut seen = std::collections::HashSet::new(); + + for counter in counters { + let result = filter.check(counter); + let is_duplicate = seen.contains(&counter); + + if is_duplicate { + prop_assert!(result.is_err(), "Duplicate counter {} should be rejected", counter); + continue; + } + + if result.is_ok() { + filter.update(counter); + seen.insert(counter); + } + + for &prev in &seen { + prop_assert!(filter.check(prev).is_err(), "Previously seen counter {} should be rejected", prev); + } + } + } + + #[test] + fn prop_sequential_always_accepted(start in any::(), len in 1usize..1000) { + let mut filter = ReplayFilter::new(); + + for i in 0..len { + let counter = start.saturating_add(i as u64); + prop_assert!(filter.check(counter).is_ok()); + filter.update(counter); + } + } + + #[test] + fn prop_old_counters_rejected( + current in 1000u64..u64::MAX - REPLAY_WINDOW_BITS as u64, + old_offset in (REPLAY_WINDOW_BITS as u64 + 1)..10000u64 + ) { + let mut filter = ReplayFilter::new(); + + prop_assert!(filter.check(current).is_ok()); + filter.update(current); + + let old_counter = current.saturating_sub(old_offset); + prop_assert!(filter.check(old_counter).is_err()); + } + + #[test] + fn prop_window_consistency(counters in prop::collection::vec(0u64..1000, 10..100)) { + let mut filter = ReplayFilter::new(); + let mut accepted = vec![]; + + for counter in counters { + if filter.check(counter).is_ok() { + filter.update(counter); + accepted.push(counter); + + let max_accepted = *accepted.iter().max().unwrap(); + for &prev in &accepted { + if prev + REPLAY_WINDOW_BITS as u64 > max_accepted { + prop_assert!(filter.check(prev).is_err()); + } + } + } + } + } + + #[test] + fn prop_bitmap_integrity(operations in prop::collection::vec((0u64..REPLAY_WINDOW_BITS as u64 * 2, any::()), 1..200)) { + let mut filter = ReplayFilter::new(); + let mut expected_set = std::collections::HashSet::new(); + + for (counter, should_accept) in operations { + let is_new = !expected_set.contains(&counter); + if !should_accept || !is_new { + continue; + } + + if counter >= REPLAY_WINDOW_BITS as u64 { + expected_set.clear(); + } + let result = filter.check(counter); + if result.is_ok() { + filter.update(counter); + expected_set.insert(counter); + } + } + + for counter in expected_set.iter() { + if *counter + REPLAY_WINDOW_BITS as u64 > filter.next { + prop_assert!(filter.is_set(*counter)); + } + } + } + + #[test] + fn prop_monotonic_next_counter(mut counters in prop::collection::vec(0u64..10000, 1..100)) { + counters.sort_unstable(); + counters.dedup(); + + let mut filter = ReplayFilter::new(); + let mut prev_next = 0u64; + + for counter in counters { + if filter.check(counter).is_ok() { + filter.update(counter); + prop_assert!(filter.next > prev_next); + prop_assert!(filter.next > counter); + prev_next = filter.next; + } + } + } + + #[test] + fn prop_random_sequence_consistency(sequence in prop::collection::vec(0u64..1000, 1..500)) { + let mut filter1 = ReplayFilter::new(); + let mut filter2 = ReplayFilter::new(); + + let mut accepted1 = vec![]; + for &counter in &sequence { + if filter1.check(counter).is_ok() { + filter1.update(counter); + accepted1.push(counter); + } + } + + let mut accepted2 = vec![]; + for &counter in &sequence { + if filter2.check(counter).is_ok() { + filter2.update(counter); + accepted2.push(counter); + } + } + + prop_assert_eq!(accepted1, accepted2); + } + + #[test] + fn prop_large_gap_clears_bitmap( + start in 0u64..1000, + gap_multiplier in 2u64..10 + ) { + let mut filter = ReplayFilter::new(); + + for i in 0..10 { + prop_assert!(filter.check(start + i).is_ok()); + filter.update(start + i); + } + + let large_jump = start + REPLAY_WINDOW_BITS as u64 * gap_multiplier; + prop_assert!(filter.check(large_jump).is_ok()); + filter.update(large_jump); + + for i in 0..10 { + let old_counter = start + i; + prop_assert!(filter.check(old_counter).is_err()); + } + } + } +} diff --git a/monad-wireauth-session/src/responder.rs b/monad-wireauth-session/src/responder.rs new file mode 100644 index 0000000000..ad934faf88 --- /dev/null +++ b/monad-wireauth-session/src/responder.rs @@ -0,0 +1,184 @@ +use std::{ + net::SocketAddr, + ops::{Deref, DerefMut}, + time::Duration, +}; + +use monad_wireauth_protocol::{ + common::*, + handshake::{self}, + messages::{DataPacket, HandshakeInitiation, HandshakeResponse}, +}; + +use crate::{ + common::{add_jitter, Config, SessionError, SessionState, SessionTimeoutResult}, + transport::TransportState, +}; + +pub struct ValidatedHandshakeInit { + pub handshake_state: monad_wireauth_protocol::handshake::HandshakeState, + pub remote_public_key: PublicKey, + pub system_time: std::time::SystemTime, + pub remote_index: SessionIndex, +} + +pub struct ResponderState { + transport: TransportState, +} + +impl Deref for ResponderState { + type Target = SessionState; + + fn deref(&self) -> &Self::Target { + &self.transport + } +} + +impl DerefMut for ResponderState { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.transport + } +} + +impl ResponderState { + pub fn validate_init( + local_static_key: &PrivateKey, + local_static_public: &PublicKey, + handshake_packet: &mut HandshakeInitiation, + ) -> Result { + let (handshake_state, system_time) = handshake::accept_handshake_init( + local_static_key, + &SerializedPublicKey::from(local_static_public), + handshake_packet, + ) + .map_err(SessionError::InvalidHandshake)?; + + let remote_public_key: PublicKey = handshake_state + .remote_static + .as_ref() + .expect("remote static key must be set") + .try_into() + .map_err(|e: monad_wireauth_protocol::errors::CryptoError| { + SessionError::InvalidHandshake(e.into()) + })?; + + let remote_index = handshake_state.receiver_index.into(); + + Ok(ValidatedHandshakeInit { + handshake_state, + remote_public_key, + system_time, + remote_index, + }) + } + + pub fn new( + rng: &mut R, + duration_since_start: Duration, + config: &Config, + local_session_index: SessionIndex, + stored_cookie: Option<&[u8; 16]>, + validated_init: ValidatedHandshakeInit, + remote_addr: SocketAddr, + ) -> Result<(ResponderState, Duration, HandshakeResponse), SessionError> { + let mut handshake_state = validated_init.handshake_state; + let (response_msg, transport_keys) = handshake::send_handshake_response( + rng, + local_session_index.as_u32(), + &mut handshake_state, + &config.psk, + stored_cookie, + ) + .map_err(SessionError::InvalidHandshake)?; + + let response_mac1 = response_msg.mac1.0; + + let mut common = SessionState::new( + remote_addr, + validated_init.remote_public_key, + local_session_index, + duration_since_start, + 0, + Some(validated_init.system_time), + false, + ); + common.last_handshake_mac1 = Some(response_mac1); + + let timeout_with_jitter = + add_jitter(rng, config.session_timeout, config.session_timeout_jitter); + common.reset_session_timeout(duration_since_start, timeout_with_jitter); + + let timer = common + .get_next_deadline() + .expect("expected at least one timer to be set"); + + let transport = TransportState::new( + handshake_state.receiver_index.into(), + transport_keys.send_key, + transport_keys.recv_key, + common, + ); + Ok((ResponderState { transport }, timer, response_msg)) + } + + pub fn decrypt( + &mut self, + config: &Config, + duration_since_start: Duration, + data_packet: DataPacket, + ) -> Result { + self.transport + .decrypt(config, duration_since_start, data_packet) + } + + pub fn establish( + mut self, + rng: &mut R, + config: &Config, + duration_since_start: Duration, + ) -> (TransportState, Duration) { + self.transport + .common + .reset_session_timeout(duration_since_start, config.session_timeout); + self.transport.common.reset_keepalive( + duration_since_start, + add_jitter(rng, config.keepalive_interval, config.keepalive_jitter), + ); + self.transport + .common + .set_max_session_duration(duration_since_start, config.max_session_duration); + + let timer = self + .transport + .common + .get_next_deadline() + .expect("expected at least one timer to be set"); + + (self.transport, timer) + } + + pub fn tick( + &mut self, + duration_since_start: Duration, + ) -> Option<(Option, SessionTimeoutResult)> { + let session_timeout_expired = self + .session_timeout_deadline + .is_some_and(|deadline| deadline <= duration_since_start); + + if !session_timeout_expired { + return None; + } + + self.clear_session_timeout(); + let (terminated, rekey) = self.handle_session_timeout(); + let timer = self.get_next_deadline(); + Some((timer, SessionTimeoutResult { terminated, rekey })) + } + + pub fn handle_cookie( + &mut self, + cookie_reply: &mut monad_wireauth_protocol::messages::CookieReply, + ) -> Result<(), SessionError> { + self.transport.common.handle_cookie(cookie_reply) + } +} diff --git a/monad-wireauth-session/src/transport.rs b/monad-wireauth-session/src/transport.rs new file mode 100644 index 0000000000..99b7c2d11f --- /dev/null +++ b/monad-wireauth-session/src/transport.rs @@ -0,0 +1,206 @@ +use std::{ + ops::{Deref, DerefMut}, + time::Duration, +}; + +use monad_wireauth_protocol::{ + common::{CipherKey, SessionIndex}, + messages::{DataPacket, DataPacketHeader}, +}; +use tracing::debug; + +use crate::{ + common::{Config, MessageEvent, RekeyEvent, SessionError, SessionState, TerminatedEvent}, + replay_filter::ReplayFilter, +}; + +pub struct TransportState { + pub remote_index: SessionIndex, + pub send_key: CipherKey, + pub send_nonce: u64, + pub recv_key: CipherKey, + pub replay_filter: ReplayFilter, + pub common: SessionState, +} + +impl TransportState { + pub fn new( + remote_index: SessionIndex, + send_key: CipherKey, + recv_key: CipherKey, + common: SessionState, + ) -> Self { + TransportState { + remote_index, + send_key, + send_nonce: 0, + recv_key, + replay_filter: ReplayFilter::new(), + common, + } + } + + pub fn encrypt( + &mut self, + config: &Config, + duration_since_start: Duration, + plaintext: &mut [u8], + ) -> (DataPacketHeader, Duration) { + use monad_wireauth_protocol::crypto; + + let header = DataPacketHeader { + receiver_index: self.remote_index.as_u32().into(), + counter: self.send_nonce.into(), + tag: crypto::encrypt_in_place(&self.send_key, &self.send_nonce.into(), plaintext, &[]), + ..Default::default() + }; + + self.send_nonce += 1; + + self.common + .reset_keepalive(duration_since_start, config.keepalive_interval); + let timer = self + .common + .get_next_deadline() + .expect("expected at least one timer to be set"); + (header, timer) + } + + pub fn decrypt( + &mut self, + config: &Config, + duration_since_start: Duration, + data_packet: DataPacket, + ) -> Result { + use monad_wireauth_protocol::crypto; + + self.replay_filter.check(data_packet.header.counter.get())?; + + crypto::decrypt_in_place( + &self.recv_key, + &data_packet.header.counter.get().into(), + data_packet.plaintext, + &data_packet.header.tag, + &[], + ) + .map_err(SessionError::InvalidMac)?; + + self.replay_filter.update(data_packet.header.counter.get()); + + self.common + .reset_session_timeout(duration_since_start, config.session_timeout); + let timer = self + .common + .get_next_deadline() + .expect("expected at least one timer to be set"); + Ok(timer) + } + + #[allow(clippy::type_complexity)] + pub fn tick( + &mut self, + config: &Config, + duration_since_start: Duration, + ) -> ( + Option, + Option, + Option, + Option, + ) { + let mut message = None; + let mut rekey = None; + let mut terminated = None; + + let keepalive_expired = self + .common + .keepalive_deadline + .is_some_and(|deadline| deadline <= duration_since_start); + if keepalive_expired { + self.common.clear_keepalive(); + debug!( + duration_since_start = ?duration_since_start, + remote_addr = ?self.common.remote_addr, + "sending keepalive packet" + ); + let (header, _) = self.encrypt(config, duration_since_start, &mut []); + message = Some(MessageEvent { + remote_addr: self.common.remote_addr, + header, + }); + } + + let rekey_expired = self + .common + .rekey_deadline + .is_some_and(|deadline| deadline <= duration_since_start); + if rekey_expired { + self.common.clear_rekey(); + debug!( + remote_addr = ?self.common.remote_addr, + "rekey timer expired" + ); + rekey = Some(RekeyEvent { + remote_public_key: self.common.remote_public_key.clone(), + remote_addr: self.common.remote_addr, + retry_attempts: self.common.retry_attempts, + stored_cookie: self.common.stored_cookie, + }); + } + + let session_timeout_expired = self + .common + .session_timeout_deadline + .is_some_and(|deadline| deadline <= duration_since_start); + if session_timeout_expired { + self.common.clear_session_timeout(); + + debug!( + remote_addr = ?self.common.remote_addr, + "session timeout expired" + ); + + let (terminated_event, rekey_event) = self.common.handle_session_timeout(); + terminated = Some(terminated_event); + rekey = rekey.or(rekey_event); + } + + let max_session_duration_expired = self + .common + .max_session_duration_deadline + .is_some_and(|deadline| deadline <= duration_since_start); + if max_session_duration_expired { + self.common.clear_max_session_duration(); + + debug!( + remote_addr = ?self.common.remote_addr, + "max session duration expired" + ); + + let (terminated_event, _) = self.common.handle_session_timeout(); + terminated = Some(terminated_event); + rekey = None; + } + + let next_timer = self.common.get_next_deadline(); + (next_timer, message, rekey, terminated) + } + + #[cfg(any(test, feature = "bench"))] + pub fn reset_replay_filter(&mut self) { + self.replay_filter = ReplayFilter::new(); + } +} + +impl Deref for TransportState { + type Target = SessionState; + + fn deref(&self) -> &Self::Target { + &self.common + } +} + +impl DerefMut for TransportState { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.common + } +} diff --git a/monad-wireauth-session/tests/tests.rs b/monad-wireauth-session/tests/tests.rs new file mode 100644 index 0000000000..d0bff195b6 --- /dev/null +++ b/monad-wireauth-session/tests/tests.rs @@ -0,0 +1,699 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::{Duration, SystemTime}, +}; + +use monad_wireauth_protocol::{ + common::{PrivateKey, PublicKey, SerializedPublicKey, SessionIndex}, + cookies, + messages::DataPacket, +}; +use monad_wireauth_session::*; +use secp256k1::rand::{rngs::StdRng, SeedableRng}; + +struct TestEnv { + rng: StdRng, + time: Duration, + system_time: SystemTime, + config: Config, +} + +impl TestEnv { + fn new() -> Self { + Self { + rng: StdRng::seed_from_u64(42), + time: Duration::ZERO, + system_time: SystemTime::UNIX_EPOCH, + config: Config { + session_timeout: Duration::from_secs(10), + session_timeout_jitter: Duration::ZERO, + keepalive_interval: Duration::from_secs(3), + keepalive_jitter: Duration::ZERO, + rekey_interval: Duration::from_secs(60), + rekey_jitter: Duration::ZERO, + max_session_duration: Duration::from_secs(70), + ..Default::default() + }, + } + } + + fn advance(&mut self, duration: Duration) { + self.time += duration; + self.system_time += duration; + } +} + +struct Peer { + static_key: PrivateKey, + static_public: PublicKey, + addr: SocketAddr, + session_index: SessionIndex, +} + +impl Peer { + fn new(port: u16, index: u32) -> Self { + let mut rng = StdRng::seed_from_u64(port as u64); + let (static_public, static_key) = + monad_wireauth_protocol::crypto::generate_keypair(&mut rng).unwrap(); + Self { + static_key, + static_public, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port), + session_index: SessionIndex::new(index), + } + } +} + +//1. alice initiates handshake +//2. bob validates initiation +//3. bob sends response +//4. alice validates response +//5. alice establishes transport +//6. bob establishes transport +//7. alice encrypts data +//8. bob decrypts data +#[test] +fn test_handshake_and_data_exchange() { + let mut env = TestEnv::new(); + let alice = Peer::new(8001, 1); + let bob = Peer::new(8002, 2); + + let (mut initiator, (_timer, mut init_msg)) = InitiatorState::new( + &mut env.rng, + env.system_time, + env.time, + &env.config, + alice.session_index, + &alice.static_key, + alice.static_public.clone(), + bob.static_public.clone(), + bob.addr, + None, + DEFAULT_RETRY_ATTEMPTS, + ) + .unwrap(); + + let validated_init = + ResponderState::validate_init(&bob.static_key, &bob.static_public, &mut init_msg).unwrap(); + + let (responder, _timer, mut resp_msg) = ResponderState::new( + &mut env.rng, + env.time, + &env.config, + bob.session_index, + None, + validated_init, + alice.addr, + ) + .unwrap(); + + let validated_resp = initiator + .validate_response( + &env.config, + &alice.static_key, + &alice.static_public, + &mut resp_msg, + ) + .unwrap(); + + let (mut initiator_transport, _timer, _header) = initiator.establish( + &mut env.rng, + &env.config, + env.time, + validated_resp, + bob.addr, + ); + + let (mut responder_transport, _timer) = + responder.establish(&mut env.rng, &env.config, env.time); + + let mut plaintext = b"hello world".to_vec(); + let (header, _timer) = initiator_transport.encrypt(&env.config, env.time, &mut plaintext); + + let data_packet = DataPacket { + header: &header, + plaintext: &mut plaintext, + }; + + let _timer = responder_transport + .decrypt(&env.config, env.time, data_packet) + .unwrap(); + assert_eq!(&plaintext, b"hello world"); +} + +//1. alice initiates handshake +//2. bob validates initiation +//3. bob sends response +//4. alice validates response +//5. alice establishes transport +//6. bob establishes transport +//7. advance time by keepalive interval -> bob tick triggers first keepalive +//9. advance time by keepalive interval -> bob tick triggers second keepalive +#[test] +fn test_keepalive_sends_twice() { + let mut env = TestEnv::new(); + let alice = Peer::new(8001, 1); + let bob = Peer::new(8002, 2); + + let (mut initiator, (_timer, mut init_msg)) = InitiatorState::new( + &mut env.rng, + env.system_time, + env.time, + &env.config, + alice.session_index, + &alice.static_key, + alice.static_public.clone(), + bob.static_public.clone(), + bob.addr, + None, + DEFAULT_RETRY_ATTEMPTS, + ) + .unwrap(); + + let validated_init = + ResponderState::validate_init(&bob.static_key, &bob.static_public, &mut init_msg).unwrap(); + + let (responder, _timer, mut resp_msg) = ResponderState::new( + &mut env.rng, + env.time, + &env.config, + bob.session_index, + None, + validated_init, + alice.addr, + ) + .unwrap(); + + let validated_resp = initiator + .validate_response( + &env.config, + &alice.static_key, + &alice.static_public, + &mut resp_msg, + ) + .unwrap(); + + let (_initiator_transport, _timer, _header) = initiator.establish( + &mut env.rng, + &env.config, + env.time, + validated_resp, + bob.addr, + ); + + let (mut responder_transport, _timer) = + responder.establish(&mut env.rng, &env.config, env.time); + + env.advance(env.config.keepalive_interval); + + let result = responder_transport.tick(&env.config, env.time); + assert!(result.1.is_some()); + let message_event = result.1.unwrap(); + assert_eq!(message_event.remote_addr, alice.addr); + + env.advance(env.config.keepalive_interval); + + let result = responder_transport.tick(&env.config, env.time); + assert!(result.1.is_some()); + let message_event = result.1.unwrap(); + assert_eq!(message_event.remote_addr, alice.addr); +} + +//1. alice initiates handshake +//2. bob validates initiation +//3. bob sends response +//4. alice validates response +//5. alice establishes transport +//6. bob establishes transport +//7. advance time by keepalive interval -> bob tick triggers keepalive +//8. advance time by session timeout -> alice tick triggers session timeout and rekey +#[test] +fn test_session_timeout_triggers_rekey_after_keepalive() { + let mut env = TestEnv::new(); + let alice = Peer::new(8001, 1); + let bob = Peer::new(8002, 2); + + let (mut initiator, (_timer, mut init_msg)) = InitiatorState::new( + &mut env.rng, + env.system_time, + env.time, + &env.config, + alice.session_index, + &alice.static_key, + alice.static_public.clone(), + bob.static_public.clone(), + bob.addr, + None, + DEFAULT_RETRY_ATTEMPTS, + ) + .unwrap(); + + let validated_init = + ResponderState::validate_init(&bob.static_key, &bob.static_public, &mut init_msg).unwrap(); + + let (responder, _timer, mut resp_msg) = ResponderState::new( + &mut env.rng, + env.time, + &env.config, + bob.session_index, + None, + validated_init, + alice.addr, + ) + .unwrap(); + + let validated_resp = initiator + .validate_response( + &env.config, + &alice.static_key, + &alice.static_public, + &mut resp_msg, + ) + .unwrap(); + + let (mut initiator_transport, _timer, _header) = initiator.establish( + &mut env.rng, + &env.config, + env.time, + validated_resp, + bob.addr, + ); + + let (mut responder_transport, _timer) = + responder.establish(&mut env.rng, &env.config, env.time); + + env.advance(env.config.keepalive_interval); + + let result = responder_transport.tick(&env.config, env.time); + assert!(result.1.is_some()); + + env.advance(env.config.session_timeout); + + let result = initiator_transport.tick(&env.config, env.time); + assert!(result.2.is_some()); + let rekey_event = result.2.unwrap(); + assert_eq!(rekey_event.remote_addr, bob.addr); +} + +//1. alice initiates handshake +//2. advance time by session timeout -> alice tick triggers rekey with retry attempts decremented +#[test] +fn test_initiator_timeout_triggers_rekey() { + let mut env = TestEnv::new(); + let alice = Peer::new(8001, 1); + let bob = Peer::new(8002, 2); + + let (mut initiator, (_timer, _init_msg)) = InitiatorState::new( + &mut env.rng, + env.system_time, + env.time, + &env.config, + alice.session_index, + &alice.static_key, + alice.static_public.clone(), + bob.static_public.clone(), + bob.addr, + None, + DEFAULT_RETRY_ATTEMPTS, + ) + .unwrap(); + + env.advance(env.config.session_timeout); + + let result = initiator.tick(env.time).unwrap(); + assert!(result.1.rekey.is_some()); + let rekey_event = result.1.rekey.unwrap(); + assert_eq!(rekey_event.remote_addr, bob.addr); + assert_eq!(rekey_event.retry_attempts, DEFAULT_RETRY_ATTEMPTS - 1); +} + +//1. alice initiates handshake +//2. bob validates initiation and creates responder +//3. advance time by session timeout -> bob tick triggers termination without rekey +#[test] +fn test_responder_timeout_terminates_session() { + let mut env = TestEnv::new(); + let alice = Peer::new(8001, 1); + let bob = Peer::new(8002, 2); + + let (_initiator, (_timer, mut init_msg)) = InitiatorState::new( + &mut env.rng, + env.system_time, + env.time, + &env.config, + alice.session_index, + &alice.static_key, + alice.static_public.clone(), + bob.static_public.clone(), + bob.addr, + None, + DEFAULT_RETRY_ATTEMPTS, + ) + .unwrap(); + + let validated_init = + ResponderState::validate_init(&bob.static_key, &bob.static_public, &mut init_msg).unwrap(); + + let (mut responder, _timer, _resp_msg) = ResponderState::new( + &mut env.rng, + env.time, + &env.config, + bob.session_index, + None, + validated_init, + alice.addr, + ) + .unwrap(); + + env.advance(env.config.session_timeout); + + let result = responder.tick(env.time).unwrap(); + assert!(result.1.rekey.is_none()); + assert_eq!(result.1.terminated.remote_addr, alice.addr); +} + +//1. alice initiates handshake +//2. alice receives and stores cookie reply +//3. advance time by session timeout -> alice tick triggers rekey with stored cookie +//4. alice re-initiates with stored cookie +//5. bob validates initiation and sends response +//6. alice validates response and establishes transport +#[test] +fn test_cookie_stored_and_reused_after_timeout() { + let mut env = TestEnv::new(); + let alice = Peer::new(8001, 1); + let bob = Peer::new(8002, 2); + + let (mut initiator, (_timer, init_msg)) = InitiatorState::new( + &mut env.rng, + env.system_time, + env.time, + &env.config, + alice.session_index, + &alice.static_key, + alice.static_public.clone(), + bob.static_public.clone(), + bob.addr, + None, + DEFAULT_RETRY_ATTEMPTS, + ) + .unwrap(); + + let mac1 = init_msg.mac1.0; + let cookie_secret = [0u8; 32]; + let nonce = 0u64; + let cookie = cookies::generate_cookie(&cookie_secret, nonce, &alice.addr); + + let nonce_secret = [0u8; 32]; + let nonce_counter = 0u128; + let mut cookie_reply = cookies::send_cookie_reply( + &nonce_secret, + nonce_counter, + &SerializedPublicKey::from(&bob.static_public), + alice.session_index.as_u32(), + &mac1, + &cookie, + ) + .unwrap(); + + initiator.handle_cookie(&mut cookie_reply).unwrap(); + let stored_cookie = initiator.stored_cookie().unwrap(); + + env.advance(env.config.session_timeout); + + let result = initiator.tick(env.time).unwrap(); + assert!(result.1.rekey.is_some()); + let rekey_event = result.1.rekey.unwrap(); + assert_eq!(rekey_event.stored_cookie, Some(stored_cookie)); + + let (mut initiator2, (_timer, mut init_msg2)) = InitiatorState::new( + &mut env.rng, + env.system_time, + env.time, + &env.config, + alice.session_index, + &alice.static_key, + alice.static_public.clone(), + bob.static_public.clone(), + bob.addr, + Some(stored_cookie), + DEFAULT_RETRY_ATTEMPTS - 1, + ) + .unwrap(); + + let validated_init = + ResponderState::validate_init(&bob.static_key, &bob.static_public, &mut init_msg2).unwrap(); + + let (_responder, _timer, mut resp_msg) = ResponderState::new( + &mut env.rng, + env.time, + &env.config, + bob.session_index, + None, + validated_init, + alice.addr, + ) + .unwrap(); + + let validated_resp = initiator2 + .validate_response( + &env.config, + &alice.static_key, + &alice.static_public, + &mut resp_msg, + ) + .unwrap(); + + let (_initiator_transport, _timer, _header) = initiator2.establish( + &mut env.rng, + &env.config, + env.time, + validated_resp, + bob.addr, + ); +} + +//1. alice initiates handshake with cookie +//2. bob creates responder with stored cookie +//3. verify response has non-zero mac2 +#[test] +fn test_rekey_without_cookie_receives_cookie_reply() { + let mut env = TestEnv::new(); + let alice = Peer::new(8001, 1); + let bob = Peer::new(8002, 2); + + let cookie_secret = [0u8; 32]; + let nonce = 0u64; + let cookie = cookies::generate_cookie(&cookie_secret, nonce, &alice.addr); + + let (_initiator, (_timer, mut init_msg)) = InitiatorState::new( + &mut env.rng, + env.system_time, + env.time, + &env.config, + alice.session_index, + &alice.static_key, + alice.static_public.clone(), + bob.static_public.clone(), + bob.addr, + Some(cookie), + DEFAULT_RETRY_ATTEMPTS, + ) + .unwrap(); + + let validated_init = + ResponderState::validate_init(&bob.static_key, &bob.static_public, &mut init_msg).unwrap(); + + let (_responder, _timer, resp_msg) = ResponderState::new( + &mut env.rng, + env.time, + &env.config, + bob.session_index, + Some(&cookie), + validated_init, + alice.addr, + ) + .unwrap(); + + assert_ne!(resp_msg.mac2.0, [0u8; 16]); +} + +//1. alice initiates handshake with 0 retries +//2. bob validates initiation and sends response +//3. alice validates response and establishes transport +//4. bob establishes transport +//5. advance time to trigger multiple keepalives +//6. advance time to rekey interval -> alice tick triggers rekey without terminate +#[test] +fn test_rekey_interval_with_zero_retries() { + let mut env = TestEnv::new(); + let alice = Peer::new(8001, 1); + let bob = Peer::new(8002, 2); + + let (mut initiator, (_timer, mut init_msg)) = InitiatorState::new( + &mut env.rng, + env.system_time, + env.time, + &env.config, + alice.session_index, + &alice.static_key, + alice.static_public.clone(), + bob.static_public.clone(), + bob.addr, + None, + 0, + ) + .unwrap(); + + let validated_init = + ResponderState::validate_init(&bob.static_key, &bob.static_public, &mut init_msg).unwrap(); + + let (responder, _timer, mut resp_msg) = ResponderState::new( + &mut env.rng, + env.time, + &env.config, + bob.session_index, + None, + validated_init, + alice.addr, + ) + .unwrap(); + + let validated_resp = initiator + .validate_response( + &env.config, + &alice.static_key, + &alice.static_public, + &mut resp_msg, + ) + .unwrap(); + + let (mut initiator_transport, _timer, _header) = initiator.establish( + &mut env.rng, + &env.config, + env.time, + validated_resp, + bob.addr, + ); + + let (mut responder_transport, _timer) = + responder.establish(&mut env.rng, &env.config, env.time); + + let keepalive_count = + (env.config.rekey_interval.as_secs() / env.config.keepalive_interval.as_secs()) as usize; + for _ in 0..keepalive_count { + env.advance(env.config.keepalive_interval); + let result = responder_transport.tick(&env.config, env.time); + assert!(result.1.is_some()); + let message_event = result.1.unwrap(); + let mut plaintext = vec![]; + let data_packet = DataPacket { + header: &message_event.header, + plaintext: &mut plaintext, + }; + let _timer = initiator_transport + .decrypt(&env.config, env.time, data_packet) + .unwrap(); + } + + let result = initiator_transport.tick(&env.config, env.time); + assert!(result.2.is_some()); + assert!(result.3.is_none()); + let rekey_event = result.2.unwrap(); + assert_eq!(rekey_event.remote_addr, bob.addr); + assert_eq!(rekey_event.retry_attempts, 0); +} + +//1. alice initiates handshake +//2. bob validates initiation and sends response +//3. alice validates response and establishes transport +//4. bob establishes transport +//5. advance time to rekey interval -> alice tick triggers rekey +//6. send keepalives during post-rekey period +//7. advance time to max session duration -> alice tick triggers termination despite keapalives +#[test] +fn test_max_session_duration_terminates_after_rekey() { + let mut env = TestEnv::new(); + let alice = Peer::new(8001, 1); + let bob = Peer::new(8002, 2); + + let (mut initiator, (_timer, mut init_msg)) = InitiatorState::new( + &mut env.rng, + env.system_time, + env.time, + &env.config, + alice.session_index, + &alice.static_key, + alice.static_public.clone(), + bob.static_public.clone(), + bob.addr, + None, + DEFAULT_RETRY_ATTEMPTS, + ) + .unwrap(); + + let validated_init = + ResponderState::validate_init(&bob.static_key, &bob.static_public, &mut init_msg).unwrap(); + + let (responder, _timer, mut resp_msg) = ResponderState::new( + &mut env.rng, + env.time, + &env.config, + bob.session_index, + None, + validated_init, + alice.addr, + ) + .unwrap(); + + let validated_resp = initiator + .validate_response( + &env.config, + &alice.static_key, + &alice.static_public, + &mut resp_msg, + ) + .unwrap(); + + let (mut initiator_transport, _timer, _header) = initiator.establish( + &mut env.rng, + &env.config, + env.time, + validated_resp, + bob.addr, + ); + + let (mut responder_transport, _timer) = + responder.establish(&mut env.rng, &env.config, env.time); + + env.advance(env.config.rekey_interval); + + let result = initiator_transport.tick(&env.config, env.time); + assert!(result.2.is_some()); + + let remaining_time = env.config.max_session_duration - env.config.rekey_interval; + let keepalive_count = remaining_time.as_secs() / env.config.keepalive_interval.as_secs(); + + for _ in 0..keepalive_count { + env.advance(env.config.keepalive_interval); + let result = responder_transport.tick(&env.config, env.time); + assert!(result.1.is_some()); + let message_event = result.1.unwrap(); + let mut plaintext = vec![]; + let data_packet = DataPacket { + header: &message_event.header, + plaintext: &mut plaintext, + }; + let _timer = initiator_transport + .decrypt(&env.config, env.time, data_packet) + .unwrap(); + } + + env.advance(Duration::from_secs(1)); + + let result = initiator_transport.tick(&env.config, env.time); + assert!(result.2.is_none()); + assert!(result.3.is_some()); + let terminated_event = result.3.unwrap(); + assert_eq!(terminated_event.remote_addr, bob.addr); +}