From 42cdf44ca6bb0f6ad4604495ec9d24b3c5236ef6 Mon Sep 17 00:00:00 2001 From: steschu77 Date: Fri, 14 Nov 2025 17:27:25 +0100 Subject: [PATCH] Improve coefficient sign decoding using the fact that range is always within 128..254. * Add sign decoding to arithmetic decoder. * Add narrower limits for asserting `range`. * Improved flag decoding by limiting `shift` to be 0 (for range=128) or 1 (for range in 64..127) --- src/vp8.rs | 2 +- src/vp8_arithmetic_decoder.rs | 88 +++++++++++++++++++++++++++++++---- 2 files changed, 80 insertions(+), 10 deletions(-) diff --git a/src/vp8.rs b/src/vp8.rs index a8d48f7..33392fa 100644 --- a/src/vp8.rs +++ b/src/vp8.rs @@ -852,7 +852,7 @@ impl Vp8Decoder { 2 }; - if decoder.read_flag().or_accumulate(&mut res) { + if decoder.read_sign().or_accumulate(&mut res) { abs_value = -abs_value; } diff --git a/src/vp8_arithmetic_decoder.rs b/src/vp8_arithmetic_decoder.rs index 7f8c4be..a25d3e9 100644 --- a/src/vp8_arithmetic_decoder.rs +++ b/src/vp8_arithmetic_decoder.rs @@ -175,6 +175,16 @@ impl ArithmeticDecoder { self.cold_read_flag() } + // Do not inline this because inlining seems to worsen performance. + #[inline(never)] + pub(crate) fn read_sign(&mut self) -> BitResult { + if let Some(b) = self.fast().read_sign() { + return BitResult::ok(b); + } + + self.cold_read_flag() + } + // Do not inline this because inlining seems to worsen performance. #[inline(never)] pub(crate) fn read_literal(&mut self, n: u8) -> BitResult { @@ -399,6 +409,11 @@ impl FastDecoder<'_> { self.commit_if_valid(value) } + fn read_sign(mut self) -> Option { + let value = self.fast_read_sign(); + self.commit_if_valid(value) + } + fn read_literal(mut self, n: u8) -> Option { let value = self.fast_read_literal(n); self.commit_if_valid(value) @@ -448,6 +463,7 @@ impl FastDecoder<'_> { } debug_assert!(bit_count >= 0); + debug_assert!((128..=255).contains(&range)); let probability = u32::from(probability); let split = 1 + (((range - 1) * probability) >> 8); let bigsplit = u64::from(split) << bit_count; @@ -460,18 +476,18 @@ impl FastDecoder<'_> { range = split; false }; - debug_assert!(range > 0); // Compute shift required to satisfy `range >= 128`. // Apply that shift to `range` and `self.bitcount`. // // Subtract 24 because we only care about leading zeros in the // lowest byte of `range` which is a `u32`. + debug_assert!((1..=254).contains(&range)); let shift = range.leading_zeros().saturating_sub(24); range <<= shift; bit_count -= shift as i32; - debug_assert!(range >= 128); + debug_assert!((128..=254).contains(&range)); self.uncommitted_state = State { chunk_index, value, @@ -504,6 +520,7 @@ impl FastDecoder<'_> { } debug_assert!(bit_count >= 0); + debug_assert!((128..=255).contains(&range)); let half_range = range / 2; let split = range - half_range; let bigsplit = u64::from(split) << bit_count; @@ -516,18 +533,71 @@ impl FastDecoder<'_> { range = split; false }; - debug_assert!(range > 0); // Compute shift required to satisfy `range >= 128`. + // A `range` of 64..127 requires a shift of 1. No shift if `range` is 128. // Apply that shift to `range` and `self.bitcount`. - // - // Subtract 24 because we only care about leading zeros in the - // lowest byte of `range` which is a `u32`. - let shift = range.leading_zeros().saturating_sub(24); + debug_assert!((64..=128).contains(&range)); + let shift = if range == 0x80 { 0 } else { 1 }; range <<= shift; - bit_count -= shift as i32; - debug_assert!(range >= 128); + bit_count -= shift; + + debug_assert!((128..=254).contains(&range)); + self.uncommitted_state = State { + chunk_index, + value, + range, + bit_count, + }; + retval + } + + fn fast_read_sign(&mut self) -> bool { + let State { + mut chunk_index, + mut value, + mut range, + mut bit_count, + } = self.uncommitted_state; + + if bit_count < 0 { + let chunk = self.chunks.get(chunk_index).copied(); + // We ignore invalid data inside the `fast_` functions, + // but we increase `chunk_index` below, so we can check + // whether we read invalid data in `commit_if_valid`. + let chunk = chunk.unwrap_or_default(); + + let v = u32::from_be_bytes(chunk); + chunk_index += 1; + value <<= 32; + value |= u64::from(v); + bit_count += 32; + } + + // Range is only 255 at the start of decoding. After reading any symbol, it is guaranteed + // to be at most 254. Sign bits are never the first symbol in a bit stream. + debug_assert!((128..=254).contains(&range)); + let half_range = range / 2; + let split = range - half_range; + let bigsplit = u64::from(split) << bit_count; + + let retval = if let Some(new_value) = value.checked_sub(bigsplit) { + range = half_range; + value = new_value; + true + } else { + range = split; + false + }; + + // Compute shift required to satisfy `range >= 128`. + // Since `range` lies in 64..127 it always requires a shift of 1. + // Apply that shift to `range` and `self.bitcount`. + debug_assert!((64..=127).contains(&range)); + range <<= 1; + bit_count -= 1; + debug_assert!((128..=254).contains(&range)); self.uncommitted_state = State { chunk_index, value,