Skip to content

Commit 49026f5

Browse files
authored
improve symmetry between encoder and decoder (#22)
* add 'decode_all' method * make decoder constructor infallible
1 parent a11daf9 commit 49026f5

File tree

11 files changed

+105
-110
lines changed

11 files changed

+105
-110
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "arithmetic-coding"
33
description = "fast and flexible arithmetic coding library"
4-
version = "0.1.1"
4+
version = "0.2.0"
55
edition = "2021"
66
license = "MIT"
77
keywords = ["compression", "encoding", "arithmetic-coding", "lossless"]

arithmetic-coding-core/src/bitstore.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub trait BitStore:
1414
+ AddAssign
1515
+ PartialOrd
1616
+ Copy
17+
+ std::fmt::Debug
1718
{
1819
/// the number of bits needed to represent this type
1920
const BITS: u32;

benches/common/mod.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ where
2222
I: IntoIterator<Item = M::Symbol>,
2323
{
2424
let mut bitwriter = BitWriter::endian(Vec::new(), BigEndian);
25-
let mut encoder = Encoder::<M>::new(model);
25+
let mut encoder = Encoder::new(model, &mut bitwriter);
2626

27-
encoder.encode_all(input, &mut bitwriter).unwrap();
27+
encoder.encode_all(input).unwrap();
2828
bitwriter.byte_align().unwrap();
2929

3030
bitwriter.into_writer()
@@ -35,11 +35,6 @@ where
3535
M: Model,
3636
{
3737
let bitreader = BitReader::endian(buffer, BigEndian);
38-
let mut decoder = Decoder::new(model, bitreader).unwrap();
39-
let mut output = Vec::new();
40-
41-
while let Some(symbol) = decoder.decode_symbol().unwrap() {
42-
output.push(symbol);
43-
}
44-
output
38+
let mut decoder = Decoder::new(model, bitreader);
39+
decoder.decode_all().map(Result::unwrap).collect()
4540
}

examples/common/mod.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ where
2626
I: IntoIterator<Item = M::Symbol>,
2727
{
2828
let mut bitwriter = BitWriter::endian(Vec::new(), BigEndian);
29-
let mut encoder = Encoder::<M>::new(model);
29+
let mut encoder = Encoder::new(model, &mut bitwriter);
3030

31-
encoder.encode_all(input, &mut bitwriter).unwrap();
31+
encoder.encode_all(input).unwrap();
3232
bitwriter.byte_align().unwrap();
3333

3434
bitwriter.into_writer()
@@ -39,13 +39,8 @@ where
3939
M: Model,
4040
{
4141
let bitreader = BitReader::endian(buffer, BigEndian);
42-
let mut decoder = Decoder::new(model, bitreader).unwrap();
43-
let mut output = Vec::new();
44-
45-
while let Some(symbol) = decoder.decode_symbol().unwrap() {
46-
output.push(symbol);
47-
}
48-
output
42+
let mut decoder = Decoder::new(model, bitreader);
43+
decoder.decode_all().map(Result::unwrap).collect()
4944
}
5045

5146
#[allow(unused)]

examples/concatenated.rs

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -121,29 +121,29 @@ where
121121
{
122122
let mut bitwriter = BitWriter::endian(Vec::default(), BigEndian);
123123

124-
let mut encoder1 = Encoder::with_precision(model1, PRECISION);
125-
encode(&mut encoder1, input1, &mut bitwriter);
124+
let mut encoder1 = Encoder::with_precision(model1, &mut bitwriter, PRECISION);
125+
encode(&mut encoder1, input1);
126126

127127
let mut encoder2 = encoder1.chain(model2);
128-
encode(&mut encoder2, input2, &mut bitwriter);
128+
encode(&mut encoder2, input2);
129129

130-
encoder2.flush(&mut bitwriter).unwrap();
130+
encoder2.flush().unwrap();
131131

132132
bitwriter.byte_align().unwrap();
133133
bitwriter.into_writer()
134134
}
135135

136136
/// Encode all symbols, followed by EOF. Doesn't flush the encoder (allowing
137137
/// more bits to be concatenated)
138-
fn encode<M, W>(encoder: &mut Encoder<M>, input: &[M::Symbol], bitwriter: &mut W)
138+
fn encode<M, W>(encoder: &mut Encoder<M, W>, input: &[M::Symbol])
139139
where
140140
M: Model,
141141
W: BitWrite,
142142
{
143143
for symbol in input {
144-
encoder.encode(Some(symbol), bitwriter).unwrap();
144+
encoder.encode(Some(symbol)).unwrap();
145145
}
146-
encoder.encode(None, bitwriter).unwrap();
146+
encoder.encode(None).unwrap();
147147
}
148148

149149
/// Decode two sets of symbols, in sequence
@@ -154,12 +154,10 @@ where
154154
{
155155
let bitreader = BitReader::endian(buffer, BigEndian);
156156

157-
let mut decoder1 = Decoder::with_precision(model1, bitreader, PRECISION).unwrap();
158-
157+
let mut decoder1 = Decoder::with_precision(model1, bitreader, PRECISION);
159158
let output1 = decode(&mut decoder1);
160159

161160
let mut decoder2 = decoder1.chain(model2);
162-
163161
let output2 = decode(&mut decoder2);
164162

165163
(output1, output2)
@@ -171,11 +169,5 @@ where
171169
M: Model,
172170
R: BitRead,
173171
{
174-
let mut output = Vec::default();
175-
176-
while let Some(symbol) = decoder.decode_symbol().unwrap() {
177-
output.push(symbol);
178-
}
179-
180-
output
172+
decoder.decode_all().map(Result::unwrap).collect()
181173
}

fuzz/fuzz_targets/fuzz_target_1.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use libfuzzer_sys::fuzz_target;
55
mod round_trip;
66

77
fuzz_target!(|data: &[u8]| {
8-
let model = FenwickModel::with_symbols(256, 1 << 20);
8+
let model = FenwickModel::builder(256, 1 << 20).build();
99
let input: Vec<usize> = data.into_iter().copied().map(usize::from).collect();
1010

1111
round_trip::round_trip(model, input);

fuzz/fuzz_targets/round_trip.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,7 @@ where
3030
M: Model,
3131
{
3232
let bitreader = BitReader::endian(buffer, BigEndian);
33-
let mut decoder = Decoder::new(model, bitreader).expect("failed to initialise decoder");
34-
let mut output = Vec::new();
33+
let mut decoder = Decoder::new(model, bitreader).unwrap();
3534

36-
while let Some(symbol) = decoder.decode_symbol().expect("failed to encode symbol!") {
37-
output.push(symbol);
38-
}
39-
output
35+
decoder.decode_all().map(Result::unwrap).collect()
4036
}

src/decoder.rs

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::io;
22

33
use bitstream_io::BitRead;
44

5-
use crate::{BitStore, Error, Model};
5+
use crate::{BitStore, Model};
66

77
// this algorithm is derived from this article - https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
88

@@ -22,6 +22,7 @@ where
2222
high: M::B,
2323
input: R,
2424
x: M::B,
25+
uninitialised: bool,
2526
}
2627

2728
trait BitReadExt {
@@ -49,10 +50,6 @@ where
4950
/// needed to represent the [`Model::denominator`]. 'precision' bits is
5051
/// equal to [`u32::BITS`] - [`Model::denominator`] bits.
5152
///
52-
/// # Errors
53-
///
54-
/// This method can fail if the underlying [`BitRead`] cannot be read from.
55-
///
5653
/// # Panics
5754
///
5855
/// The calculation of the number of bits used for 'precision' is subject to
@@ -64,7 +61,7 @@ where
6461
///
6562
/// If these constraints cannot be satisfied this method will panic in debug
6663
/// builds
67-
pub fn new(model: M, input: R) -> io::Result<Self> {
64+
pub fn new(model: M, input: R) -> Self {
6865
let frequency_bits = model.max_denominator().log2() + 1;
6966
let precision = M::B::BITS - frequency_bits;
7067

@@ -73,10 +70,6 @@ where
7370

7471
/// Construct a new [`Decoder`] with a custom precision
7572
///
76-
/// # Errors
77-
///
78-
/// This method can fail if the underlying [`BitRead`] cannot be read from.
79-
///
8073
/// # Panics
8174
///
8275
/// The calculation of the number of bits used for 'precision' is subject to
@@ -88,7 +81,7 @@ where
8881
///
8982
/// If these constraints cannot be satisfied this method will panic in debug
9083
/// builds
91-
pub fn with_precision(model: M, input: R, precision: u32) -> io::Result<Self> {
84+
pub fn with_precision(model: M, input: R, precision: u32) -> Self {
9285
let frequency_bits = model.max_denominator().log2() + 1;
9386
debug_assert!(
9487
(precision >= (frequency_bits + 2)),
@@ -103,17 +96,15 @@ where
10396
let high = M::B::ONE << precision;
10497
let x = M::B::ZERO;
10598

106-
let mut encoder = Self {
99+
Self {
107100
model,
108101
precision,
109102
low,
110103
high,
111104
input,
112105
x,
113-
};
114-
115-
encoder.fill()?;
116-
Ok(encoder)
106+
uninitialised: true,
107+
}
117108
}
118109

119110
fn fill(&mut self) -> io::Result<()> {
@@ -141,14 +132,26 @@ where
141132
self.half() + self.quarter()
142133
}
143134

135+
/// Return an iterator over the decoded symbols.
136+
///
137+
/// The iterator will continue returning symbols until EOF is reached
138+
pub fn decode_all(&mut self) -> DecodeIter<M, R> {
139+
DecodeIter { decoder: self }
140+
}
141+
144142
/// Read the next symbol from the stream of bits
145143
///
146144
/// This method will return `Ok(None)` when EOF is reached.
147145
///
148146
/// # Errors
149147
///
150148
/// This method can fail if the underlying [`BitRead`] cannot be read from.
151-
pub fn decode_symbol(&mut self) -> Result<Option<M::Symbol>, Error<M::ValueError>> {
149+
pub fn decode(&mut self) -> io::Result<Option<M::Symbol>> {
150+
if self.uninitialised {
151+
self.fill()?;
152+
self.uninitialised = false;
153+
}
154+
152155
let range = self.high - self.low + M::B::ONE;
153156
let denominator = self.model.denominator();
154157
debug_assert!(
@@ -161,7 +164,7 @@ where
161164
let p = self
162165
.model
163166
.probability(symbol.as_ref())
164-
.map_err(Error::ValueError)?;
167+
.expect("this should not be able to fail. Check the implementation of the model.");
165168

166169
self.high = self.low + (range * p.end) / denominator - M::B::ONE;
167170
self.low += (range * p.start) / denominator;
@@ -224,6 +227,29 @@ where
224227
high: self.high,
225228
input: self.input,
226229
x: self.x,
230+
uninitialised: self.uninitialised,
227231
}
228232
}
229233
}
234+
235+
/// The iterator returned by the [`Model::decode_all`] method
236+
#[derive(Debug)]
237+
pub struct DecodeIter<'a, M, R>
238+
where
239+
M: Model,
240+
R: BitRead,
241+
{
242+
decoder: &'a mut Decoder<M, R>,
243+
}
244+
245+
impl<'a, M, R> Iterator for DecodeIter<'a, M, R>
246+
where
247+
M: Model,
248+
R: BitRead,
249+
{
250+
type Item = io::Result<M::Symbol>;
251+
252+
fn next(&mut self) -> Option<Self::Item> {
253+
self.decoder.decode().transpose()
254+
}
255+
}

0 commit comments

Comments
 (0)