Skip to content

Commit ba16ad6

Browse files
authored
refactor: de-duplicate 'state' code (#57)
1 parent 507f64e commit ba16ad6

File tree

6 files changed

+101
-90
lines changed

6 files changed

+101
-90
lines changed

fenwick-model/src/lib.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,17 @@ impl Weights {
6262
return None;
6363
}
6464

65-
// invariant: low <= our answer < high
6665
// we seek the lowest number i such that prefix_sum(i) > prefix_sum
6766
let mut low = 0;
6867
let mut high = self.len();
69-
debug_assert!(low < high);
70-
debug_assert!(prefix_sum < self.prefix_sum(Some(high - 1)));
68+
// Ensure the search range is valid (low < high)
69+
debug_assert!(low < high, "Invalid search range");
70+
71+
// Verify that prefix_sum is within the valid range of cumulative weights
72+
debug_assert!(
73+
prefix_sum < self.prefix_sum(Some(high - 1)),
74+
"'prefix_sum' is out of bounds"
75+
);
7176
while low + 1 < high {
7277
let i = (low + high - 1) / 2;
7378
if self.prefix_sum(Some(i)) > prefix_sum {
@@ -110,7 +115,7 @@ mod tests {
110115
}
111116

112117
#[test]
113-
#[should_panic]
118+
#[should_panic(expected = "index out of bounds: the len is 4 but the index is 4")]
114119
fn range_out_of_bounds() {
115120
let weights = Weights::new(3);
116121
weights.range(Some(3));
@@ -126,7 +131,7 @@ mod tests {
126131
}
127132

128133
#[test]
129-
#[should_panic]
134+
#[should_panic(expected = "'prefix_sum' is out of bounds")]
130135
fn symbol_out_of_bounds() {
131136
let weights = Weights::new(3);
132137
weights.symbol(4);

fenwick-model/src/simple.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,10 @@ impl Model for FenwickModel {
5555
&self,
5656
symbol: Option<&Self::Symbol>,
5757
) -> Result<std::ops::Range<Self::B>, Self::ValueError> {
58-
if let Some(s) = symbol.copied() {
59-
if s >= self.weights.len() {
60-
Err(ValueError(s))
61-
} else {
62-
Ok(self.weights.range(Some(s)))
63-
}
64-
} else {
65-
Ok(self.weights.range(None))
58+
match symbol {
59+
None => Ok(self.weights.range(None)),
60+
Some(&s) if s < self.weights.len() => Ok(self.weights.range(Some(s))),
61+
Some(&s) => Err(ValueError(s)),
6662
}
6763
}
6864

src/common.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
use std::ops::Range;
2+
3+
use arithmetic_coding_core::BitStore;
4+
5+
#[derive(Debug)]
6+
pub struct State<B: BitStore> {
7+
pub precision: u32,
8+
pub low: B,
9+
pub high: B,
10+
}
11+
12+
impl<B> State<B>
13+
where
14+
B: BitStore,
15+
{
16+
pub fn new(precision: u32) -> Self {
17+
let low = B::ZERO;
18+
let high = B::ONE << precision;
19+
20+
Self {
21+
precision,
22+
low,
23+
high,
24+
}
25+
}
26+
27+
pub fn half(&self) -> B {
28+
B::ONE << (self.precision - 1)
29+
}
30+
31+
pub fn quarter(&self) -> B {
32+
B::ONE << (self.precision - 2)
33+
}
34+
35+
pub fn three_quarter(&self) -> B {
36+
self.half() + self.quarter()
37+
}
38+
39+
pub fn scale(&mut self, p: Range<B>, denominator: B) {
40+
let range = self.high - self.low + B::ONE;
41+
42+
self.high = self.low + (range * p.end) / denominator - B::ONE;
43+
self.low += (range * p.start) / denominator;
44+
}
45+
}

src/decoder.rs

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::{io, ops::Range};
44

55
use bitstream_io::BitRead;
66

7-
use crate::{BitStore, Model};
7+
use crate::{common, BitStore, Model};
88

99
// this algorithm is derived from this article - https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
1010

@@ -185,9 +185,7 @@ where
185185
B: BitStore,
186186
R: BitRead,
187187
{
188-
precision: u32,
189-
low: B,
190-
high: B,
188+
state: common::State<B>,
191189
input: R,
192190
x: B,
193191
uninitialised: bool,
@@ -200,54 +198,41 @@ where
200198
{
201199
/// todo
202200
pub fn new(precision: u32, input: R) -> Self {
203-
let low = B::ZERO;
204-
let high = B::ONE << precision;
201+
let state = common::State::new(precision);
205202
let x = B::ZERO;
206203

207204
Self {
208-
precision,
209-
low,
210-
high,
205+
state,
211206
input,
212207
x,
213208
uninitialised: true,
214209
}
215210
}
216211

217-
fn half(&self) -> B {
218-
B::ONE << (self.precision - 1)
219-
}
220-
221-
fn quarter(&self) -> B {
222-
B::ONE << (self.precision - 2)
223-
}
224-
225-
fn three_quarter(&self) -> B {
226-
self.half() + self.quarter()
227-
}
228-
229212
fn normalise(&mut self) -> io::Result<()> {
230-
while self.high < self.half() || self.low >= self.half() {
231-
if self.high < self.half() {
232-
self.high <<= 1;
233-
self.low <<= 1;
213+
while self.state.high < self.state.half() || self.state.low >= self.state.half() {
214+
if self.state.high < self.state.half() {
215+
self.state.high <<= 1;
216+
self.state.low <<= 1;
234217
self.x <<= 1;
235218
} else {
236219
// self.low >= self.half()
237-
self.low = (self.low - self.half()) << 1;
238-
self.high = (self.high - self.half()) << 1;
239-
self.x = (self.x - self.half()) << 1;
220+
self.state.low = (self.state.low - self.state.half()) << 1;
221+
self.state.high = (self.state.high - self.state.half()) << 1;
222+
self.x = (self.x - self.state.half()) << 1;
240223
}
241224

242225
if self.input.next_bit()? == Some(true) {
243226
self.x += B::ONE;
244227
}
245228
}
246229

247-
while self.low >= self.quarter() && self.high < (self.three_quarter()) {
248-
self.low = (self.low - self.quarter()) << 1;
249-
self.high = (self.high - self.quarter()) << 1;
250-
self.x = (self.x - self.quarter()) << 1;
230+
while self.state.low >= self.state.quarter()
231+
&& self.state.high < (self.state.three_quarter())
232+
{
233+
self.state.low = (self.state.low - self.state.quarter()) << 1;
234+
self.state.high = (self.state.high - self.state.quarter()) << 1;
235+
self.x = (self.x - self.state.quarter()) << 1;
251236

252237
if self.input.next_bit()? == Some(true) {
253238
self.x += B::ONE;
@@ -258,21 +243,17 @@ where
258243
}
259244

260245
fn scale(&mut self, p: Range<B>, denominator: B) -> io::Result<()> {
261-
let range = self.high - self.low + B::ONE;
262-
263-
self.high = self.low + (range * p.end) / denominator - B::ONE;
264-
self.low += (range * p.start) / denominator;
265-
246+
self.state.scale(p, denominator);
266247
self.normalise()
267248
}
268249

269250
fn value(&self, denominator: B) -> B {
270-
let range = self.high - self.low + B::ONE;
271-
((self.x - self.low + B::ONE) * denominator - B::ONE) / range
251+
let range = self.state.high - self.state.low + B::ONE;
252+
((self.x - self.state.low + B::ONE) * denominator - B::ONE) / range
272253
}
273254

274255
fn fill(&mut self) -> io::Result<()> {
275-
for _ in 0..self.precision {
256+
for _ in 0..self.state.precision {
276257
self.x <<= 1;
277258
if self.input.next_bit()? == Some(true) {
278259
self.x += B::ONE;

src/encoder.rs

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::{io, ops::Range};
44

55
use bitstream_io::BitWrite;
66

7-
use crate::{BitStore, Error, Model};
7+
use crate::{common, BitStore, Error, Model};
88

99
// this algorithm is derived from this article - https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
1010

@@ -176,9 +176,7 @@ where
176176
B: BitStore,
177177
W: BitWrite,
178178
{
179-
precision: u32,
180-
low: B,
181-
high: B,
179+
state: common::State<B>,
182180
pending: u32,
183181
output: &'a mut W,
184182
}
@@ -193,57 +191,40 @@ where
193191
/// Normally this would be done automatically using the [`Encoder::new`]
194192
/// method.
195193
pub fn new(precision: u32, output: &'a mut W) -> Self {
196-
let low = B::ZERO;
197-
let high = B::ONE << precision;
194+
let state = common::State::new(precision);
198195
let pending = 0;
199196

200197
Self {
201-
precision,
202-
low,
203-
high,
198+
state,
204199
pending,
205200
output,
206201
}
207202
}
208203

209-
fn three_quarter(&self) -> B {
210-
self.half() + self.quarter()
211-
}
212-
213-
fn half(&self) -> B {
214-
B::ONE << (self.precision - 1)
215-
}
216-
217-
fn quarter(&self) -> B {
218-
B::ONE << (self.precision - 2)
219-
}
220-
221204
fn scale(&mut self, p: Range<B>, denominator: B) -> io::Result<()> {
222-
let range = self.high - self.low + B::ONE;
223-
224-
self.high = self.low + (range * p.end) / denominator - B::ONE;
225-
self.low += (range * p.start) / denominator;
226-
205+
self.state.scale(p, denominator);
227206
self.normalise()
228207
}
229208

230209
fn normalise(&mut self) -> io::Result<()> {
231-
while self.high < self.half() || self.low >= self.half() {
232-
if self.high < self.half() {
210+
while self.state.high < self.state.half() || self.state.low >= self.state.half() {
211+
if self.state.high < self.state.half() {
233212
self.emit(false)?;
234-
self.high <<= 1;
235-
self.low <<= 1;
213+
self.state.high <<= 1;
214+
self.state.low <<= 1;
236215
} else {
237216
self.emit(true)?;
238-
self.low = (self.low - self.half()) << 1;
239-
self.high = (self.high - self.half()) << 1;
217+
self.state.low = (self.state.low - self.state.half()) << 1;
218+
self.state.high = (self.state.high - self.state.half()) << 1;
240219
}
241220
}
242221

243-
while self.low >= self.quarter() && self.high < (self.three_quarter()) {
222+
while self.state.low >= self.state.quarter()
223+
&& self.state.high < (self.state.three_quarter())
224+
{
244225
self.pending += 1;
245-
self.low = (self.low - self.quarter()) << 1;
246-
self.high = (self.high - self.quarter()) << 1;
226+
self.state.low = (self.state.low - self.state.quarter()) << 1;
227+
self.state.high = (self.state.high - self.state.quarter()) << 1;
247228
}
248229

249230
Ok(())
@@ -258,14 +239,16 @@ where
258239
Ok(())
259240
}
260241

261-
/// Flush the internal buffer and write all remaining bits to the output
242+
/// Flush the internal buffer and write all remaining bits to the output.
243+
/// This method MUST be called when you finish writing symbols to ensure
244+
/// they are fully written to the output.
262245
///
263246
/// # Errors
264247
///
265248
/// This method can fail if the output cannot be written to
266249
pub fn flush(&mut self) -> io::Result<()> {
267250
self.pending += 1;
268-
if self.low <= self.quarter() {
251+
if self.state.low <= self.state.quarter() {
269252
self.emit(false)?;
270253
} else {
271254
self.emit(true)?;

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
pub use arithmetic_coding_core::{fixed_length, max_length, one_shot, BitStore, Model};
66

7+
mod common;
78
pub mod decoder;
89
pub mod encoder;
910

0 commit comments

Comments
 (0)