|
1 | 1 | use crate::{Header, CHUNK_DATA_SIZE, HEADER_SIZE}; |
| 2 | +use std::io::{self, Write}; |
2 | 3 |
|
3 | | -#[derive(Debug, thiserror::Error)] |
| 4 | +#[derive(Debug, PartialEq, thiserror::Error)] |
4 | 5 | pub enum DecodeError { |
5 | 6 | #[error("chunk too small, expected at least {}", HEADER_SIZE)] |
6 | 7 | HeaderTooSmall, |
@@ -97,7 +98,7 @@ struct MessageInfo { |
97 | 98 | } |
98 | 99 |
|
99 | 100 | #[derive(Debug, Clone)] |
100 | | -struct RawChunk { |
| 101 | +pub(crate) struct RawChunk { |
101 | 102 | data: [u8; CHUNK_DATA_SIZE], |
102 | 103 | len: u8, |
103 | 104 | } |
@@ -207,3 +208,172 @@ impl Dechunker { |
207 | 208 | Some(result) |
208 | 209 | } |
209 | 210 | } |
| 211 | + |
| 212 | +#[derive(Debug)] |
| 213 | +struct DechunkerSlot { |
| 214 | + dechunker: Dechunker, |
| 215 | + last_used: u64, |
| 216 | +} |
| 217 | + |
| 218 | +pub struct MasterDechunker<const N: usize = 10> { |
| 219 | + dechunkers: [Option<DechunkerSlot>; N], |
| 220 | + counter: u64, |
| 221 | +} |
| 222 | + |
| 223 | +impl<const N: usize> Default for MasterDechunker<N> { |
| 224 | + fn default() -> Self { |
| 225 | + Self { |
| 226 | + dechunkers: std::array::from_fn(|_| None), |
| 227 | + counter: 0, |
| 228 | + } |
| 229 | + } |
| 230 | +} |
| 231 | + |
| 232 | +impl<const N: usize> MasterDechunker<N> { |
| 233 | + pub fn insert_chunk(&mut self, chunk: Chunk) -> Option<Vec<u8>> { |
| 234 | + let message_id = chunk.header.message_id; |
| 235 | + |
| 236 | + for decoder_slot in &mut self.dechunkers { |
| 237 | + if let Some(ref mut slot) = decoder_slot { |
| 238 | + if slot.dechunker.message_id() == Some(message_id) { |
| 239 | + self.counter += 1; |
| 240 | + slot.last_used = self.counter; |
| 241 | + slot.dechunker.insert_chunk(chunk).unwrap(); |
| 242 | + |
| 243 | + return if slot.dechunker.is_complete() { |
| 244 | + decoder_slot.take().unwrap().dechunker.data() |
| 245 | + } else { |
| 246 | + None |
| 247 | + }; |
| 248 | + } |
| 249 | + } |
| 250 | + } |
| 251 | + |
| 252 | + let target_slot = |
| 253 | + if let Some(empty_slot) = self.dechunkers.iter_mut().find(|slot| slot.is_none()) { |
| 254 | + empty_slot |
| 255 | + } else { |
| 256 | + let lru_index = self |
| 257 | + .dechunkers |
| 258 | + .iter() |
| 259 | + .enumerate() |
| 260 | + .filter_map(|(i, slot)| slot.as_ref().map(|s| (i, s.last_used))) |
| 261 | + .min_by_key(|(_, last_used)| *last_used) |
| 262 | + .map(|(i, _)| i) |
| 263 | + .expect("should find slot"); |
| 264 | + |
| 265 | + &mut self.dechunkers[lru_index] |
| 266 | + }; |
| 267 | + |
| 268 | + let mut decoder = Dechunker::new(); |
| 269 | + decoder.insert_chunk(chunk).unwrap(); |
| 270 | + |
| 271 | + if decoder.is_complete() { |
| 272 | + decoder.data() |
| 273 | + } else { |
| 274 | + self.counter += 1; |
| 275 | + *target_slot = Some(DechunkerSlot { |
| 276 | + dechunker: decoder, |
| 277 | + last_used: self.counter, |
| 278 | + }); |
| 279 | + None |
| 280 | + } |
| 281 | + } |
| 282 | +} |
| 283 | + |
| 284 | +#[derive(Debug, thiserror::Error)] |
| 285 | +pub enum StreamError { |
| 286 | + #[error(transparent)] |
| 287 | + Io(#[from] io::Error), |
| 288 | + #[error(transparent)] |
| 289 | + MessageId(#[from] MessageIdError), |
| 290 | +} |
| 291 | + |
| 292 | +pub struct StreamDechunker<W: Write> { |
| 293 | + writer: W, |
| 294 | + pub(crate) chunks: Vec<Option<RawChunk>>, |
| 295 | + info: Option<MessageInfo>, |
| 296 | + next_chunk_to_write: u16, |
| 297 | + bytes_written: u64, |
| 298 | +} |
| 299 | + |
| 300 | +impl<W: Write> StreamDechunker<W> { |
| 301 | + pub fn new(writer: W) -> Self { |
| 302 | + Self { |
| 303 | + writer, |
| 304 | + chunks: Vec::new(), |
| 305 | + info: None, |
| 306 | + next_chunk_to_write: 0, |
| 307 | + bytes_written: 0, |
| 308 | + } |
| 309 | + } |
| 310 | + |
| 311 | + pub fn insert_chunk(&mut self, chunk: Chunk) -> Result<bool, StreamError> { |
| 312 | + let header = &chunk.header; |
| 313 | + |
| 314 | + match self.info { |
| 315 | + None => { |
| 316 | + self.info = Some(MessageInfo { |
| 317 | + message_id: header.message_id, |
| 318 | + total_chunks: header.total_chunks, |
| 319 | + chunks_received: 0, |
| 320 | + }); |
| 321 | + self.chunks.resize(header.total_chunks as usize, None); |
| 322 | + } |
| 323 | + Some(info) if info.message_id != header.message_id => { |
| 324 | + return Err(StreamError::MessageId(MessageIdError { |
| 325 | + expected: info.message_id, |
| 326 | + actual: header.message_id, |
| 327 | + })); |
| 328 | + } |
| 329 | + _ => {} |
| 330 | + } |
| 331 | + |
| 332 | + if self.chunks[header.index as usize].is_none() { |
| 333 | + self.chunks[header.index as usize] = Some(RawChunk { |
| 334 | + len: header.data_len, |
| 335 | + data: chunk.chunk, |
| 336 | + }); |
| 337 | + |
| 338 | + if let Some(ref mut info) = self.info { |
| 339 | + info.chunks_received += 1; |
| 340 | + } |
| 341 | + } |
| 342 | + |
| 343 | + while (self.next_chunk_to_write as usize) < self.chunks.len() { |
| 344 | + if let Some(chunk) = self.chunks[self.next_chunk_to_write as usize].take() { |
| 345 | + self.writer.write_all(chunk.as_slice())?; |
| 346 | + self.next_chunk_to_write += 1; |
| 347 | + self.bytes_written += chunk.len as u64; |
| 348 | + } else { |
| 349 | + break; |
| 350 | + } |
| 351 | + } |
| 352 | + |
| 353 | + Ok(self.is_complete()) |
| 354 | + } |
| 355 | + |
| 356 | + pub fn is_complete(&self) -> bool { |
| 357 | + self.info |
| 358 | + .map(|info| info.chunks_received == info.total_chunks) |
| 359 | + .unwrap_or(false) |
| 360 | + } |
| 361 | + |
| 362 | + pub fn message_id(&self) -> Option<u16> { |
| 363 | + self.info.map(|info| info.message_id) |
| 364 | + } |
| 365 | + |
| 366 | + pub fn bytes_written(&self) -> u64 { |
| 367 | + self.bytes_written |
| 368 | + } |
| 369 | + |
| 370 | + pub fn progress(&self) -> f32 { |
| 371 | + self.info |
| 372 | + .map(|info| info.chunks_received as f32 / info.total_chunks as f32) |
| 373 | + .unwrap_or(0.0) |
| 374 | + } |
| 375 | + |
| 376 | + pub fn into_writer(self) -> W { |
| 377 | + self.writer |
| 378 | + } |
| 379 | +} |
0 commit comments