Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 172 additions & 2 deletions btp/src/dechunk.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{Header, CHUNK_DATA_SIZE, HEADER_SIZE};
use std::io::{self, Write};

#[derive(Debug, thiserror::Error)]
#[derive(Debug, PartialEq, thiserror::Error)]
pub enum DecodeError {
#[error("chunk too small, expected at least {}", HEADER_SIZE)]
HeaderTooSmall,
Expand Down Expand Up @@ -97,7 +98,7 @@ struct MessageInfo {
}

#[derive(Debug, Clone)]
struct RawChunk {
pub(crate) struct RawChunk {
data: [u8; CHUNK_DATA_SIZE],
len: u8,
}
Expand Down Expand Up @@ -207,3 +208,172 @@ impl Dechunker {
Some(result)
}
}

#[derive(Debug)]
struct DechunkerSlot {
dechunker: Dechunker,
last_used: u64,
}

pub struct MasterDechunker<const N: usize = 10> {
dechunkers: [Option<DechunkerSlot>; N],
counter: u64,
}

impl<const N: usize> Default for MasterDechunker<N> {
fn default() -> Self {
Self {
dechunkers: std::array::from_fn(|_| None),
counter: 0,
}
}
}

impl<const N: usize> MasterDechunker<N> {
pub fn insert_chunk(&mut self, chunk: Chunk) -> Option<Vec<u8>> {
let message_id = chunk.header.message_id;

for decoder_slot in &mut self.dechunkers {
if let Some(ref mut slot) = decoder_slot {
if slot.dechunker.message_id() == Some(message_id) {
self.counter += 1;
slot.last_used = self.counter;
slot.dechunker.insert_chunk(chunk).unwrap();

return if slot.dechunker.is_complete() {
decoder_slot.take().unwrap().dechunker.data()
} else {
None
};
}
}
}

let target_slot =
if let Some(empty_slot) = self.dechunkers.iter_mut().find(|slot| slot.is_none()) {
empty_slot
} else {
let lru_index = self
.dechunkers
.iter()
.enumerate()
.filter_map(|(i, slot)| slot.as_ref().map(|s| (i, s.last_used)))
.min_by_key(|(_, last_used)| *last_used)
.map(|(i, _)| i)
.expect("should find slot");

&mut self.dechunkers[lru_index]
};

let mut decoder = Dechunker::new();
decoder.insert_chunk(chunk).unwrap();

if decoder.is_complete() {
decoder.data()
} else {
self.counter += 1;
*target_slot = Some(DechunkerSlot {
dechunker: decoder,
last_used: self.counter,
});
None
}
}
}

#[derive(Debug, thiserror::Error)]
pub enum StreamError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
MessageId(#[from] MessageIdError),
}

pub struct StreamDechunker<W: Write> {
writer: W,
pub(crate) chunks: Vec<Option<RawChunk>>,
info: Option<MessageInfo>,
next_chunk_to_write: u16,
bytes_written: u64,
}

impl<W: Write> StreamDechunker<W> {
pub fn new(writer: W) -> Self {
Self {
writer,
chunks: Vec::new(),
info: None,
next_chunk_to_write: 0,
bytes_written: 0,
}
}

pub fn insert_chunk(&mut self, chunk: Chunk) -> Result<bool, StreamError> {
let header = &chunk.header;

match self.info {
None => {
self.info = Some(MessageInfo {
message_id: header.message_id,
total_chunks: header.total_chunks,
chunks_received: 0,
});
self.chunks.resize(header.total_chunks as usize, None);
}
Some(info) if info.message_id != header.message_id => {
return Err(StreamError::MessageId(MessageIdError {
expected: info.message_id,
actual: header.message_id,
}));
}
_ => {}
}

if self.chunks[header.index as usize].is_none() {
self.chunks[header.index as usize] = Some(RawChunk {
len: header.data_len,
data: chunk.chunk,
});

if let Some(ref mut info) = self.info {
info.chunks_received += 1;
}
}

while (self.next_chunk_to_write as usize) < self.chunks.len() {
if let Some(chunk) = self.chunks[self.next_chunk_to_write as usize].take() {
self.writer.write_all(chunk.as_slice())?;
self.next_chunk_to_write += 1;
self.bytes_written += chunk.len as u64;
} else {
break;
}
}

Ok(self.is_complete())
}

pub fn is_complete(&self) -> bool {
self.info
.map(|info| info.chunks_received == info.total_chunks)
.unwrap_or(false)
}

pub fn message_id(&self) -> Option<u16> {
self.info.map(|info| info.message_id)
}

pub fn bytes_written(&self) -> u64 {
self.bytes_written
}

pub fn progress(&self) -> f32 {
self.info
.map(|info| info.chunks_received as f32 / info.total_chunks as f32)
.unwrap_or(0.0)
}

pub fn into_writer(self) -> W {
self.writer
}
}
Loading