diff --git a/Cargo.lock b/Cargo.lock index 982eb95..e18c80f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -207,7 +207,7 @@ dependencies = [ "rand 0.8.5", "secp256k1", "sha2", - "thiserror", + "thiserror 1.0.69", "x25519-dalek", ] @@ -228,7 +228,7 @@ dependencies = [ "itertools 0.11.0", "paste", "ssh-key", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -253,7 +253,7 @@ checksum = "5ae105d819b82cb8a68e7c3a9c63a2080401f6f18be48e5fc0095a2e1db274e4" dependencies = [ "bc-crypto", "bc-rand", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -264,7 +264,7 @@ checksum = "0eced5ba4f1321a74faac8a346d1d527a76ecec9fda7d746da22a43e923201a3" dependencies = [ "anyhow", "dcbor", - "thiserror", + "thiserror 1.0.69", "ur", ] @@ -355,8 +355,10 @@ dependencies = [ name = "btp" version = "0.1.0" dependencies = [ + "bytemuck", "consts", "rand 0.9.1", + "thiserror 2.0.12", ] [[package]] @@ -376,6 +378,20 @@ name = "bytemuck" version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ecc273b49b3205b83d648f0690daa588925572cc5063745bfe547fe7ec8e1a1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.98", +] [[package]] name = "byteorder" @@ -601,7 +617,7 @@ dependencies = [ "chrono", "half", "hex", - "thiserror", + "thiserror 1.0.69", "unicode-normalization", ] @@ -857,7 +873,7 @@ dependencies = [ "gstp", "indoc", "nu-ansi-term", - "thiserror", + "thiserror 1.0.69", "tokio", ] @@ -2317,7 +2333,7 @@ checksum = "c2bb82d84810c92ae464197c90f38f2bad403032d5a6b55ec746ce06d19eb867" dependencies = [ "bc-rand", "bc-shamir", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -2377,7 +2393,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl 2.0.12", ] [[package]] @@ -2391,6 +2416,17 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.98", +] + [[package]] name = "threadpool" version = "1.8.1" diff --git a/abstracted/src/abstracted/abstract_bluetooth.rs b/abstracted/src/abstracted/abstract_bluetooth.rs index f8549db..dd53006 100644 --- a/abstracted/src/abstracted/abstract_bluetooth.rs +++ b/abstracted/src/abstracted/abstract_bluetooth.rs @@ -25,7 +25,7 @@ pub trait AbstractBluetoothChannel { let cbor = envelope.to_cbor_data(); for chunk in chunk(&cbor) { - self.send(chunk).await.expect("couldn't send"); + self.send(chunk).await?; } Ok(()) @@ -33,14 +33,9 @@ pub trait AbstractBluetoothChannel { async fn receive_envelope(&self, timeout: Duration) -> Result { let mut unchunker = Dechunker::new(); - loop { + while !unchunker.is_complete() { let bytes = self.receive(timeout).await?; - println!("Received {} bytes over BLE", bytes.len()); unchunker.receive(&bytes)?; - - if unchunker.is_complete() { - break; - } } let message = unchunker.data().expect("data is complete"); diff --git a/btp/Cargo.toml b/btp/Cargo.toml index 46de867..1d0a7c3 100644 --- a/btp/Cargo.toml +++ b/btp/Cargo.toml @@ -5,5 +5,9 @@ edition = "2021" homepage.workspace = true [dependencies] -consts = { git = "https://github.com/Foundation-Devices/prime-ble-firmware.git", features = ["dle"] } +bytemuck = { version = "1", features = ["derive"] } +consts = { git = "https://github.com/Foundation-Devices/prime-ble-firmware.git", features = [ + "dle", +] } rand = { workspace = true } +thiserror = "2" diff --git a/btp/src/chunk.rs b/btp/src/chunk.rs new file mode 100644 index 0000000..a411db8 --- /dev/null +++ b/btp/src/chunk.rs @@ -0,0 +1,51 @@ +use consts::APP_MTU; + +use crate::{Header, CHUNK_DATA_SIZE, HEADER_SIZE}; + +pub struct Chunker<'a> { + data: &'a [u8], + message_id: u16, + current_index: u16, + total_chunks: u16, +} + +impl<'a> Iterator for Chunker<'a> { + type Item = [u8; APP_MTU]; + + fn next(&mut self) -> Option { + let start_idx = self.current_index as usize * CHUNK_DATA_SIZE; + if start_idx >= self.data.len() { + return None; + } + + let end_idx = (start_idx + CHUNK_DATA_SIZE).min(self.data.len()); + let chunk_data = &self.data[start_idx..end_idx]; + + let header = Header::new( + self.message_id, + self.current_index, + self.total_chunks, + chunk_data.len() as u8, + ); + + let mut buffer = [0u8; APP_MTU]; + buffer[..HEADER_SIZE].copy_from_slice(header.as_bytes()); + buffer[HEADER_SIZE..HEADER_SIZE + chunk_data.len()].copy_from_slice(chunk_data); + self.current_index += 1; + + Some(buffer) + } +} + +/// Splits data into chunks for transmission +pub fn chunk(data: &[u8]) -> Chunker<'_> { + let message_id = rand::Rng::random::(&mut rand::rng()); + let total_chunks = data.len().div_ceil(CHUNK_DATA_SIZE) as u16; + + Chunker { + data, + message_id, + current_index: 0, + total_chunks, + } +} diff --git a/btp/src/dechunk.rs b/btp/src/dechunk.rs new file mode 100644 index 0000000..0af963d --- /dev/null +++ b/btp/src/dechunk.rs @@ -0,0 +1,209 @@ +use crate::{Header, CHUNK_DATA_SIZE, HEADER_SIZE}; + +#[derive(Debug, thiserror::Error)] +pub enum DecodeError { + #[error("chunk too small, expected at least {}", HEADER_SIZE)] + HeaderTooSmall, + + #[error("invalid chunk header")] + InvalidHeader, + + #[error("chunk data too small: header claims {expected} bytes, but only {actual} available")] + ChunkTooSmall { expected: usize, actual: usize }, + + #[error("invalid chunk index: {index} >= {total_chunks}")] + InvalidChunkIndex { index: u16, total_chunks: u16 }, + + #[error("chunk data length exceeds maximum chunk size {}", CHUNK_DATA_SIZE)] + ChunkTooLarge, +} + +#[derive(Debug, thiserror::Error)] +#[error("wrong message id: expected {expected}, actual {actual}")] +pub struct MessageIdError { + expected: u16, + actual: u16, +} + +#[derive(Debug, thiserror::Error)] +pub enum ReceiveError { + #[error(transparent)] + Decode(#[from] DecodeError), + #[error(transparent)] + MessageId(#[from] MessageIdError), +} + +#[derive(Clone, Copy)] +pub struct Chunk { + pub header: Header, + pub chunk: [u8; CHUNK_DATA_SIZE], +} + +impl Chunk { + /// Returns chunk data as slice + pub fn as_slice(&self) -> &[u8] { + &self.chunk[..self.header.data_len as usize] + } + + /// Parses raw bytes into a chunk + pub fn parse(data: &[u8]) -> Result { + let (header_data, chunk_data) = data + .split_at_checked(HEADER_SIZE) + .ok_or(DecodeError::HeaderTooSmall)?; + + let header = Header::from_bytes(header_data).ok_or(DecodeError::InvalidHeader)?; + + if header.index >= header.total_chunks { + return Err(DecodeError::InvalidChunkIndex { + index: header.index, + total_chunks: header.total_chunks, + }); + } + + let data_len = header.data_len as usize; + + if data_len > CHUNK_DATA_SIZE { + return Err(DecodeError::ChunkTooLarge); + } + + if chunk_data.len() < data_len { + return Err(DecodeError::ChunkTooSmall { + expected: data_len, + actual: chunk_data.len(), + }); + } + + let mut chunk = [0u8; CHUNK_DATA_SIZE]; + chunk[..data_len].copy_from_slice(&chunk_data[..data_len]); + + Ok(Chunk { + header: *header, + chunk, + }) + } +} + +#[derive(Debug, Default)] +pub struct Dechunker { + chunks: Vec>, + info: Option, +} + +#[derive(Debug, Clone, Copy)] +struct MessageInfo { + message_id: u16, + total_chunks: u16, + chunks_received: u16, +} + +#[derive(Debug, Clone)] +struct RawChunk { + data: [u8; CHUNK_DATA_SIZE], + len: u8, +} + +impl RawChunk { + fn as_slice(&self) -> &[u8] { + &self.data[..self.len as usize] + } +} + +impl Dechunker { + /// Creates a new dechunker + pub fn new() -> Self { + Self::default() + } + + /// Returns true if all chunks received + pub fn is_complete(&self) -> bool { + self.info + .map(|info| info.chunks_received == info.total_chunks) + .unwrap_or(false) + } + + /// Clears all chunks and resets state + pub fn clear(&mut self) { + self.chunks.clear(); + self.info = None; + } + + /// Returns progress as fraction (0.0 to 1.0) + pub fn progress(&self) -> f32 { + self.info + .map(|info| info.chunks_received as f32 / info.total_chunks as f32) + .unwrap_or(0.0) + } + + /// Inserts a parsed chunk. Use this for multiple concurrent messages. + /// First parse with [`Chunk::parse()`], lookup decoder by message ID, then insert. + pub fn insert_chunk(&mut self, chunk: Chunk) -> Result<(), MessageIdError> { + let header = &chunk.header; + + match self.info { + None => { + self.info = Some(MessageInfo { + message_id: header.message_id, + total_chunks: header.total_chunks, + chunks_received: 0, + }); + self.chunks.resize(header.total_chunks as usize, None); + } + Some(info) if info.message_id != header.message_id => { + return Err(MessageIdError { + expected: info.message_id, + actual: header.message_id, + }); + } + _ => {} + } + + // store chunk if not already received + if self.chunks[header.index as usize].is_none() { + self.chunks[header.index as usize] = Some(RawChunk { + len: header.data_len, + data: chunk.chunk, + }); + + if let Some(ref mut info) = self.info { + info.chunks_received += 1; + } + } + + Ok(()) + } + + /// Parses and inserts raw chunk data. Use this for single message at a time. + /// For multiple concurrent messages, use [`Chunk::parse()`] then [`Dechunker::insert_chunk()`]. + pub fn receive(&mut self, data: &[u8]) -> Result<(), ReceiveError> { + let chunk_with_header = Chunk::parse(data)?; + self.insert_chunk(chunk_with_header)?; + Ok(()) + } + + /// Returns the message ID if we've received a chunk + pub fn message_id(&self) -> Option { + self.info.map(|info| info.message_id) + } + + /// Returns reassembled data if complete + pub fn data(&self) -> Option> { + if !self.is_complete() { + return None; + } + + // unwraps are now ok + + let mut result = Vec::with_capacity( + self.chunks + .iter() + .map(|chunk| chunk.as_ref().unwrap().len as usize) + .sum(), + ); + + for chunk in &self.chunks { + result.extend_from_slice(chunk.as_ref().unwrap().as_slice()); + } + + Some(result) + } +} diff --git a/btp/src/lib.rs b/btp/src/lib.rs index 584592f..c7698f7 100644 --- a/btp/src/lib.rs +++ b/btp/src/lib.rs @@ -1,252 +1,46 @@ -use consts::APP_MTU; -use rand::Rng; +pub use chunk::*; +pub use dechunk::*; +mod chunk; +mod dechunk; #[cfg(test)] mod tests; -const HEADER_SIZE: usize = std::mem::size_of::
(); +use bytemuck::{Pod, Zeroable}; +use consts::APP_MTU; + +pub const HEADER_SIZE: usize = std::mem::size_of::
(); +pub const CHUNK_DATA_SIZE: usize = APP_MTU - HEADER_SIZE; -#[derive(Debug, Clone, Copy)] -struct Header { - message_id: u16, - index: u16, - total_chunks: u16, - data_len: u8, - is_last: bool, +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct Header { + pub message_id: u16, + pub index: u16, + pub total_chunks: u16, + pub data_len: u8, + pub _padding: u8, } impl Header { - fn to_bytes(self) -> [u8; HEADER_SIZE] { - let mut bytes = [0; HEADER_SIZE]; - bytes[0..2].copy_from_slice(&self.message_id.to_be_bytes()); - bytes[2..4].copy_from_slice(&self.index.to_be_bytes()); - bytes[4..6].copy_from_slice(&self.total_chunks.to_be_bytes()); - bytes[6] = self.data_len; - bytes[7] = if self.is_last { 1 } else { 0 }; - bytes - } - - fn from_bytes(bytes: &[u8]) -> Option { - if bytes.len() < HEADER_SIZE { - return None; - } - let message_id = u16::from_be_bytes([bytes[0], bytes[1]]); - let index = u16::from_be_bytes([bytes[2], bytes[3]]); - let total_chunks = u16::from_be_bytes([bytes[4], bytes[5]]); - let data_len = bytes[6]; - let is_last = bytes[7] != 0; - Some(Self { + #[inline] + fn new(message_id: u16, index: u16, total_chunks: u16, data_len: u8) -> Self { + Self { message_id, index, total_chunks, data_len, - is_last, - }) - } -} - -pub struct Chunker<'a> { - data: &'a [u8], - message_id: u16, - current_index: u16, - total_chunks: u16, - data_per_chunk: usize, -} - -impl<'a> Iterator for Chunker<'a> { - type Item = [u8; APP_MTU]; - - fn next(&mut self) -> Option { - let start_idx = self.current_index as usize * self.data_per_chunk; - if start_idx >= self.data.len() { - return None; - } - - let mut buffer = [0u8; APP_MTU]; - - let end_idx = (start_idx + self.data_per_chunk).min(self.data.len()); - let chunk_data = &self.data[start_idx..end_idx]; - let is_last = end_idx >= self.data.len(); - - let header = Header { - message_id: self.message_id, - index: self.current_index, - total_chunks: self.total_chunks, - data_len: chunk_data.len() as u8, - is_last, - }; - - buffer[..HEADER_SIZE].copy_from_slice(&header.to_bytes()); - buffer[HEADER_SIZE..HEADER_SIZE + chunk_data.len()].copy_from_slice(chunk_data); - self.current_index += 1; - - Some(buffer) - } -} - -pub fn chunk(data: &[u8]) -> Chunker<'_> { - let message_id = rand::rng().random::(); - let data_per_chunk = APP_MTU - HEADER_SIZE; - let total_chunks = data.len().div_ceil(data_per_chunk) as u16; - - Chunker { - data, - message_id, - current_index: 0, - total_chunks, - data_per_chunk, - } -} - -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum DecodeError { - PacketTooSmall { size: usize }, - InvalidHeader, - WrongMessageId { expected: u16, received: u16 }, -} - -impl std::fmt::Display for DecodeError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - DecodeError::PacketTooSmall { size } => write!(f, "Packet too small: {size} bytes"), - DecodeError::InvalidHeader => write!(f, "Invalid header"), - DecodeError::WrongMessageId { expected, received } => { - write!( - f, - "Wrong message ID: expected {expected}, received {received}" - ) - } - } - } -} - -impl std::error::Error for DecodeError {} - -const CHUNK_DATA_SIZE: usize = APP_MTU - HEADER_SIZE; - -#[derive(Clone, Copy)] -struct Chunk { - data: [u8; CHUNK_DATA_SIZE], - len: u8, -} - -impl Chunk { - fn as_slice(&self) -> &[u8] { - &self.data[..self.len as usize] - } -} - -pub struct Dechunker { - chunks: Vec>, - message_id: Option, - total_chunks: Option, - is_complete: bool, -} - -impl Default for Dechunker { - fn default() -> Self { - Self::new() - } -} - -impl Dechunker { - pub fn new() -> Self { - Self { - chunks: Vec::new(), - message_id: None, - total_chunks: None, - is_complete: false, + _padding: 0, } } - pub fn is_complete(&self) -> bool { - self.is_complete - } - - pub fn clear(&mut self) { - self.chunks.clear(); - self.message_id = None; - self.total_chunks = None; - self.is_complete = false; + #[inline] + fn as_bytes(&self) -> &[u8] { + bytemuck::bytes_of(self) } - pub fn progress(&self) -> f32 { - match self.total_chunks { - Some(total) if total > 0 => { - let received = self.chunks.iter().filter(|c| c.is_some()).count(); - received as f32 / total as f32 - } - _ => 0.0, - } - } - - pub fn receive(&mut self, data: &[u8]) -> Result>, DecodeError> { - let Some((header_data, chunk_data)) = data.split_at_checked(HEADER_SIZE) else { - return Err(DecodeError::PacketTooSmall { size: data.len() }); - }; - - let header = Header::from_bytes(header_data).ok_or(DecodeError::InvalidHeader)?; - - match self.message_id { - None => { - self.message_id = Some(header.message_id); - self.total_chunks = Some(header.total_chunks); - self.chunks.resize(header.total_chunks as usize, None); - } - Some(id) if id != header.message_id => { - return Err(DecodeError::WrongMessageId { - expected: id, - received: header.message_id, - }); - } - _ => {} - } - - let data_len = header.data_len as usize; - - // store chunk if not already received - // should this be an error? - if self.chunks[header.index as usize].is_none() { - let mut data = [0u8; CHUNK_DATA_SIZE]; - data[..data_len].copy_from_slice(&chunk_data[..data_len]); - self.chunks[header.index as usize] = Some(Chunk { - data, - len: data_len as u8, - }); - } - - if header.is_last { - self.is_complete = true; - } - - // attempt to complete the message - if self.is_complete { - let data = self.data(); - return Ok(data); - } - - Ok(None) - } - - pub fn message_id(&self) -> Option { - self.message_id - } - - pub fn data(&self) -> Option> { - if !self.is_complete { - return None; - } - - let mut result = Vec::new(); - let total = self.total_chunks? as usize; - - for i in 0..total { - match self.chunks.get(i).and_then(|chunk| chunk.as_ref()) { - Some(chunk) => result.extend_from_slice(chunk.as_slice()), - None => return None, - } - } - - Some(result) + #[inline] + fn from_bytes(bytes: &[u8]) -> Option<&Self> { + bytemuck::try_from_bytes::
(bytes).ok() } } diff --git a/btp/src/tests.rs b/btp/src/tests.rs index 27cb77a..7b4f141 100644 --- a/btp/src/tests.rs +++ b/btp/src/tests.rs @@ -1,15 +1,18 @@ -use crate::{chunk, Dechunker, APP_MTU}; +use crate::{ + chunk, Chunk, Dechunker, DecodeError, MessageIdError, APP_MTU, CHUNK_DATA_SIZE, HEADER_SIZE, +}; +use rand::{seq::SliceRandom, Rng, RngCore}; + +static TEST_STR: &[u8]= b" +This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. +This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. +This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. +This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. +"; #[test] fn end_to_end() { - let data = b" - This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. - This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. - This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. - This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. - ".to_vec(); - - let chunked_data: Vec<[u8; APP_MTU]> = chunk(&data).collect(); + let chunked_data: Vec<[u8; APP_MTU]> = chunk(TEST_STR).collect(); assert_eq!(chunked_data.len(), 3); @@ -21,61 +24,61 @@ fn end_to_end() { .expect("Failed to receive chunk"); } - assert_eq!(unchunker.data(), Some(data)); + assert_eq!(unchunker.data(), Some(TEST_STR.to_vec())); assert!(unchunker.is_complete()); } #[test] fn end_to_end_ooo() { - let data = vec![0u8; 100000]; + for _ in 0..10 { + let mut rng = rand::rng(); + let size = rng.random_range(50000..200000); + let mut data = vec![0u8; size]; + rng.fill_bytes(&mut data); - let mut chunked_data: Vec<[u8; APP_MTU]> = chunk(&data).collect(); + let mut chunks: Vec<_> = chunk(&data).collect(); + chunks.shuffle(&mut rng); - chunked_data.swap(0, 2); - chunked_data.swap(4, 3); + let mut dechunker = Dechunker::new(); - let mut dechunker = Dechunker::new(); + for (i, chunk) in chunks.iter().enumerate() { + dechunker.receive(chunk.as_ref()).unwrap(); - for chunk in chunked_data.iter() { - match dechunker.receive(chunk.as_ref()) { - Ok(result) => { - if let Some(reassembled) = result { - assert_eq!(reassembled.len(), data.len(),); - assert!(data.eq(&reassembled),); - } - } - Err(e) => panic!("Error receiving chunk: {e}"), + let expected_progress = (i + 1) as f32 / chunks.len() as f32; + assert!( + (dechunker.progress() - expected_progress).abs() < 0.01, + "Progress should match chunks received" + ); } - } - assert_eq!(dechunker.data(), Some(data)); + assert_eq!(dechunker.data(), Some(data)); + } } #[test] -fn test_single_chunk() { +fn single_chunk() { let data = b"Small data".to_vec(); let chunks: Vec<_> = chunk(&data).collect(); assert_eq!(chunks.len(), 1); let mut dechunker = Dechunker::new(); - let result = dechunker.receive(&chunks[0]).unwrap(); + dechunker.receive(&chunks[0]).unwrap(); - assert_eq!(result, Some(data)); + assert_eq!(dechunker.data(), Some(data)); assert!(dechunker.is_complete()); } #[test] -fn test_empty_data() { +fn empty_data() { let data = b""; let chunks: Vec<_> = chunk(data).collect(); assert_eq!(chunks.len(), 0, "Empty data should produce no chunks"); } #[test] -fn test_exact_chunk_boundary() { - let data_per_chunk = APP_MTU - crate::HEADER_SIZE; - let data = vec![42u8; data_per_chunk * 3]; +fn exact_chunk_boundary() { + let data = vec![42u8; CHUNK_DATA_SIZE * 3]; let chunks: Vec<_> = chunk(&data).collect(); assert_eq!( @@ -86,13 +89,8 @@ fn test_exact_chunk_boundary() { let mut dechunker = Dechunker::new(); for (i, chunk) in chunks.iter().enumerate() { - let result = dechunker.receive(chunk).unwrap(); - if i < chunks.len() - 1 { - assert!( - result.is_none(), - "non-last chunks should not complete the message" - ); - } else { + dechunker.receive(chunk).unwrap(); + if i == chunks.len() - 1 { assert_eq!( dechunker.data().as_ref(), Some(&data), @@ -104,7 +102,7 @@ fn test_exact_chunk_boundary() { } #[test] -fn test_different_message_ids() { +fn different_message_ids() { let data1 = b"Message 1".to_vec(); let data2 = b"Message 2".to_vec(); @@ -112,26 +110,20 @@ fn test_different_message_ids() { let chunks2: Vec<_> = chunk(&data2).collect(); let mut dechunker1 = Dechunker::new(); - let mut dechunker2 = Dechunker::new(); - dechunker1.receive(&chunks1[0]).unwrap(); let result = dechunker1.receive(&chunks2[0]); assert!( - matches!(result, Err(crate::DecodeError::WrongMessageId { .. })), + matches!( + result, + Err(crate::ReceiveError::MessageId(crate::MessageIdError { .. })) + ), "Chunk from different message should be rejected" ); - - let result = dechunker2.receive(&chunks2[0]).unwrap(); - assert!( - result.is_some(), - "Single chunk message should complete immediately" - ); - assert_eq!(result, Some(data2)); } #[test] -fn test_progress_tracking() { +fn progress_tracking() { let data = vec![1u8; 10000]; let chunks: Vec<_> = chunk(&data).collect(); @@ -164,29 +156,25 @@ fn test_progress_tracking() { } #[test] -fn test_duplicate_chunks() { +fn dechunker_decode_duplicate() { let data = b"Test duplicate handling".to_vec(); let chunks: Vec<_> = chunk(&data).collect(); let mut dechunker = Dechunker::new(); + dechunker.receive(&chunks[0]).unwrap(); dechunker.receive(&chunks[0]).unwrap(); dechunker.receive(&chunks[0]).unwrap(); - let result = dechunker.receive(&chunks[0]).unwrap(); - assert!( - result.is_some(), - "Duplicate chunks should not prevent completion" - ); assert_eq!( - result.unwrap(), - data, + dechunker.data(), + Some(data), "Data should be correctly reassembled despite duplicates" ); } #[test] -fn test_missing_middle_chunk() { +fn missing_middle_chunk() { let data = vec![1u8; 1000]; let chunks: Vec<_> = chunk(&data).collect(); @@ -209,14 +197,13 @@ fn test_missing_middle_chunk() { "Message should not complete with middle chunk still missing" ); - let result = dechunker.receive(&chunks[middle]).unwrap(); + dechunker.receive(&chunks[middle]).unwrap(); - assert_eq!(result, Some(data)); - assert!(dechunker.is_complete()); + assert_eq!(dechunker.data(), Some(data)); } #[test] -fn test_data_with_zeros() { +fn data_with_zeros() { let mut data = vec![0u8; 500]; data[100] = 1; data[200] = 2; @@ -227,50 +214,182 @@ fn test_data_with_zeros() { let mut dechunker = Dechunker::new(); for chunk in chunks { - if let Some(result) = dechunker.receive(&chunk).unwrap() { - assert_eq!( - result.len(), - data.len(), - "Data with zeros should maintain correct length" - ); - assert_eq!(result, data, "Data with zeros should be preserved exactly"); - } + dechunker.receive(&chunk).unwrap(); } - assert!( - dechunker.is_complete(), + assert_eq!( + dechunker.data(), + Some(data), "Dechunker should complete successfully with zero-containing data" ); } #[test] -fn test_reverse_order_decoding() { - let data = b" - This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. - This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. - This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. - This is some example data to be chunked.This is some example data to be chunked.This is some example data to be chunked. - ".to_vec(); - let chunks: Vec<_> = chunk(&data).collect(); +fn reverse_order_decoding() { + let chunks: Vec<_> = chunk(TEST_STR).collect(); let mut dechunker = Dechunker::new(); for chunk in chunks.iter().rev() { - let result = dechunker.receive(chunk).unwrap(); + dechunker.receive(chunk).unwrap(); + } - if result.is_some() { - assert_eq!( - result.unwrap(), - data, - "Data should be correctly reassembled" - ); - } + assert_eq!(dechunker.data(), Some(TEST_STR.to_vec())); +} + +#[test] +fn chunk_parse_and_insert() { + use crate::Chunk; + + let data = b"Test data for parse and push"; + let chunks: Vec<_> = chunk(data).collect(); + + let mut dechunker = Dechunker::new(); + + for raw_chunk in &chunks { + let parsed = Chunk::parse(raw_chunk).unwrap(); + dechunker.insert_chunk(parsed).unwrap(); + } + + assert_eq!(dechunker.data(), Some(data.to_vec())); +} + +#[test] +fn chunk_parse_errors() { + let small_data = vec![0u8; HEADER_SIZE - 1]; + let result = Chunk::parse(&small_data); + assert!(matches!(result, Err(DecodeError::HeaderTooSmall))); +} + +#[test] +fn chunk_too_small() { + let header = crate::Header::new(1234, 0, 1, 100); + let mut raw_chunk = vec![0u8; HEADER_SIZE + 5]; + raw_chunk[..HEADER_SIZE].copy_from_slice(header.as_bytes()); + + let result = Chunk::parse(&raw_chunk); + assert!( + matches!( + result, + Err(DecodeError::ChunkTooSmall { + expected: 100, + actual: 5 + }) + ), + "should fail if actual data is less than header claims" + ); +} + +#[test] +fn chunk_too_large() { + let header = crate::Header::new(1234, 0, 1, 255); + let mut raw_chunk = vec![0u8; HEADER_SIZE + 255]; + raw_chunk[..HEADER_SIZE].copy_from_slice(header.as_bytes()); + + let result = Chunk::parse(&raw_chunk); + assert!( + matches!(result, Err(DecodeError::ChunkTooLarge)), + "should fail when data_len exceeds maximum chunk size" + ); +} + +#[test] +fn chunk_parse_invalid_index() { + use crate::{Chunk, DecodeError, HEADER_SIZE}; + + let header = crate::Header::new(1234, 10, 1, 5); + let mut raw_chunk = vec![0u8; HEADER_SIZE + 5]; + raw_chunk[..HEADER_SIZE].copy_from_slice(header.as_bytes()); + + let result = Chunk::parse(&raw_chunk); + assert!( + matches!( + result, + Err(DecodeError::InvalidChunkIndex { + index: 10, + total_chunks: 1 + }) + ), + "invalid index" + ); +} + +#[test] +fn insert_chunk_wrong_message_id() { + let data1 = b"First message"; + let data2 = b"Second message"; + + let chunks1: Vec<_> = chunk(data1).collect(); + let chunks2: Vec<_> = chunk(data2).collect(); + + let mut dechunker = Dechunker::new(); + + let chunk1 = Chunk::parse(&chunks1[0]).unwrap(); + dechunker.insert_chunk(chunk1).unwrap(); + + let chunk2 = Chunk::parse(&chunks2[0]).unwrap(); + let result = dechunker.insert_chunk(chunk2); + + assert!( + matches!(result, Err(MessageIdError { .. })), + "message id mismatch" + ); +} + +#[test] +fn insert_chunk_out_of_order() { + let mut rng = rand::rng(); + let size = rng.random_range(5000..20000); + let mut data = vec![0u8; size]; + rng.fill_bytes(&mut data); + + let mut chunks: Vec<_> = chunk(&data).collect(); + let original_count = chunks.len(); + + chunks.shuffle(&mut rng); + + let mut dechunker = Dechunker::new(); + + for (i, chunk) in chunks.iter().enumerate() { + let parsed = Chunk::parse(chunk).unwrap(); + dechunker.insert_chunk(parsed).unwrap(); + + let expected_progress = (i + 1) as f32 / original_count as f32; + assert!( + (dechunker.progress() - expected_progress).abs() < 0.01, + "Progress should match chunks inserted" + ); } - assert!(dechunker.is_complete(), "Dechunker should be complete"); assert_eq!( dechunker.data(), Some(data), - "Final data should match original" + "Data should match after out of order reassembly" + ); +} + +#[test] +fn insert_duplicate_chunks() { + let data = b"Test duplicate handling"; + let chunks: Vec<_> = chunk(data).collect(); + + let mut dechunker = Dechunker::new(); + + let chunk0 = Chunk::parse(&chunks[0]).unwrap(); + + dechunker.insert_chunk(chunk0).unwrap(); + dechunker.insert_chunk(chunk0).unwrap(); + + assert_eq!( + dechunker.progress(), + 1.0 / chunks.len() as f32, + "Progress should only count unique chunks" ); + + for chunk in &chunks[1..] { + let parsed = Chunk::parse(chunk).unwrap(); + dechunker.insert_chunk(parsed).unwrap(); + } + + assert_eq!(dechunker.data(), Some(data.to_vec())); }