|
| 1 | +//! Helper trait for creating fixed-length Models |
| 2 | +
|
| 3 | +use std::ops::Range; |
| 4 | + |
| 5 | +use crate::BitStore; |
| 6 | + |
| 7 | +/// A [`Model`] is used to calculate the probability of a given symbol occuring |
| 8 | +/// in a sequence. The [`Model`] is used both for encoding and decoding. A |
| 9 | +/// 'max-length' model has a maximum length. The compressed size of a message |
| 10 | +/// equal to the maximum length is larger than with a |
| 11 | +/// [`fixed_length::Model`](crate::fixed_length::Model), but smaller than with a |
| 12 | +/// [`Model`](crate::Model). |
| 13 | +/// |
| 14 | +/// A max-length model can be converted into a regular model using the |
| 15 | +/// convenience [`Wrapper`] type. |
| 16 | +/// |
| 17 | +/// The more accurately a [`Model`] is able to predict the next symbol, the |
| 18 | +/// greater the compression ratio will be. |
| 19 | +/// |
| 20 | +/// # Example |
| 21 | +/// |
| 22 | +/// ``` |
| 23 | +/// #![feature(exclusive_range_pattern)] |
| 24 | +/// #![feature(never_type)] |
| 25 | +/// # use std::ops::Range; |
| 26 | +/// # |
| 27 | +/// # use arithmetic_coding_core::max_length; |
| 28 | +/// |
| 29 | +/// pub enum Symbol { |
| 30 | +/// A, |
| 31 | +/// B, |
| 32 | +/// C, |
| 33 | +/// } |
| 34 | +/// |
| 35 | +/// pub struct MyModel; |
| 36 | +/// |
| 37 | +/// impl max_length::Model for MyModel { |
| 38 | +/// type Symbol = Symbol; |
| 39 | +/// type ValueError = !; |
| 40 | +/// |
| 41 | +/// fn probability(&self, symbol: Option<&Self::Symbol>) -> Result<Range<u32>, !> { |
| 42 | +/// Ok(match symbol { |
| 43 | +/// Some(Symbol::A) => 0..1, |
| 44 | +/// Some(Symbol::B) => 1..2, |
| 45 | +/// Some(Symbol::C) => 2..3, |
| 46 | +/// None => 3..4, |
| 47 | +/// }) |
| 48 | +/// } |
| 49 | +/// |
| 50 | +/// fn symbol(&self, value: Self::B) -> Option<Self::Symbol> { |
| 51 | +/// match value { |
| 52 | +/// 0..1 => Some(Symbol::A), |
| 53 | +/// 1..2 => Some(Symbol::B), |
| 54 | +/// 2..3 => Some(Symbol::C), |
| 55 | +/// 3..4 => None, |
| 56 | +/// _ => unreachable!(), |
| 57 | +/// } |
| 58 | +/// } |
| 59 | +/// |
| 60 | +/// fn max_denominator(&self) -> u32 { |
| 61 | +/// 4 |
| 62 | +/// } |
| 63 | +/// |
| 64 | +/// fn max_length(&self) -> usize { |
| 65 | +/// 3 |
| 66 | +/// } |
| 67 | +/// } |
| 68 | +/// ``` |
| 69 | +pub trait Model { |
| 70 | + /// The type of symbol this [`Model`] describes |
| 71 | + type Symbol; |
| 72 | + |
| 73 | + /// Invalid symbol error |
| 74 | + type ValueError: std::error::Error; |
| 75 | + |
| 76 | + /// The internal representation to use for storing integers |
| 77 | + type B: BitStore = u32; |
| 78 | + |
| 79 | + /// Given a symbol, return an interval representing the probability of that |
| 80 | + /// symbol occurring. |
| 81 | + /// |
| 82 | + /// This is given as a range, over the denominator given by |
| 83 | + /// [`Model::denominator`]. This range should in general include `EOF`, |
| 84 | + /// which is denoted by `None`. |
| 85 | + /// |
| 86 | + /// For example, from the set {heads, tails}, the interval representing |
| 87 | + /// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be |
| 88 | + /// `2..3` (with a denominator of `3`). |
| 89 | + /// |
| 90 | + /// This is the inverse of the [`Model::symbol`] method |
| 91 | + /// |
| 92 | + /// # Errors |
| 93 | + /// |
| 94 | + /// This returns a custom error if the given symbol is not valid |
| 95 | + fn probability( |
| 96 | + &self, |
| 97 | + symbol: Option<&Self::Symbol>, |
| 98 | + ) -> Result<Range<Self::B>, Self::ValueError>; |
| 99 | + |
| 100 | + /// The denominator for probability ranges. See [`Model::probability`]. |
| 101 | + /// |
| 102 | + /// By default this method simply returns the [`Model::max_denominator`], |
| 103 | + /// which is suitable for non-adaptive models. |
| 104 | + /// |
| 105 | + /// In adaptive models this value may change, however it should never exceed |
| 106 | + /// [`Model::max_denominator`], or it becomes possible for the |
| 107 | + /// [`Encoder`](crate::Encoder) and [`Decoder`](crate::Decoder) to panic due |
| 108 | + /// to overflow or underflow. |
| 109 | + fn denominator(&self) -> Self::B { |
| 110 | + self.max_denominator() |
| 111 | + } |
| 112 | + |
| 113 | + /// The maximum denominator used for probability ranges. See |
| 114 | + /// [`Model::probability`]. |
| 115 | + /// |
| 116 | + /// This value is used to calculate an appropriate precision for the |
| 117 | + /// encoding, therefore this value must not change, and |
| 118 | + /// [`Model::denominator`] must never exceed it. |
| 119 | + fn max_denominator(&self) -> Self::B; |
| 120 | + |
| 121 | + /// Given a value, return the symbol whose probability range it falls in. |
| 122 | + /// |
| 123 | + /// `None` indicates `EOF` |
| 124 | + /// |
| 125 | + /// This is the inverse of the [`Model::probability`] method |
| 126 | + fn symbol(&self, value: Self::B) -> Option<Self::Symbol>; |
| 127 | + |
| 128 | + /// Update the current state of the model with the latest symbol. |
| 129 | + /// |
| 130 | + /// This method only needs to be implemented for 'adaptive' models. It's a |
| 131 | + /// no-op by default. |
| 132 | + fn update(&mut self, _symbol: &Self::Symbol) {} |
| 133 | + |
| 134 | + /// The maximum number of symbols to encode |
| 135 | + fn max_length(&self) -> usize; |
| 136 | +} |
| 137 | + |
| 138 | +/// A wrapper which converts a [`fixed_length::Model`](Model) to a |
| 139 | +/// [`crate::Model`]. |
| 140 | +#[derive(Debug, Clone)] |
| 141 | +pub struct Wrapper<M> |
| 142 | +where |
| 143 | + M: Model, |
| 144 | +{ |
| 145 | + model: M, |
| 146 | + remaining: usize, |
| 147 | +} |
| 148 | + |
| 149 | +impl<M> Wrapper<M> |
| 150 | +where |
| 151 | + M: Model, |
| 152 | +{ |
| 153 | + /// Construct a new wrapper from a [`fixed_length::Model`](Model) |
| 154 | + pub fn new(model: M) -> Self { |
| 155 | + let remaining = model.max_length(); |
| 156 | + Self { model, remaining } |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +impl<M> crate::Model for Wrapper<M> |
| 161 | +where |
| 162 | + M: Model, |
| 163 | +{ |
| 164 | + type B = M::B; |
| 165 | + type Symbol = M::Symbol; |
| 166 | + type ValueError = Error<M::ValueError>; |
| 167 | + |
| 168 | + fn probability( |
| 169 | + &self, |
| 170 | + symbol: Option<&Self::Symbol>, |
| 171 | + ) -> Result<Range<Self::B>, Self::ValueError> { |
| 172 | + if self.remaining == 0 { |
| 173 | + if symbol.is_some() { |
| 174 | + Err(Error::UnexpectedSymbol) |
| 175 | + } else { |
| 176 | + // got an EOF when we expected it, return a 100% probability |
| 177 | + Ok(Self::B::ZERO..self.denominator()) |
| 178 | + } |
| 179 | + } else { |
| 180 | + self.model |
| 181 | + .probability(symbol) |
| 182 | + .map_err(Self::ValueError::Value) |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + fn max_denominator(&self) -> Self::B { |
| 187 | + self.model.max_denominator() |
| 188 | + } |
| 189 | + |
| 190 | + fn symbol(&self, value: Self::B) -> Option<Self::Symbol> { |
| 191 | + if self.remaining > 0 { |
| 192 | + self.model.symbol(value) |
| 193 | + } else { |
| 194 | + None |
| 195 | + } |
| 196 | + } |
| 197 | + |
| 198 | + fn denominator(&self) -> Self::B { |
| 199 | + self.model.denominator() |
| 200 | + } |
| 201 | + |
| 202 | + fn update(&mut self, symbol: Option<&Self::Symbol>) { |
| 203 | + if let Some(s) = symbol { |
| 204 | + self.model.update(s); |
| 205 | + self.remaining -= 1; |
| 206 | + } |
| 207 | + } |
| 208 | +} |
| 209 | + |
| 210 | +/// Fixed-length encoding/decoding errors |
| 211 | +#[derive(Debug, thiserror::Error)] |
| 212 | +pub enum Error<E> |
| 213 | +where |
| 214 | + E: std::error::Error, |
| 215 | +{ |
| 216 | + /// Model received a symbol when it expected an EOF |
| 217 | + #[error("Unexpected Symbol")] |
| 218 | + UnexpectedSymbol, |
| 219 | + |
| 220 | + /// The model received an invalid symbol |
| 221 | + #[error(transparent)] |
| 222 | + Value(E), |
| 223 | +} |
0 commit comments