Skip to content

Commit e0f25ad

Browse files
committed
block-buffer: replace ReadBuffer::read method with read_cached and write_block methods
1 parent f5ac85f commit e0f25ad

File tree

2 files changed

+119
-123
lines changed

2 files changed

+119
-123
lines changed

block-buffer/src/read.rs

Lines changed: 76 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,48 @@
11
use super::{Array, ArraySize, Error};
2-
3-
use core::{fmt, slice};
4-
#[cfg(feature = "zeroize")]
5-
use zeroize::Zeroize;
2+
use core::fmt;
63

74
/// Buffer for reading block-generated data.
85
pub struct ReadBuffer<BS: ArraySize> {
9-
// The first byte of the block is used as position.
6+
/// The first byte of the block is used as cursor position.
7+
/// `&buffer[usize::from(buffer[0])..]` is iterpreted as unread bytes.
8+
/// The cursor position is always bigger than zero and smaller than or equal to block size.
109
buffer: Array<u8, BS>,
1110
}
1211

1312
impl<BS: ArraySize> fmt::Debug for ReadBuffer<BS> {
1413
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1514
f.debug_struct("ReadBuffer")
16-
.field("remaining_data", &self.get_pos())
17-
.finish()
15+
.field("remaining_data", &self.remaining())
16+
.finish_non_exhaustive()
1817
}
1918
}
2019

2120
impl<BS: ArraySize> Default for ReadBuffer<BS> {
2221
#[inline]
2322
fn default() -> Self {
24-
let mut buffer = Array::<u8, BS>::default();
25-
buffer[0] = BS::U8;
26-
Self { buffer }
23+
assert!(
24+
BS::USIZE != 0 && BS::USIZE < 256,
25+
"buffer block size must be bigger than zero and smaller than 256"
26+
);
27+
28+
let buffer = Default::default();
29+
let mut res = Self { buffer };
30+
// SAFETY: `BS::USIZE` satisfies the `set_pos_unchecked` safety contract
31+
unsafe { res.set_pos_unchecked(BS::USIZE) };
32+
res
2733
}
2834
}
2935

3036
impl<BS: ArraySize> Clone for ReadBuffer<BS> {
3137
#[inline]
3238
fn clone(&self) -> Self {
33-
Self {
34-
buffer: self.buffer.clone(),
35-
}
39+
let buffer = self.buffer.clone();
40+
Self { buffer }
3641
}
3742
}
3843

3944
impl<BS: ArraySize> ReadBuffer<BS> {
40-
/// Return current cursor position.
45+
/// Return current cursor position, i.e. how many bytes were read from the buffer.
4146
#[inline(always)]
4247
pub fn get_pos(&self) -> usize {
4348
let pos = self.buffer[0];
@@ -63,57 +68,68 @@ impl<BS: ArraySize> ReadBuffer<BS> {
6368
self.size() - self.get_pos()
6469
}
6570

71+
/// Set cursor position.
72+
///
73+
/// # Safety
74+
/// `pos` must be smaller than or equal to the buffer block size and be bigger than zero.
6675
#[inline(always)]
67-
fn set_pos_unchecked(&mut self, pos: usize) {
68-
debug_assert!(pos <= BS::USIZE);
76+
unsafe fn set_pos_unchecked(&mut self, pos: usize) {
77+
debug_assert!(pos != 0 && pos <= BS::USIZE);
6978
self.buffer[0] = pos as u8;
7079
}
7180

72-
/// Write remaining data inside buffer into `data`, fill remaining space
73-
/// in `data` with blocks generated by `gen_block`, and save leftover data
74-
/// from the last generated block into buffer for future use.
75-
#[inline]
76-
pub fn read(&mut self, mut data: &mut [u8], mut gen_block: impl FnMut(&mut Array<u8, BS>)) {
81+
/// Read up to `len` bytes of remaining data in the buffer.
82+
///
83+
/// Returns slice with length of `ret_len = min(len, buffer.remaining())` bytes
84+
/// and sets the cursor position to `buffer.get_pos() + ret_len`.
85+
#[inline(always)]
86+
pub fn read_cached(&mut self, len: usize) -> &[u8] {
87+
let rem = self.remaining();
88+
let new_len = core::cmp::min(rem, len);
7789
let pos = self.get_pos();
78-
let r = self.remaining();
79-
let n = data.len();
80-
81-
if r != 0 {
82-
if n < r {
83-
// double slicing allows to remove panic branches
84-
data.copy_from_slice(&self.buffer[pos..][..n]);
85-
self.set_pos_unchecked(pos + n);
86-
return;
87-
}
88-
let (left, right) = data.split_at_mut(r);
89-
data = right;
90-
left.copy_from_slice(&self.buffer[pos..]);
91-
}
9290

93-
let (blocks, leftover) = Self::to_blocks_mut(data);
94-
for block in blocks {
95-
gen_block(block);
96-
}
91+
// SAFETY: `pos + new_len` is not equal to zero and not bigger than block size
92+
unsafe { self.set_pos_unchecked(pos + new_len) };
93+
&self.buffer[pos..][..new_len]
94+
}
9795

98-
let n = leftover.len();
99-
if n != 0 {
100-
let mut block = Default::default();
101-
gen_block(&mut block);
102-
leftover.copy_from_slice(&block[..n]);
103-
self.buffer = block;
104-
self.set_pos_unchecked(n);
105-
} else {
106-
self.set_pos_unchecked(BS::USIZE);
96+
/// Write new block and consume `read_len` bytes from it.
97+
///
98+
/// If `read_len` is equal to zero, sets buffer to the exhausted state (i.e. it sets the cursor
99+
/// position to block size) and immediately returns without calling the closures.
100+
/// Otherwise, the method calls `gen_block` to fill the internal buffer,
101+
/// passes to `read_fn` slice with first `read_len` bytes of the block,
102+
/// and sets the cursor position to `read_len`.
103+
///
104+
/// # Panics
105+
/// If `read_len` is bigger than block size.
106+
#[inline(always)]
107+
pub fn write_block(
108+
&mut self,
109+
read_len: usize,
110+
gen_block: impl FnOnce(&mut Array<u8, BS>),
111+
read_fn: impl FnOnce(&[u8]),
112+
) {
113+
if read_len == 0 {
114+
unsafe { self.set_pos_unchecked(BS::USIZE) };
115+
return;
107116
}
117+
assert!(read_len < BS::USIZE);
118+
119+
gen_block(&mut self.buffer);
120+
read_fn(&self.buffer[..read_len]);
121+
122+
// We checked that `read_len` satisfies the `set_pos_unchecked` safety contract
123+
unsafe { self.set_pos_unchecked(read_len) };
108124
}
109125

110126
/// Serialize buffer into a byte array.
111127
#[inline]
112128
pub fn serialize(&self) -> Array<u8, BS> {
113-
let mut res = self.buffer.clone();
114129
let pos = self.get_pos();
130+
let mut res = self.buffer.clone();
115131
// zeroize "garbage" data
116-
for b in res[1..pos].iter_mut() {
132+
for b in &mut res[1..pos] {
117133
*b = 0;
118134
}
119135
res
@@ -122,33 +138,23 @@ impl<BS: ArraySize> ReadBuffer<BS> {
122138
/// Deserialize buffer from a byte array.
123139
#[inline]
124140
pub fn deserialize(buffer: &Array<u8, BS>) -> Result<Self, Error> {
125-
let pos = buffer[0];
126-
if pos == 0 || pos > BS::U8 || buffer[1..pos as usize].iter().any(|&b| b != 0) {
141+
let pos = usize::from(buffer[0]);
142+
if pos == 0 || pos > BS::USIZE || buffer[1..pos].iter().any(|&b| b != 0) {
127143
Err(Error)
128144
} else {
129-
Ok(Self {
130-
buffer: buffer.clone(),
131-
})
145+
let buffer = buffer.clone();
146+
Ok(Self { buffer })
132147
}
133148
}
134-
135-
/// Split message into mutable slice of parallel blocks, blocks, and leftover bytes.
136-
#[inline(always)]
137-
fn to_blocks_mut(data: &mut [u8]) -> (&mut [Array<u8, BS>], &mut [u8]) {
138-
let nb = data.len() / BS::USIZE;
139-
let (left, right) = data.split_at_mut(nb * BS::USIZE);
140-
let p = left.as_mut_ptr() as *mut Array<u8, BS>;
141-
// SAFETY: we guarantee that `blocks` does not point outside of `data`, and `p` is valid for
142-
// mutation
143-
let blocks = unsafe { slice::from_raw_parts_mut(p, nb) };
144-
(blocks, right)
145-
}
146149
}
147150

148151
#[cfg(feature = "zeroize")]
149-
impl<BS: ArraySize> Zeroize for ReadBuffer<BS> {
150-
#[inline]
151-
fn zeroize(&mut self) {
152+
impl<BS: ArraySize> Drop for ReadBuffer<BS> {
153+
fn drop(&mut self) {
154+
use zeroize::Zeroize;
152155
self.buffer.zeroize();
153156
}
154157
}
158+
159+
#[cfg(feature = "zeroize")]
160+
impl<BS: ArraySize> zeroize::ZeroizeOnDrop for ReadBuffer<BS> {}

block-buffer/tests/mod.rs

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,41 @@ fn test_read() {
8383

8484
let mut n = 0u8;
8585
let mut g = |block: &mut Array<u8, U4>| {
86-
block.iter_mut().for_each(|b| *b = n);
87-
n += 1;
86+
block.iter_mut().for_each(|b| {
87+
*b = n;
88+
n += 1;
89+
});
8890
};
8991

90-
let mut out = [0u8; 6];
91-
buf.read(&mut out, &mut g);
92-
assert_eq!(out, [0, 0, 0, 0, 1, 1]);
93-
assert_eq!(buf.get_pos(), 2);
92+
let res = buf.read_cached(0);
93+
assert!(res.is_empty());
94+
let res = buf.read_cached(10);
95+
assert!(res.is_empty());
96+
97+
buf.write_block(2, &mut g, |buf| assert_eq!(buf, [0, 1]));
9498
assert_eq!(buf.remaining(), 2);
9599

96-
let mut out = [0u8; 3];
97-
buf.read(&mut out, &mut g);
98-
assert_eq!(out, [1, 1, 2]);
99-
assert_eq!(buf.get_pos(), 1);
100-
assert_eq!(buf.remaining(), 3);
100+
let res = buf.read_cached(1);
101+
assert_eq!(res, [2]);
102+
let res = buf.read_cached(10);
103+
assert_eq!(res, [3]);
104+
assert_eq!(buf.remaining(), 0);
105+
106+
buf.write_block(0, |_| unreachable!(), |_| unreachable!());
107+
buf.write_block(3, &mut g, |buf| assert_eq!(buf, [4, 5, 6]));
108+
assert_eq!(buf.remaining(), 1);
101109

102-
let mut out = [0u8; 3];
103-
buf.read(&mut out, &mut g);
104-
assert_eq!(out, [2, 2, 2]);
105-
assert_eq!(buf.get_pos(), 4);
110+
buf.write_block(0, |_| unreachable!(), |_| unreachable!());
106111
assert_eq!(buf.remaining(), 0);
112+
let res = buf.read_cached(10);
113+
assert!(res.is_empty());
107114

108-
assert_eq!(n, 3);
115+
buf.write_block(1, &mut g, |buf| assert_eq!(buf, [8]));
116+
assert_eq!(buf.remaining(), 3);
117+
118+
let res = buf.read_cached(10);
119+
assert_eq!(res, [9, 10, 11]);
120+
assert_eq!(buf.remaining(), 0);
109121
}
110122

111123
#[test]
@@ -287,55 +299,33 @@ fn test_lazy_serialize() {
287299
fn test_read_serialize() {
288300
type Buf = ReadBuffer<U4>;
289301

290-
let mut n = 42u8;
302+
let mut n = 0u8;
291303
let mut g = |block: &mut Array<u8, U4>| {
292304
block.iter_mut().for_each(|b| {
293305
*b = n;
294306
n += 1;
295307
});
296308
};
297309

298-
let mut buf1 = Buf::default();
299-
let ser0 = buf1.serialize();
300-
assert_eq!(&ser0[..], &[4, 0, 0, 0]);
301-
assert_eq!(Buf::deserialize(&ser0).unwrap().serialize(), ser0);
302-
303-
buf1.read(&mut [0; 2], &mut g);
304-
305-
let ser1 = buf1.serialize();
306-
assert_eq!(&ser1[..], &[2, 0, 44, 45]);
310+
let mut buf = Buf::default();
311+
let ser1 = buf.serialize();
312+
assert_eq!(&ser1[..], &[4, 0, 0, 0]);
313+
assert_eq!(Buf::deserialize(&ser1).unwrap().serialize(), ser1);
307314

308-
let mut buf2 = Buf::deserialize(&ser1).unwrap();
315+
let mut buf1 = Buf::deserialize(&ser1).unwrap();
309316
assert_eq!(buf1.serialize(), ser1);
317+
assert_eq!(buf1.remaining(), 0);
318+
assert_eq!(buf1.read_cached(10), []);
310319

311-
buf1.read(&mut [0; 1], &mut g);
312-
buf2.read(&mut [0; 1], &mut g);
313-
314-
let ser2 = buf1.serialize();
315-
assert_eq!(&ser2[..], &[3, 0, 0, 45]);
316-
assert_eq!(buf1.serialize(), ser2);
317-
318-
let mut buf3 = Buf::deserialize(&ser2).unwrap();
319-
assert_eq!(buf3.serialize(), ser2);
320-
321-
buf1.read(&mut [0; 1], &mut g);
322-
buf2.read(&mut [0; 1], &mut g);
323-
buf3.read(&mut [0; 1], &mut g);
324-
325-
let ser3 = buf1.serialize();
326-
assert_eq!(&ser3[..], &[4, 0, 0, 0]);
327-
assert_eq!(buf2.serialize(), ser3);
328-
assert_eq!(buf3.serialize(), ser3);
320+
buf.write_block(2, &mut g, |buf| assert_eq!(buf, [0, 1]));
329321

330-
buf1.read(&mut [0; 1], &mut g);
331-
buf2.read(&mut [0; 1], &mut g);
332-
buf3.read(&mut [0; 1], &mut g);
322+
let ser2 = buf.serialize();
323+
assert_eq!(&ser2[..], &[2, 0, 2, 3]);
333324

334-
// note that each buffer calls `gen`, so they get filled
335-
// with different data
336-
assert_eq!(&buf1.serialize()[..], &[1, 47, 48, 49]);
337-
assert_eq!(&buf2.serialize()[..], &[1, 51, 52, 53]);
338-
assert_eq!(&buf3.serialize()[..], &[1, 55, 56, 57]);
325+
let mut buf2 = Buf::deserialize(&ser2).unwrap();
326+
assert_eq!(buf2.serialize(), ser2);
327+
assert_eq!(buf2.remaining(), 2);
328+
assert_eq!(buf2.read_cached(10), [2, 3]);
339329

340330
// Invalid position
341331
let buf = Array([0, 0, 0, 0]);

0 commit comments

Comments
 (0)