diff --git a/bounded-collections/src/bounded_btree_map.rs b/bounded-collections/src/bounded_btree_map.rs index 09a17dcb..4ff591db 100644 --- a/bounded-collections/src/bounded_btree_map.rs +++ b/bounded-collections/src/bounded_btree_map.rs @@ -430,48 +430,11 @@ where macro_rules! codec_impl { ($codec:ident) => { use super::*; + use crate::codec_utils::PrependCompactInput; use $codec::{ Compact, Decode, DecodeLength, DecodeWithMemTracking, Encode, EncodeLike, Error, Input, MaxEncodedLen, }; - // Struct which allows prepending the compact after reading from an input. - pub(crate) struct PrependCompactInput<'a, I> { - pub encoded_len: &'a [u8], - pub read: usize, - pub inner: &'a mut I, - } - - impl<'a, I: Input> Input for PrependCompactInput<'a, I> { - fn remaining_len(&mut self) -> Result, Error> { - let remaining_compact = self.encoded_len.len().saturating_sub(self.read); - Ok(self.inner.remaining_len()?.map(|len| len.saturating_add(remaining_compact))) - } - - fn read(&mut self, into: &mut [u8]) -> Result<(), Error> { - if into.is_empty() { - return Ok(()); - } - - let remaining_compact = self.encoded_len.len().saturating_sub(self.read); - if remaining_compact > 0 { - let to_read = into.len().min(remaining_compact); - into[..to_read].copy_from_slice(&self.encoded_len[self.read..][..to_read]); - self.read += to_read; - - if to_read < into.len() { - // Buffer not full, keep reading the inner. - self.inner.read(&mut into[to_read..]) - } else { - // Buffer was filled by the compact. - Ok(()) - } - } else { - // Prepended compact has been read, just read from inner. - self.inner.read(into) - } - } - } - impl Decode for BoundedBTreeMap where K: Decode + Ord, @@ -480,13 +443,13 @@ macro_rules! codec_impl { { fn decode(input: &mut I) -> Result { // Fail early if the len is too big. This is a compact u32 which we will later put back. - let compact = >::decode(input)?; - if compact.0 > S::get() { + let len = >::decode(input)?; + if len.0 > S::get() { return Err("BoundedBTreeMap exceeds its limit".into()); } // Reconstruct the original input by prepending the length we just read, then delegate the decoding to BTreeMap. let inner = BTreeMap::decode(&mut PrependCompactInput { - encoded_len: compact.encode().as_ref(), + encoded_len: len.encode().as_ref(), read: 0, inner: input, })?; @@ -549,9 +512,7 @@ mod test { use crate::ConstU32; use alloc::{vec, vec::Vec}; #[cfg(feature = "scale-codec")] - use scale_codec::{Compact, CompactLen, Decode, Encode, Input}; - #[cfg(feature = "scale-codec")] - use scale_codec_impl::PrependCompactInput; + use scale_codec::{Compact, CompactLen, Decode, Encode}; fn map_from_keys(keys: &[K]) -> BTreeMap where @@ -805,59 +766,6 @@ mod test { assert_eq!(Ok(b2), b1.try_map(|(_, v)| (v as u16).checked_mul(100_u16).ok_or("overflow"))); } - #[test] - #[cfg(feature = "scale-codec")] - fn prepend_compact_input_works() { - let encoded_len = Compact(3u32).encode(); - let inner = [2, 3, 4]; - let mut input = PrependCompactInput { encoded_len: encoded_len.as_ref(), read: 0, inner: &mut &inner[..] }; - assert_eq!(input.remaining_len(), Ok(Some(4))); - - // Passing an empty buffer should leave input unchanged. - let mut empty_buf = []; - assert_eq!(input.read(&mut empty_buf), Ok(())); - assert_eq!(input.remaining_len(), Ok(Some(4))); - assert_eq!(input.read, 0); - - // Passing a correctly-sized buffer will read correctly. - let mut buf = [0; 4]; - assert_eq!(input.read(&mut buf), Ok(())); - assert_eq!(buf[0], encoded_len[0]); - assert_eq!(buf[1..], inner[..]); - // And the bookkeeping agrees. - assert_eq!(input.remaining_len(), Ok(Some(0))); - assert_eq!(input.read, encoded_len.len()); - - // And we can't read more. - assert!(input.read(&mut buf).is_err()); - } - - #[test] - #[cfg(feature = "scale-codec")] - fn prepend_compact_input_incremental_read_works() { - let encoded_len = Compact(3u32).encode(); - let inner = [2, 3, 4]; - let mut input = PrependCompactInput { encoded_len: encoded_len.as_ref(), read: 0, inner: &mut &inner[..] }; - assert_eq!(input.remaining_len(), Ok(Some(4))); - - // Compact is first byte - ensure that it fills the buffer when it's more than one. - let mut buf = [0u8; 2]; - assert_eq!(input.read(&mut buf), Ok(())); - assert_eq!(buf[0], encoded_len[0]); - assert_eq!(buf[1], inner[0]); - assert_eq!(input.remaining_len(), Ok(Some(2))); - assert_eq!(input.read, encoded_len.len()); - - // Check the last two bytes are read correctly. - assert_eq!(input.read(&mut buf), Ok(())); - assert_eq!(buf[..], inner[1..]); - assert_eq!(input.remaining_len(), Ok(Some(0))); - assert_eq!(input.read, encoded_len.len()); - - // And we can't read more. - assert!(input.read(&mut buf).is_err()); - } - // Just a test that structs containing `BoundedBTreeMap` can derive `Hash`. (This was broken // when it was deriving `Hash`). #[test] diff --git a/bounded-collections/src/bounded_btree_set.rs b/bounded-collections/src/bounded_btree_set.rs index fc393325..4f8b6d6f 100644 --- a/bounded-collections/src/bounded_btree_set.rs +++ b/bounded-collections/src/bounded_btree_set.rs @@ -359,22 +359,28 @@ where macro_rules! codec_impl { ($codec:ident) => { use super::*; - use $codec::{Compact, Decode, DecodeLength, Encode, EncodeLike, Error, Input, MaxEncodedLen}; + use crate::codec_utils::PrependCompactInput; + use $codec::{ + Compact, Decode, DecodeLength, DecodeWithMemTracking, Encode, EncodeLike, Error, Input, MaxEncodedLen, + }; + impl Decode for BoundedBTreeSet where T: Decode + Ord, S: Get, { fn decode(input: &mut I) -> Result { - // Same as the underlying implementation for `Decode` on `BTreeSet`, except we fail early if - // the len is too big. - let len: u32 = >::decode(input)?.into(); - if len > S::get() { + // Fail early if the len is too big. This is a compact u32 which we will later put back. + let len = >::decode(input)?; + if len.0 > S::get() { return Err("BoundedBTreeSet exceeds its limit".into()); } - input.descend_ref()?; - let inner = Result::from_iter((0..len).map(|_| Decode::decode(input)))?; - input.ascend_ref(); + // Reconstruct the original input by prepending the length we just read, then delegate the decoding to BTreeMap. + let inner = BTreeSet::decode(&mut PrependCompactInput { + encoded_len: len.encode().as_ref(), + read: 0, + inner: input, + })?; Ok(Self(inner, PhantomData)) } @@ -405,6 +411,13 @@ macro_rules! codec_impl { } impl EncodeLike> for BoundedBTreeSet where BTreeSet: Encode {} + + impl DecodeWithMemTracking for BoundedBTreeSet + where + T: Decode + Ord, + S: Get, + { + } }; } diff --git a/bounded-collections/src/codec_utils.rs b/bounded-collections/src/codec_utils.rs new file mode 100644 index 00000000..6e84733b --- /dev/null +++ b/bounded-collections/src/codec_utils.rs @@ -0,0 +1,157 @@ +// This file is part of Substrate. + +// Copyright (C) 2023 Parity Technologies (UK) Ltd. +// SPDX-License-Identifier: Apache-2.0 + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Shared codec utilities for bounded collections. + +/// Struct which allows prepending the compact after reading from an input. +/// +/// This is used internally by bounded collections to reconstruct the original +/// input stream after reading the length prefix during decoding. + +#[cfg(any(feature = "scale-codec", feature = "jam-codec"))] +pub struct PrependCompactInput<'a, I> { + pub encoded_len: &'a [u8], + pub read: usize, + pub inner: &'a mut I, +} + +/// Macro to implement Input trait for PrependCompactInput for different codec crates +#[cfg(any(feature = "scale-codec", feature = "jam-codec"))] +macro_rules! impl_prepend_compact_input { + ($codec:ident) => { + use $codec::{Error, Input}; + + impl<'a, I: Input> Input for PrependCompactInput<'a, I> { + fn remaining_len(&mut self) -> Result, Error> { + let remaining_compact = self.encoded_len.len().saturating_sub(self.read); + Ok(self.inner.remaining_len()?.map(|len| len.saturating_add(remaining_compact))) + } + + fn read(&mut self, into: &mut [u8]) -> Result<(), Error> { + if into.is_empty() { + return Ok(()); + } + + let remaining_compact = self.encoded_len.len().saturating_sub(self.read); + if remaining_compact > 0 { + let to_read = into.len().min(remaining_compact); + into[..to_read].copy_from_slice(&self.encoded_len[self.read..][..to_read]); + self.read += to_read; + + if to_read < into.len() { + // Buffer not full, keep reading the inner. + self.inner.read(&mut into[to_read..]) + } else { + // Buffer was filled by the compact. + Ok(()) + } + } else { + // Prepended compact has been read, just read from inner. + self.inner.read(into) + } + } + } + }; +} + +// Generate implementations for each codec +#[cfg(feature = "scale-codec")] +pub mod scale_codec_impl { + use super::PrependCompactInput; + impl_prepend_compact_input!(scale_codec); +} + +#[cfg(feature = "jam-codec")] +pub mod jam_codec_impl { + use super::PrependCompactInput; + impl_prepend_compact_input!(jam_codec); +} + +#[cfg(test)] +#[cfg(any(feature = "scale-codec", feature = "jam-codec"))] +mod tests { + use super::PrependCompactInput; + + /// Macro to generate tests for different codec implementations + macro_rules! codec_tests { + ($codec:ident, $mod_name:ident) => { + mod $mod_name { + use super::PrependCompactInput; + use $codec::{Compact, Encode, Input}; + + #[test] + fn prepend_compact_input_works() { + let encoded_len = Compact(3u32).encode(); + let inner = [2, 3, 4]; + let mut input = + PrependCompactInput { encoded_len: encoded_len.as_ref(), read: 0, inner: &mut &inner[..] }; + assert_eq!(input.remaining_len(), Ok(Some(4))); + + // Passing an empty buffer should leave input unchanged. + let mut empty_buf = []; + assert_eq!(input.read(&mut empty_buf), Ok(())); + assert_eq!(input.remaining_len(), Ok(Some(4))); + assert_eq!(input.read, 0); + + // Passing a correctly-sized buffer will read correctly. + let mut buf = [0; 4]; + assert_eq!(input.read(&mut buf), Ok(())); + assert_eq!(buf[0], encoded_len[0]); + assert_eq!(buf[1..], inner[..]); + // And the bookkeeping agrees. + assert_eq!(input.remaining_len(), Ok(Some(0))); + assert_eq!(input.read, encoded_len.len()); + + // And we can't read more. + assert!(input.read(&mut buf).is_err()); + } + + #[test] + fn prepend_compact_input_incremental_read_works() { + let encoded_len = Compact(3u32).encode(); + let inner = [2, 3, 4]; + let mut input = + PrependCompactInput { encoded_len: encoded_len.as_ref(), read: 0, inner: &mut &inner[..] }; + assert_eq!(input.remaining_len(), Ok(Some(4))); + + // Compact is first byte - ensure that it fills the buffer when it's more than one. + let mut buf = [0u8; 2]; + assert_eq!(input.read(&mut buf), Ok(())); + assert_eq!(buf[0], encoded_len[0]); + assert_eq!(buf[1], inner[0]); + assert_eq!(input.remaining_len(), Ok(Some(2))); + assert_eq!(input.read, encoded_len.len()); + + // Check the last two bytes are read correctly. + assert_eq!(input.read(&mut buf), Ok(())); + assert_eq!(buf[..], inner[1..]); + assert_eq!(input.remaining_len(), Ok(Some(0))); + assert_eq!(input.read, encoded_len.len()); + + // And we can't read more. + assert!(input.read(&mut buf).is_err()); + } + } + }; + } + + // Generate tests for each available codec + #[cfg(feature = "scale-codec")] + codec_tests!(scale_codec, scale_codec_impl); + #[cfg(feature = "jam-codec")] + codec_tests!(jam_codec, jam_codec_impl); +} diff --git a/bounded-collections/src/lib.rs b/bounded-collections/src/lib.rs index 182749ad..e3b80555 100644 --- a/bounded-collections/src/lib.rs +++ b/bounded-collections/src/lib.rs @@ -16,6 +16,7 @@ pub extern crate alloc; pub mod bounded_btree_map; pub mod bounded_btree_set; pub mod bounded_vec; +pub(crate) mod codec_utils; pub mod const_int; pub mod weak_bounded_vec;