Skip to content

Commit 6e58cc2

Browse files
authored
General improvements (#36)
1 parent c6ada58 commit 6e58cc2

File tree

14 files changed

+717
-158
lines changed

14 files changed

+717
-158
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ thiserror = "1.0.30"
2323
[dev-dependencies]
2424
fenwick-model = { path = "./fenwick-model" }
2525
criterion = "0.3.5"
26+
test-case = "2.0.2"
2627

2728
[[bench]]
2829
name = "sherlock"

arithmetic-coding-core/src/lib.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,4 @@ mod bitstore;
1414
pub use bitstore::BitStore;
1515

1616
mod model;
17-
pub use model::Model;
18-
19-
pub mod fixed_length;
17+
pub use model::{fixed_length, max_length, one_shot, Model};

arithmetic-coding-core/src/model.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ use std::{error::Error, ops::Range};
22

33
use crate::BitStore;
44

5+
pub mod fixed_length;
6+
pub mod max_length;
7+
pub mod one_shot;
8+
59
/// A [`Model`] is used to calculate the probability of a given symbol occuring
610
/// in a sequence. The [`Model`] is used both for encoding and decoding.
711
///
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
//! Helper trait for creating fixed-length Models
2+
3+
use std::ops::Range;
4+
5+
use crate::BitStore;
6+
7+
/// A [`Model`] is used to calculate the probability of a given symbol occuring
8+
/// in a sequence. The [`Model`] is used both for encoding and decoding. A
9+
/// 'max-length' model has a maximum length. The compressed size of a message
10+
/// equal to the maximum length is larger than with a
11+
/// [`fixed_length::Model`](crate::fixed_length::Model), but smaller than with a
12+
/// [`Model`](crate::Model).
13+
///
14+
/// A max-length model can be converted into a regular model using the
15+
/// convenience [`Wrapper`] type.
16+
///
17+
/// The more accurately a [`Model`] is able to predict the next symbol, the
18+
/// greater the compression ratio will be.
19+
///
20+
/// # Example
21+
///
22+
/// ```
23+
/// #![feature(exclusive_range_pattern)]
24+
/// #![feature(never_type)]
25+
/// # use std::ops::Range;
26+
/// #
27+
/// # use arithmetic_coding_core::max_length;
28+
///
29+
/// pub enum Symbol {
30+
/// A,
31+
/// B,
32+
/// C,
33+
/// }
34+
///
35+
/// pub struct MyModel;
36+
///
37+
/// impl max_length::Model for MyModel {
38+
/// type Symbol = Symbol;
39+
/// type ValueError = !;
40+
///
41+
/// fn probability(&self, symbol: Option<&Self::Symbol>) -> Result<Range<u32>, !> {
42+
/// Ok(match symbol {
43+
/// Some(Symbol::A) => 0..1,
44+
/// Some(Symbol::B) => 1..2,
45+
/// Some(Symbol::C) => 2..3,
46+
/// None => 3..4,
47+
/// })
48+
/// }
49+
///
50+
/// fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
51+
/// match value {
52+
/// 0..1 => Some(Symbol::A),
53+
/// 1..2 => Some(Symbol::B),
54+
/// 2..3 => Some(Symbol::C),
55+
/// 3..4 => None,
56+
/// _ => unreachable!(),
57+
/// }
58+
/// }
59+
///
60+
/// fn max_denominator(&self) -> u32 {
61+
/// 4
62+
/// }
63+
///
64+
/// fn max_length(&self) -> usize {
65+
/// 3
66+
/// }
67+
/// }
68+
/// ```
69+
pub trait Model {
70+
/// The type of symbol this [`Model`] describes
71+
type Symbol;
72+
73+
/// Invalid symbol error
74+
type ValueError: std::error::Error;
75+
76+
/// The internal representation to use for storing integers
77+
type B: BitStore = u32;
78+
79+
/// Given a symbol, return an interval representing the probability of that
80+
/// symbol occurring.
81+
///
82+
/// This is given as a range, over the denominator given by
83+
/// [`Model::denominator`]. This range should in general include `EOF`,
84+
/// which is denoted by `None`.
85+
///
86+
/// For example, from the set {heads, tails}, the interval representing
87+
/// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be
88+
/// `2..3` (with a denominator of `3`).
89+
///
90+
/// This is the inverse of the [`Model::symbol`] method
91+
///
92+
/// # Errors
93+
///
94+
/// This returns a custom error if the given symbol is not valid
95+
fn probability(
96+
&self,
97+
symbol: Option<&Self::Symbol>,
98+
) -> Result<Range<Self::B>, Self::ValueError>;
99+
100+
/// The denominator for probability ranges. See [`Model::probability`].
101+
///
102+
/// By default this method simply returns the [`Model::max_denominator`],
103+
/// which is suitable for non-adaptive models.
104+
///
105+
/// In adaptive models this value may change, however it should never exceed
106+
/// [`Model::max_denominator`], or it becomes possible for the
107+
/// [`Encoder`](crate::Encoder) and [`Decoder`](crate::Decoder) to panic due
108+
/// to overflow or underflow.
109+
fn denominator(&self) -> Self::B {
110+
self.max_denominator()
111+
}
112+
113+
/// The maximum denominator used for probability ranges. See
114+
/// [`Model::probability`].
115+
///
116+
/// This value is used to calculate an appropriate precision for the
117+
/// encoding, therefore this value must not change, and
118+
/// [`Model::denominator`] must never exceed it.
119+
fn max_denominator(&self) -> Self::B;
120+
121+
/// Given a value, return the symbol whose probability range it falls in.
122+
///
123+
/// `None` indicates `EOF`
124+
///
125+
/// This is the inverse of the [`Model::probability`] method
126+
fn symbol(&self, value: Self::B) -> Option<Self::Symbol>;
127+
128+
/// Update the current state of the model with the latest symbol.
129+
///
130+
/// This method only needs to be implemented for 'adaptive' models. It's a
131+
/// no-op by default.
132+
fn update(&mut self, _symbol: &Self::Symbol) {}
133+
134+
/// The maximum number of symbols to encode
135+
fn max_length(&self) -> usize;
136+
}
137+
138+
/// A wrapper which converts a [`fixed_length::Model`](Model) to a
139+
/// [`crate::Model`].
140+
#[derive(Debug, Clone)]
141+
pub struct Wrapper<M>
142+
where
143+
M: Model,
144+
{
145+
model: M,
146+
remaining: usize,
147+
}
148+
149+
impl<M> Wrapper<M>
150+
where
151+
M: Model,
152+
{
153+
/// Construct a new wrapper from a [`fixed_length::Model`](Model)
154+
pub fn new(model: M) -> Self {
155+
let remaining = model.max_length();
156+
Self { model, remaining }
157+
}
158+
}
159+
160+
impl<M> crate::Model for Wrapper<M>
161+
where
162+
M: Model,
163+
{
164+
type B = M::B;
165+
type Symbol = M::Symbol;
166+
type ValueError = Error<M::ValueError>;
167+
168+
fn probability(
169+
&self,
170+
symbol: Option<&Self::Symbol>,
171+
) -> Result<Range<Self::B>, Self::ValueError> {
172+
if self.remaining == 0 {
173+
if symbol.is_some() {
174+
Err(Error::UnexpectedSymbol)
175+
} else {
176+
// got an EOF when we expected it, return a 100% probability
177+
Ok(Self::B::ZERO..self.denominator())
178+
}
179+
} else {
180+
self.model
181+
.probability(symbol)
182+
.map_err(Self::ValueError::Value)
183+
}
184+
}
185+
186+
fn max_denominator(&self) -> Self::B {
187+
self.model.max_denominator()
188+
}
189+
190+
fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
191+
if self.remaining > 0 {
192+
self.model.symbol(value)
193+
} else {
194+
None
195+
}
196+
}
197+
198+
fn denominator(&self) -> Self::B {
199+
self.model.denominator()
200+
}
201+
202+
fn update(&mut self, symbol: Option<&Self::Symbol>) {
203+
if let Some(s) = symbol {
204+
self.model.update(s);
205+
self.remaining -= 1;
206+
}
207+
}
208+
}
209+
210+
/// Fixed-length encoding/decoding errors
211+
#[derive(Debug, thiserror::Error)]
212+
pub enum Error<E>
213+
where
214+
E: std::error::Error,
215+
{
216+
/// Model received a symbol when it expected an EOF
217+
#[error("Unexpected Symbol")]
218+
UnexpectedSymbol,
219+
220+
/// The model received an invalid symbol
221+
#[error(transparent)]
222+
Value(E),
223+
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
//! Helper trait for creating Models which only accept a single symbol
2+
3+
use std::ops::Range;
4+
5+
pub use crate::fixed_length::Wrapper;
6+
use crate::{fixed_length, BitStore};
7+
8+
/// A [`Model`] is used to calculate the probability of a given symbol occuring
9+
/// in a sequence. The [`Model`] is used both for encoding and decoding. A
10+
/// 'fixed-length' model always expects an exact number of symbols, and so does
11+
/// not need to encode an EOF symbol.
12+
///
13+
/// A fixed length model can be converted into a regular model using the
14+
/// convenience [`Wrapper`] type.
15+
///
16+
/// The more accurately a [`Model`] is able to predict the next symbol, the
17+
/// greater the compression ratio will be.
18+
///
19+
/// # Example
20+
///
21+
/// ```
22+
/// #![feature(exclusive_range_pattern)]
23+
/// #![feature(never_type)]
24+
/// # use std::ops::Range;
25+
/// #
26+
/// # use arithmetic_coding_core::one_shot;
27+
///
28+
/// pub enum Symbol {
29+
/// A,
30+
/// B,
31+
/// C,
32+
/// }
33+
///
34+
/// pub struct MyModel;
35+
///
36+
/// impl one_shot::Model for MyModel {
37+
/// type Symbol = Symbol;
38+
/// type ValueError = !;
39+
///
40+
/// fn probability(&self, symbol: &Self::Symbol) -> Result<Range<u32>, !> {
41+
/// Ok(match symbol {
42+
/// Symbol::A => 0..1,
43+
/// Symbol::B => 1..2,
44+
/// Symbol::C => 2..3,
45+
/// })
46+
/// }
47+
///
48+
/// fn symbol(&self, value: Self::B) -> Self::Symbol {
49+
/// match value {
50+
/// 0..1 => Symbol::A,
51+
/// 1..2 => Symbol::B,
52+
/// 2..3 => Symbol::C,
53+
/// _ => unreachable!(),
54+
/// }
55+
/// }
56+
///
57+
/// fn max_denominator(&self) -> u32 {
58+
/// 3
59+
/// }
60+
/// }
61+
/// ```
62+
pub trait Model {
63+
/// The type of symbol this [`Model`] describes
64+
type Symbol;
65+
66+
/// Invalid symbol error
67+
type ValueError: std::error::Error;
68+
69+
/// The internal representation to use for storing integers
70+
type B: BitStore = u32;
71+
72+
/// Given a symbol, return an interval representing the probability of that
73+
/// symbol occurring.
74+
///
75+
/// This is given as a range, over the denominator given by
76+
/// [`Model::denominator`]. This range should in general include `EOF`,
77+
/// which is denoted by `None`.
78+
///
79+
/// For example, from the set {heads, tails}, the interval representing
80+
/// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be
81+
/// `2..3` (with a denominator of `3`).
82+
///
83+
/// This is the inverse of the [`Model::symbol`] method
84+
///
85+
/// # Errors
86+
///
87+
/// This returns a custom error if the given symbol is not valid
88+
fn probability(&self, symbol: &Self::Symbol) -> Result<Range<Self::B>, Self::ValueError>;
89+
90+
/// The maximum denominator used for probability ranges. See
91+
/// [`Model::probability`].
92+
///
93+
/// This value is used to calculate an appropriate precision for the
94+
/// encoding, therefore this value must not change, and
95+
/// [`Model::denominator`] must never exceed it.
96+
fn max_denominator(&self) -> Self::B;
97+
98+
/// Given a value, return the symbol whose probability range it falls in.
99+
///
100+
/// `None` indicates `EOF`
101+
///
102+
/// This is the inverse of the [`Model::probability`] method
103+
fn symbol(&self, value: Self::B) -> Self::Symbol;
104+
}
105+
106+
impl<T> fixed_length::Model for T
107+
where
108+
T: Model,
109+
{
110+
type B = T::B;
111+
type Symbol = T::Symbol;
112+
type ValueError = T::ValueError;
113+
114+
fn probability(&self, symbol: &Self::Symbol) -> Result<Range<Self::B>, Self::ValueError> {
115+
Model::probability(self, symbol)
116+
}
117+
118+
fn max_denominator(&self) -> Self::B {
119+
self.max_denominator()
120+
}
121+
122+
fn symbol(&self, value: Self::B) -> Self::Symbol {
123+
Model::symbol(self, value)
124+
}
125+
126+
fn length(&self) -> usize {
127+
1
128+
}
129+
130+
fn denominator(&self) -> Self::B {
131+
self.max_denominator()
132+
}
133+
}

0 commit comments

Comments
 (0)