diff --git a/der/src/encode_cached.rs b/der/src/encode_cached.rs new file mode 100644 index 000000000..d87666b8b --- /dev/null +++ b/der/src/encode_cached.rs @@ -0,0 +1,189 @@ +use crate::{DecodeValue, EncodeValue, FixedTag, Header, Length, Result, Tag, ValueOrd, Writer}; +use core::{cell::Cell, cmp::Ordering, fmt}; + +/// Caches once-computed length of the object, when encoding big data structures. +/// +/// For example `Vec>>` won't need to calculate inner `Vec`'s length twice. +/// +/// Warning: users of this type should call [`EncodeValueLenCached::clear_len_cache`] on all objects before using encoding. +/// Otherwise, any errors during [`EncodeValue::encode_value`] will make the cache invalid! +/// +/// ```rust +/// use der::{asn1::SequenceOf, Encode, EncodeValueLenCached}; +/// let mut big_vec = SequenceOf::>, 1>::new(); +/// let mut inner_vec = SequenceOf::new(); +/// for _ in 0..128 { +/// inner_vec.add(()); +/// } +/// big_vec.add(inner_vec.into()); +/// +/// let mut buf = [0u8; 300]; +/// +/// // Ensure, that length cache is clear. +/// for cached in big_vec.iter() { +/// cached.clear_len_cache(); +/// } +/// // Here, inner SequenceOf calculates it's value length once +/// big_vec.encode_to_slice(&mut buf).unwrap(); +/// ``` +pub struct EncodeValueLenCached { + cached_len: Cell>, + + /// Object, that might implement [`EncodeValue`], [`DecodeValue`] or both. + pub value: T, +} + +impl EncodeValueLenCached { + /// Clears cache, in cases when [`EncodeValue::value_len`] was called by accident, + /// + /// without subsequent [`EncodeValue::encode_value`]. + pub fn clear_len_cache(&self) { + self.cached_len.set(None) + } +} + +impl Clone for EncodeValueLenCached { + fn clone(&self) -> Self { + Self { + cached_len: Cell::new(None), + value: self.value.clone(), + } + } +} + +impl Default for EncodeValueLenCached { + fn default() -> Self { + Self { + cached_len: Cell::new(None), + value: Default::default(), + } + } +} + +impl fmt::Debug for EncodeValueLenCached { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.value.fmt(f) + } +} + +impl fmt::Display for EncodeValueLenCached { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.value.fmt(f) + } +} + +impl AsRef for EncodeValueLenCached { + fn as_ref(&self) -> &T { + &self.value + } +} + +impl EncodeValue for EncodeValueLenCached +where + T: EncodeValue, +{ + fn value_len(&self) -> Result { + // Prevent calculating the same length twice + if let Some(len) = self.cached_len.get() { + return Ok(len); + } + let len = self.value.value_len()?; + self.cached_len.set(Some(len)); + Ok(len) + } + + fn encode_value(&self, writer: &mut impl Writer) -> Result<()> { + // Cached length won't be needed in this encoding pass again, so clear it. + // Also, this prevents bugs where internal data changes but the length does not. + self.cached_len.set(None); + self.value.encode_value(writer) + } +} + +impl<'a, T> DecodeValue<'a> for EncodeValueLenCached +where + T: DecodeValue<'a>, +{ + type Error = T::Error; + + fn decode_value>( + reader: &mut R, + header: Header, + ) -> core::result::Result { + Ok(EncodeValueLenCached { + cached_len: Cell::new(None), + value: T::decode_value(reader, header)?, + }) + } +} + +impl ValueOrd for EncodeValueLenCached +where + T: ValueOrd, +{ + fn value_cmp(&self, other: &Self) -> Result { + self.value.value_cmp(&other.value) + } +} + +// FixedTag is more important than Tagged, because FixedTag is used by Choice macro +impl FixedTag for EncodeValueLenCached { + const TAG: Tag = T::TAG; +} + +impl From for EncodeValueLenCached { + fn from(value: T) -> Self { + Self { + cached_len: Cell::new(None), + value, + } + } +} + +#[cfg(test)] +#[cfg(feature = "std")] +mod test { + use core::cell::Cell; + use std::vec::Vec; + + use crate::{Encode, EncodeValue, EncodeValueLenCached, FixedTag, Length, Result, Tag, Writer}; + + #[derive(Clone, Default)] + struct SusString { + len_query_counter: Cell, + } + + impl EncodeValue for SusString { + #[allow(clippy::panic, clippy::panic_in_result_fn)] + fn value_len(&self) -> Result { + let counter = self.len_query_counter.get(); + if counter >= 2 { + panic!("value_len called more than twice"); + } + self.len_query_counter.set(counter + 1); + Ok(Length::new(1)) + } + + fn encode_value(&self, encoder: &mut impl Writer) -> Result<()> { + encoder.write_byte(0xAA)?; + Ok(()) + } + } + impl FixedTag for SusString { + const TAG: Tag = Tag::Utf8String; + } + + /// Inner `SusString` objects should calculate it's length only twice. + /// + /// Once when encoding outer SEQUENCE, second time for itself. + #[test] + fn value_len_called_2_times() { + let big_vec: Vec>> = + vec![vec![SusString::default(); 1000].into()]; + + let bigger_vec = vec![big_vec]; + bigger_vec.to_der().expect("to_der"); + + assert_eq!(2, bigger_vec[0][0].as_ref()[0].len_query_counter.get()); + } +} diff --git a/der/src/lib.rs b/der/src/lib.rs index d867ef928..120e69047 100644 --- a/der/src/lib.rs +++ b/der/src/lib.rs @@ -346,6 +346,7 @@ mod bytes; mod datetime; mod decode; mod encode; +mod encode_cached; mod encode_ref; mod encoding_rules; mod error; @@ -366,6 +367,7 @@ pub use crate::{ datetime::DateTime, decode::{Decode, DecodeOwned, DecodeValue}, encode::{Encode, EncodeValue}, + encode_cached::EncodeValueLenCached, encode_ref::{EncodeRef, EncodeValueRef}, encoding_rules::EncodingRules, error::{Error, ErrorKind, Result},