Skip to content

Commit ad0ccb0

Browse files
danieleadesDCNick3
andauthored
feat: assert the chosen precision is valid when calling 'chain' in debug builds (#80)
Co-authored-by: ⭐️NINIKA⭐️ <[email protected]>
1 parent dd63e36 commit ad0ccb0

File tree

4 files changed

+166
-38
lines changed

4 files changed

+166
-38
lines changed

src/common.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use std::ops::Range;
22

33
use arithmetic_coding_core::BitStore;
44

5+
use crate::Model;
6+
57
#[derive(Debug)]
68
pub struct State<B: BitStore> {
79
pub precision: u32,
@@ -43,3 +45,15 @@ where
4345
self.low += (range * p.start) / denominator;
4446
}
4547
}
48+
49+
pub fn assert_precision_sufficient<M: Model>(max_denominator: M::B, precision: u32) {
50+
let frequency_bits = max_denominator.log2() + 1;
51+
assert!(
52+
(precision >= (frequency_bits + 2)),
53+
"not enough bits of precision to prevent overflow/underflow",
54+
);
55+
assert!(
56+
(frequency_bits + precision) <= M::B::BITS,
57+
"not enough bits in BitStore to support the required precision",
58+
);
59+
}

src/decoder.rs

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ use std::{io, ops::Range};
44

55
use bitstream_io::BitRead;
66

7-
use crate::{common, BitStore, Model};
7+
use crate::{
8+
common::{self, assert_precision_sufficient},
9+
BitStore, Model,
10+
};
811

912
// this algorithm is derived from this article - https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
1013

@@ -79,23 +82,15 @@ where
7982
/// If these constraints cannot be satisfied this method will panic in debug
8083
/// builds
8184
pub fn with_precision(model: M, input: R, precision: u32) -> Self {
82-
let frequency_bits = model.max_denominator().log2() + 1;
83-
debug_assert!(
84-
(precision >= (frequency_bits + 2)),
85-
"not enough bits of precision to prevent overflow/underflow",
86-
);
87-
debug_assert!(
88-
(frequency_bits + precision) <= M::B::BITS,
89-
"not enough bits in BitStore to support the required precision",
90-
);
91-
9285
let state = State::new(precision, input);
93-
94-
Self { model, state }
86+
Self::with_state(state, model)
9587
}
9688

9789
/// todo
98-
pub const fn with_state(state: State<M::B, R>, model: M) -> Self {
90+
pub fn with_state(state: State<M::B, R>, model: M) -> Self {
91+
#[cfg(debug_assertions)]
92+
assert_precision_sufficient::<M>(model.max_denominator(), state.state.precision);
93+
9994
Self { model, state }
10095
}
10196

@@ -144,10 +139,7 @@ where
144139
where
145140
X: Model<B = M::B>,
146141
{
147-
Decoder {
148-
model,
149-
state: self.state,
150-
}
142+
Decoder::with_state(self.state, model)
151143
}
152144

153145
/// todo

src/encoder.rs

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ use std::{io, ops::Range};
44

55
use bitstream_io::BitWrite;
66

7-
use crate::{common, BitStore, Error, Model};
7+
use crate::{
8+
common::{self, assert_precision_sufficient},
9+
BitStore, Error, Model,
10+
};
811

912
// this algorithm is derived from this article - https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
1013

@@ -65,27 +68,17 @@ where
6568
/// If these constraints cannot be satisfied this method will panic in debug
6669
/// builds
6770
pub fn with_precision(model: M, bitwriter: &'a mut W, precision: u32) -> Self {
68-
let frequency_bits = model.max_denominator().log2() + 1;
69-
debug_assert!(
70-
(precision >= (frequency_bits + 2)),
71-
"not enough bits of precision to prevent overflow/underflow",
72-
);
73-
debug_assert!(
74-
(frequency_bits + precision) <= M::B::BITS,
75-
"not enough bits in BitStore to support the required precision",
76-
);
77-
78-
Self {
79-
model,
80-
state: State::new(precision, bitwriter),
81-
}
71+
let state = State::new(precision, bitwriter);
72+
Self::with_state(state, model)
8273
}
8374

8475
/// Create an encoder from an existing [`State`].
8576
///
8677
/// This is useful for manually chaining a shared buffer through multiple
8778
/// encoders.
88-
pub const fn with_state(state: State<'a, M::B, W>, model: M) -> Self {
79+
pub fn with_state(state: State<'a, M::B, W>, model: M) -> Self {
80+
#[cfg(debug_assertions)]
81+
assert_precision_sufficient::<M>(model.max_denominator(), state.state.precision);
8982
Self { model, state }
9083
}
9184

@@ -162,10 +155,7 @@ where
162155
where
163156
X: Model<B = M::B>,
164157
{
165-
Encoder {
166-
model,
167-
state: self.state,
168-
}
158+
Encoder::with_state(self.state, model)
169159
}
170160
}
171161

tests/precision_checking.rs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// these tests check the asserts that are only present in debug configurations
2+
// so they won't pass in release mode
3+
#![cfg(debug_assertions)]
4+
5+
use std::{convert::Infallible, io::Cursor, ops::Range};
6+
7+
use arithmetic_coding::{decoder, encoder, Decoder, Encoder};
8+
use arithmetic_coding_core::one_shot;
9+
use bitstream_io::{BigEndian, BitReader, BitWriter};
10+
11+
#[derive(Copy, Clone)]
12+
struct SmallModel;
13+
impl one_shot::Model for SmallModel {
14+
type B = u64;
15+
type Symbol = u64;
16+
type ValueError = Infallible;
17+
18+
fn probability(&self, &value: &Self::Symbol) -> Result<Range<Self::B>, Self::ValueError> {
19+
#[allow(clippy::range_plus_one)]
20+
Ok(value..value + 1)
21+
}
22+
23+
fn max_denominator(&self) -> Self::B {
24+
2
25+
}
26+
27+
fn symbol(&self, value: Self::B) -> Self::Symbol {
28+
value
29+
}
30+
}
31+
32+
#[derive(Copy, Clone)]
33+
struct BigModel;
34+
impl one_shot::Model for BigModel {
35+
type B = u64;
36+
type Symbol = u64;
37+
type ValueError = Infallible;
38+
39+
fn probability(&self, &value: &Self::Symbol) -> Result<Range<Self::B>, Self::ValueError> {
40+
#[allow(clippy::range_plus_one)]
41+
Ok(value..value + 1)
42+
}
43+
44+
fn max_denominator(&self) -> Self::B {
45+
u64::from(u32::MAX) / 2
46+
}
47+
48+
fn symbol(&self, value: Self::B) -> Self::Symbol {
49+
value
50+
}
51+
}
52+
53+
// this is one bit short of what it must be
54+
const PRECISION: u32 = 32;
55+
56+
// Encoder::new should select the correct precision automagically, so we don't
57+
// expect it to panic
58+
#[test]
59+
fn encoder_new_doesnt_panic() {
60+
Encoder::new(
61+
one_shot::Wrapper::new(BigModel),
62+
&mut BitWriter::endian(Vec::new(), BigEndian),
63+
);
64+
}
65+
66+
#[test]
67+
#[should_panic(expected = "not enough bits of precision to prevent overflow/underflow")]
68+
fn encoder_with_precision_panics() {
69+
Encoder::with_precision(
70+
one_shot::Wrapper::new(BigModel),
71+
&mut BitWriter::endian(Vec::new(), BigEndian),
72+
PRECISION,
73+
);
74+
}
75+
76+
#[test]
77+
#[should_panic(expected = "not enough bits of precision to prevent overflow/underflow")]
78+
fn encoder_with_state_panics() {
79+
Encoder::with_state(
80+
encoder::State::new(PRECISION, &mut BitWriter::endian(Vec::new(), BigEndian)),
81+
one_shot::Wrapper::new(BigModel),
82+
);
83+
}
84+
85+
#[test]
86+
#[should_panic(expected = "not enough bits of precision to prevent overflow/underflow")]
87+
fn encoder_chain_panics() {
88+
let mut writer = BitWriter::endian(Vec::new(), BigEndian);
89+
let encoder =
90+
Encoder::with_precision(one_shot::Wrapper::new(SmallModel), &mut writer, PRECISION);
91+
92+
encoder.chain(one_shot::Wrapper::new(BigModel));
93+
}
94+
95+
#[test]
96+
fn decoder_new_doesnt_panic() {
97+
Decoder::new(
98+
one_shot::Wrapper::new(BigModel),
99+
BitReader::endian(Cursor::new(&[]), BigEndian),
100+
);
101+
}
102+
103+
#[test]
104+
#[should_panic(expected = "not enough bits of precision to prevent overflow/underflow")]
105+
fn decoder_with_precision_panics() {
106+
Decoder::with_precision(
107+
one_shot::Wrapper::new(BigModel),
108+
BitReader::endian(Cursor::new(&[]), BigEndian),
109+
PRECISION,
110+
);
111+
}
112+
113+
#[test]
114+
#[should_panic(expected = "not enough bits of precision to prevent overflow/underflow")]
115+
fn decoder_with_state_panics() {
116+
Decoder::with_state(
117+
decoder::State::new(PRECISION, BitReader::endian(Cursor::new(&[]), BigEndian)),
118+
one_shot::Wrapper::new(BigModel),
119+
);
120+
}
121+
122+
#[test]
123+
#[should_panic(expected = "not enough bits of precision to prevent overflow/underflow")]
124+
fn decoder_chain_panics() {
125+
let decoder = Decoder::with_precision(
126+
one_shot::Wrapper::new(SmallModel),
127+
BitReader::endian(Cursor::new(&[]), BigEndian),
128+
PRECISION,
129+
);
130+
131+
decoder.chain(one_shot::Wrapper::new(BigModel));
132+
}

0 commit comments

Comments
 (0)