diff --git a/der/src/asn1/bit_string.rs b/der/src/asn1/bit_string.rs index f171ba681..8fd9b63d3 100644 --- a/der/src/asn1/bit_string.rs +++ b/der/src/asn1/bit_string.rs @@ -7,6 +7,7 @@ use crate::{ Result, Tag, ValueOrd, Writer, }; use core::{cmp::Ordering, iter::FusedIterator}; +use unused_bits::UnusedBits; #[cfg(feature = "flagset")] use core::mem::size_of_val; @@ -20,40 +21,32 @@ use core::mem::size_of_val; #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] pub struct BitStringRef<'a> { /// Number of unused bits in the final octet. - unused_bits: u8, - - /// Length of this `BIT STRING` in bits. - bit_length: usize, + unused_bits: UnusedBits, /// Bitstring represented as a slice of bytes. inner: &'a BytesRef, } impl<'a> BitStringRef<'a> { - /// Maximum number of unused bits allowed. - pub const MAX_UNUSED_BITS: u8 = 7; - /// Create a new ASN.1 `BIT STRING` from a byte slice. /// /// Accepts an optional number of "unused bits" (0-7) which are omitted /// from the final octet. This number is 0 if the value is octet-aligned. pub fn new(unused_bits: u8, bytes: &'a [u8]) -> Result { - if (unused_bits > Self::MAX_UNUSED_BITS) || (unused_bits != 0 && bytes.is_empty()) { - return Err(Self::TAG.value_error().into()); - } - + let unused_bits = UnusedBits::new(unused_bits, bytes)?; let inner = BytesRef::new(bytes).map_err(|_| Self::TAG.length_error())?; + let value = Self::new_unchecked(unused_bits, inner); + value + .bit_len_checked() + .ok_or_else(|| Error::from(ErrorKind::Overflow))?; + Ok(value) + } - let bit_length = usize::try_from(inner.len())? - .checked_mul(8) - .and_then(|n| n.checked_sub(usize::from(unused_bits))) - .ok_or(ErrorKind::Overflow)?; - - Ok(Self { - unused_bits, - bit_length, - inner, - }) + /// Internal function. Assumptions: + /// - [`UnusedBits`] was checked for given [`BytesRef`], + /// - [`BitStringRef::bit_len_checked`] was called and returned `Ok`. + pub(crate) fn new_unchecked(unused_bits: UnusedBits, inner: &'a BytesRef) -> Self { + Self { unused_bits, inner } } /// Create a new ASN.1 `BIT STRING` from the given bytes. @@ -65,17 +58,31 @@ impl<'a> BitStringRef<'a> { /// Get the number of unused bits in this byte slice. pub fn unused_bits(&self) -> u8 { - self.unused_bits + *self.unused_bits } /// Is the number of unused bits a value other than 0? pub fn has_unused_bits(&self) -> bool { - self.unused_bits != 0 + *self.unused_bits != 0 + } + + /// Get the length of this `BIT STRING` in bits, or `None` if the value overflows. + /// + /// Ensured to be valid in the constructor. + fn bit_len_checked(&self) -> Option { + usize::try_from(self.inner.len()) + .ok() + .and_then(|n| n.checked_mul(8)) + .and_then(|n| n.checked_sub(usize::from(*self.unused_bits))) } /// Get the length of this `BIT STRING` in bits. pub fn bit_len(&self) -> usize { - self.bit_length + let bit_len = self.bit_len_checked(); + debug_assert!(bit_len.is_some()); + + // Ensured to be valid in the constructor. + bit_len.unwrap_or(0) } /// Get the number of bytes/octets needed to represent this `BIT STRING` @@ -156,7 +163,7 @@ impl EncodeValue for BitStringRef<'_> { } fn encode_value(&self, writer: &mut impl Writer) -> Result<()> { - writer.write_byte(self.unused_bits)?; + writer.write_byte(*self.unused_bits)?; writer.write(self.raw_bytes()) } } @@ -226,13 +233,48 @@ impl FixedTag for BitStringRef<'_> { const TAG: Tag = Tag::BitString; } +/// Sealed, so that `UnusedBits` newtype can't be created directly +mod unused_bits { + use core::ops::Deref; + + use crate::{Result, Tag}; + + /// Value in range `0..=7` + /// + /// Must be zero for empty `BIT STRING`. + #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] + pub(crate) struct UnusedBits(u8); + + impl UnusedBits { + /// Maximum number of unused bits allowed. + pub const MAX_UNUSED_BITS: u8 = 7; + + /// Represents number of "unused bits" (0-7) in `BIT STRING` which are omitted + /// from the final octet. This number is 0 if the value is octet-aligned. + pub fn new(unused_bits: u8, bytes: &[u8]) -> Result { + if (unused_bits > Self::MAX_UNUSED_BITS) || (unused_bits != 0 && bytes.is_empty()) { + Err(Tag::BitString.value_error().into()) + } else { + Ok(Self(unused_bits)) + } + } + } + impl Deref for UnusedBits { + type Target = u8; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } +} + // Implement by hand because the derive would create invalid values. // Use the constructor to create a valid value. #[cfg(feature = "arbitrary")] impl<'a> arbitrary::Arbitrary<'a> for BitStringRef<'a> { fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { Self::new( - u.int_in_range(0..=Self::MAX_UNUSED_BITS)?, + u.int_in_range(0..=UnusedBits::MAX_UNUSED_BITS)?, <&'a BytesRef>::arbitrary(u)?.as_slice(), ) .map_err(|_| arbitrary::Error::IncorrectFormat) @@ -259,10 +301,7 @@ mod allocating { #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] pub struct BitString { /// Number of unused bits in the final octet. - unused_bits: u8, - - /// Length of this `BIT STRING` in bits. - bit_length: usize, + unused_bits: UnusedBits, /// Bitstring represented as a slice of bytes. inner: Vec, @@ -280,11 +319,10 @@ mod allocating { let inner = bytes.into(); // Ensure parameters parse successfully as a `BitStringRef`. - let bit_length = BitStringRef::new(unused_bits, &inner)?.bit_length; + let ref_value = BitStringRef::new(unused_bits, &inner)?; Ok(BitString { - unused_bits, - bit_length, + unused_bits: ref_value.unused_bits, inner, }) } @@ -299,17 +337,34 @@ mod allocating { /// Get the number of unused bits in the octet serialization of this /// `BIT STRING`. pub fn unused_bits(&self) -> u8 { - self.unused_bits + *self.unused_bits } /// Is the number of unused bits a value other than 0? pub fn has_unused_bits(&self) -> bool { - self.unused_bits != 0 + *self.unused_bits != 0 + } + + /// Returns inner [`BytesRef`] slice. + pub(crate) fn bytes_ref(&self) -> &BytesRef { + // Ensured to parse successfully in constructor + BytesRef::new_unchecked(&self.inner) + } + + /// Get the length of this `BIT STRING` in bits, or `None` if the value overflows. + /// + /// Ensured to be valid in the constructor. + fn bit_len_checked(&self) -> Option { + BitStringRef::new_unchecked(self.unused_bits, self.bytes_ref()).bit_len_checked() } /// Get the length of this `BIT STRING` in bits. pub fn bit_len(&self) -> usize { - self.bit_length + let bit_len = self.bit_len_checked(); + debug_assert!(bit_len.is_some()); + + // Ensured to be valid in the constructor. + bit_len.unwrap_or(0) } /// Is the inner byte slice empty? @@ -367,7 +422,7 @@ mod allocating { } fn encode_value(&self, writer: &mut impl Writer) -> Result<()> { - writer.write_byte(self.unused_bits)?; + writer.write_byte(*self.unused_bits)?; writer.write(&self.inner) } } @@ -379,8 +434,7 @@ mod allocating { impl<'a> From<&'a BitString> for BitStringRef<'a> { fn from(bit_string: &'a BitString) -> BitStringRef<'a> { // Ensured to parse successfully in constructor - BitStringRef::new(bit_string.unused_bits, &bit_string.inner) - .expect("invalid BIT STRING") + BitStringRef::new_unchecked(bit_string.unused_bits, bit_string.bytes_ref()) } } @@ -436,7 +490,6 @@ mod allocating { fn ref_to_owned(&self) -> Self::Owned { BitString { unused_bits: self.unused_bits, - bit_length: self.bit_length, inner: Vec::from(self.inner.as_slice()), } }