Skip to content

Commit dc961cd

Browse files
authored
add adaptive model examples using fenwick trees (#16)
1 parent 38a80e4 commit dc961cd

File tree

9 files changed

+433
-6
lines changed

9 files changed

+433
-6
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ bitstream-io = "1.2.0"
1313
thiserror = "1.0.30"
1414

1515
[dev-dependencies]
16+
fenwick = { version = "1.0.0" }
1617
test-case = "2.0.0"

examples/common/mod.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,30 @@ where
4747
}
4848
output
4949
}
50+
51+
#[allow(unused)]
52+
pub fn round_trip_string<M>(model: M, input: String)
53+
where
54+
M: Model<Symbol = char> + Clone,
55+
{
56+
let input_bytes = input.bytes().len();
57+
58+
let buffer = encode(model.clone(), input.chars());
59+
60+
let output_bytes = buffer.len();
61+
62+
println!("input bytes: {}", input_bytes);
63+
println!("output bytes: {}", output_bytes);
64+
65+
println!(
66+
"compression ratio: {}",
67+
input_bytes as f32 / output_bytes as f32
68+
);
69+
70+
let output = decode(model, &buffer);
71+
72+
let mut prefix: String = output.into_iter().take(299).collect();
73+
prefix.push_str("...");
74+
75+
println!("{}", prefix);
76+
}

examples/fenwick/context_switching.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#![allow(missing_docs, unused)]
2+
//! Fenwick tree based context-switching model
3+
4+
use arithmetic_coding::Model;
5+
6+
use super::Weights;
7+
8+
#[derive(Debug, Clone)]
9+
pub struct FenwickModel {
10+
contexts: Vec<Weights>,
11+
previous_context: usize,
12+
current_context: usize,
13+
denominator: u64,
14+
max_denominator: u64,
15+
}
16+
17+
impl FenwickModel {
18+
#[must_use]
19+
pub fn with_symbols(symbols: usize) -> Self {
20+
let mut contexts = Vec::with_capacity(symbols + 1);
21+
let mut denominator = 0;
22+
let max_denominator = 1 << 17;
23+
24+
for _ in 0..=symbols {
25+
let weight = Weights::new(symbols);
26+
denominator = denominator.max(weight.total());
27+
contexts.push(Weights::new(symbols));
28+
}
29+
30+
Self {
31+
contexts,
32+
previous_context: 1,
33+
current_context: 1,
34+
denominator,
35+
max_denominator,
36+
}
37+
}
38+
39+
fn context(&self) -> &Weights {
40+
&self.contexts[self.current_context]
41+
}
42+
43+
fn context_mut(&mut self) -> &mut Weights {
44+
&mut self.contexts[self.current_context]
45+
}
46+
}
47+
48+
#[derive(Debug, thiserror::Error)]
49+
#[error("invalid symbol received: {0}")]
50+
pub struct ValueError(usize);
51+
52+
impl Model for FenwickModel {
53+
type B = u64;
54+
type Symbol = usize;
55+
type ValueError = ValueError;
56+
57+
fn probability(
58+
&self,
59+
symbol: Option<&Self::Symbol>,
60+
) -> Result<std::ops::Range<Self::B>, Self::ValueError> {
61+
Ok(self.context().range(symbol.copied()))
62+
}
63+
64+
fn max_denominator(&self) -> Self::B {
65+
self.max_denominator
66+
}
67+
68+
fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
69+
self.context().symbol(value)
70+
}
71+
72+
fn denominator(&self) -> Self::B {
73+
self.context().total
74+
}
75+
76+
fn update(&mut self, symbol: Option<&Self::Symbol>) {
77+
debug_assert!(
78+
self.denominator() < self.max_denominator,
79+
"hit max denominator!"
80+
);
81+
if self.denominator() < self.max_denominator {
82+
self.context_mut().update(symbol.copied(), 1);
83+
self.denominator = self.denominator.max(self.context().total());
84+
}
85+
self.current_context = symbol.map(|x| x + 1).unwrap_or_default();
86+
}
87+
}

examples/fenwick/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//! [`Models`](crate::Model) implemented using Fenwick trees
2+
3+
use std::ops::Range;
4+
5+
pub mod context_switching;
6+
pub mod simple;
7+
8+
/// A wrapper around a vector of fenwick counts, with one additional weight for
9+
/// EOF.
10+
#[derive(Debug, Clone)]
11+
struct Weights {
12+
fenwick_counts: Vec<u64>,
13+
total: u64,
14+
}
15+
16+
impl Weights {
17+
fn new(n: usize) -> Self {
18+
// we add one extra value here to account for the EOF
19+
let mut fenwick_counts = vec![0; n + 1];
20+
21+
for i in 0..fenwick_counts.len() {
22+
fenwick::array::update(&mut fenwick_counts, i, 1);
23+
}
24+
25+
let total = fenwick_counts.len() as u64;
26+
Self {
27+
fenwick_counts,
28+
total,
29+
}
30+
}
31+
32+
fn update(&mut self, i: Option<usize>, delta: u64) {
33+
let index = i.map(|i| i + 1).unwrap_or_default();
34+
fenwick::array::update(&mut self.fenwick_counts, index, delta);
35+
self.total += delta;
36+
}
37+
38+
fn prefix_sum(&self, i: Option<usize>) -> u64 {
39+
let index = i.map(|i| i + 1).unwrap_or_default();
40+
fenwick::array::prefix_sum(&self.fenwick_counts, index)
41+
}
42+
43+
fn range(&self, i: Option<usize>) -> Range<u64> {
44+
let index = i.map(|i| i + 1).unwrap_or_default();
45+
46+
let upper = fenwick::array::prefix_sum(&self.fenwick_counts, index);
47+
48+
let lower = if index == 0 {
49+
0
50+
} else {
51+
fenwick::array::prefix_sum(&self.fenwick_counts, index - 1)
52+
};
53+
lower..upper
54+
}
55+
56+
fn len(&self) -> usize {
57+
self.fenwick_counts.len() - 1
58+
}
59+
60+
fn symbol(&self, prefix_sum: u64) -> Option<usize> {
61+
if prefix_sum < self.prefix_sum(None) {
62+
return None;
63+
}
64+
65+
for i in 0..self.len() {
66+
if prefix_sum < self.prefix_sum(Some(i)) {
67+
return Some(i);
68+
}
69+
}
70+
71+
unreachable!()
72+
}
73+
74+
fn total(&self) -> u64 {
75+
self.total
76+
}
77+
}

examples/fenwick/simple.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#![allow(missing_docs, unused)]
2+
//! simple adaptive model using a fenwick tree
3+
4+
use arithmetic_coding::Model;
5+
6+
use super::Weights;
7+
8+
#[derive(Debug, Clone)]
9+
pub struct FenwickModel {
10+
weights: Weights,
11+
max_denominator: u64,
12+
}
13+
14+
impl FenwickModel {
15+
#[must_use]
16+
pub fn with_symbols(symbols: usize) -> Self {
17+
let weights = Weights::new(symbols);
18+
19+
Self {
20+
weights,
21+
max_denominator: 1 << 17,
22+
}
23+
}
24+
}
25+
26+
#[derive(Debug, thiserror::Error)]
27+
#[error("invalid symbol received: {0}")]
28+
pub struct ValueError(pub usize);
29+
30+
impl Model for FenwickModel {
31+
type B = u64;
32+
type Symbol = usize;
33+
type ValueError = ValueError;
34+
35+
fn probability(
36+
&self,
37+
symbol: Option<&Self::Symbol>,
38+
) -> Result<std::ops::Range<Self::B>, Self::ValueError> {
39+
if let Some(s) = symbol.copied() {
40+
if s >= self.weights.len() {
41+
Err(ValueError(s))
42+
} else {
43+
Ok(self.weights.range(Some(s)))
44+
}
45+
} else {
46+
Ok(self.weights.range(None))
47+
}
48+
}
49+
50+
fn max_denominator(&self) -> Self::B {
51+
self.max_denominator
52+
}
53+
54+
fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
55+
self.weights.symbol(value)
56+
}
57+
58+
fn denominator(&self) -> Self::B {
59+
self.weights.total()
60+
}
61+
62+
fn update(&mut self, symbol: Option<&Self::Symbol>) {
63+
debug_assert!(
64+
self.denominator() < self.max_denominator,
65+
"hit max denominator!"
66+
);
67+
if self.denominator() < self.max_denominator {
68+
self.weights.update(symbol.copied(), 1);
69+
}
70+
}
71+
}

examples/fenwick_adaptive.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use std::{fs::File, io::Read, ops::Range};
2+
3+
use arithmetic_coding::Model;
4+
5+
mod common;
6+
mod fenwick;
7+
8+
use self::fenwick::simple::{FenwickModel, ValueError};
9+
10+
const ALPHABET: &str =
11+
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,\n-':()[]#*;\"!?*&é/àâè%@$";
12+
13+
#[derive(Debug, Clone)]
14+
pub struct StringModel {
15+
alphabet: Vec<char>,
16+
fenwick_model: FenwickModel,
17+
}
18+
19+
impl StringModel {
20+
#[must_use]
21+
pub fn new(alphabet: Vec<char>) -> Self {
22+
let fenwick_model = FenwickModel::with_symbols(alphabet.len());
23+
Self {
24+
alphabet,
25+
fenwick_model,
26+
}
27+
}
28+
}
29+
30+
#[derive(Debug, thiserror::Error)]
31+
#[error("invalid character: {0}")]
32+
pub struct Error(char);
33+
34+
impl Model for StringModel {
35+
type B = u64;
36+
type Symbol = char;
37+
type ValueError = ValueError;
38+
39+
fn probability(
40+
&self,
41+
symbol: Option<&Self::Symbol>,
42+
) -> Result<Range<Self::B>, Self::ValueError> {
43+
let fenwick_symbol = symbol.map(|c| self.alphabet.iter().position(|x| x == c).unwrap());
44+
self.fenwick_model.probability(fenwick_symbol.as_ref())
45+
}
46+
47+
fn symbol(&self, value: Self::B) -> Option<Self::Symbol> {
48+
let index = self.fenwick_model.symbol(value)?;
49+
self.alphabet.get(index).copied()
50+
}
51+
52+
fn max_denominator(&self) -> Self::B {
53+
self.fenwick_model.max_denominator()
54+
}
55+
56+
fn denominator(&self) -> Self::B {
57+
self.fenwick_model.denominator()
58+
}
59+
60+
fn update(&mut self, symbol: Option<&Self::Symbol>) {
61+
let fenwick_symbol = symbol.map(|c| self.alphabet.iter().position(|x| x == c).unwrap());
62+
self.fenwick_model.update(fenwick_symbol.as_ref())
63+
}
64+
}
65+
66+
fn main() {
67+
let model = StringModel::new(ALPHABET.chars().collect());
68+
69+
let mut input = String::new();
70+
File::open("./resources/sherlock.txt")
71+
.unwrap()
72+
.read_to_string(&mut input)
73+
.unwrap();
74+
75+
common::round_trip_string(model, input);
76+
}

0 commit comments

Comments
 (0)