diff --git a/ciborium/src/de/mod.rs b/ciborium/src/de/mod.rs index 18742e4..8105668 100644 --- a/ciborium/src/de/mod.rs +++ b/ciborium/src/de/mod.rs @@ -8,12 +8,11 @@ pub use error::Error; use alloc::{string::String, vec::Vec}; +use crate::{simple_type::SimpleTypeAccess, tag::TagAccess}; use ciborium_io::Read; use ciborium_ll::*; use serde::de::{self, value::BytesDeserializer, Deserializer as _}; -use crate::tag::TagAccess; - trait Expected { fn expected(self, kind: &'static str) -> E; } @@ -213,8 +212,22 @@ where Header::Simple(simple::FALSE) => self.deserialize_bool(visitor), Header::Simple(simple::TRUE) => self.deserialize_bool(visitor), Header::Simple(simple::NULL) => self.deserialize_option(visitor), - Header::Simple(simple::UNDEFINED) => self.deserialize_option(visitor), - h @ Header::Simple(..) => Err(h.expected("known simple value")), + Header::Simple(v @ simple::UNDEFINED) => { + let _: Header = self.decoder.pull()?; + visitor.visit_enum(SimpleTypeAccess::new(self, v)) + } + // Those have to be registered via Standard Actions or are reserved so we should error whenever we + // encounter one. This crate should be updated once new entries in this range are added + // in the IANA registry + h @ Header::Simple(0..=31) => Err(h.expected("known simple value")), + // However we should support arbitrary simple types + Header::Simple(v) => { + let _: Header = self.decoder.pull()?; + self.recurse(|me| { + let access = SimpleTypeAccess::new(me, v); + visitor.visit_enum(access) + }) + } h @ Header::Break => Err(h.expected("non-break")), } @@ -604,6 +617,19 @@ where let access = TagAccess::new(me, tag); visitor.visit_enum(access) }); + } else if name == "@@ST@@" { + return match self.decoder.pull()? { + Header::Simple(v @ simple::UNDEFINED) => { + self.decoder.push(Header::Positive(v as u64)); + visitor.visit_enum(SimpleTypeAccess::new(self, v)) + } + h @ Header::Simple(0..=31) => Err(h.expected("known simple value")), + Header::Simple(v) => { + self.decoder.push(Header::Positive(v as u64)); + visitor.visit_enum(SimpleTypeAccess::new(self, v)) + } + h => Err(h.expected("known simple value")), + }; } loop { diff --git a/ciborium/src/lib.rs b/ciborium/src/lib.rs index cbbddf1..0b36157 100644 --- a/ciborium/src/lib.rs +++ b/ciborium/src/lib.rs @@ -94,6 +94,7 @@ extern crate alloc; pub mod de; pub mod ser; +pub mod simple_type; pub mod tag; pub mod value; diff --git a/ciborium/src/ser/mod.rs b/ciborium/src/ser/mod.rs index a5ad23d..90479f8 100644 --- a/ciborium/src/ser/mod.rs +++ b/ciborium/src/ser/mod.rs @@ -200,10 +200,21 @@ where #[inline] fn serialize_newtype_struct( self, - _name: &'static str, + name: &'static str, value: &U, ) -> Result<(), Self::Error> { - value.serialize(self) + if name == "@@SIMPLETYPE@@" { + use serde::ser::Error as _; + + let v = crate::Value::serialized(value).map_err(Error::custom)?; + let v = v + .as_integer() + .ok_or_else(|| Error::custom("Internal error handling simple types"))?; + let v = u8::try_from(v).map_err(Error::custom)?; + Ok(self.0.push(Header::Simple(v))?) + } else { + value.serialize(self) + } } #[inline] @@ -214,7 +225,16 @@ where variant: &'static str, value: &U, ) -> Result<(), Self::Error> { - if name != "@@TAG@@" || variant != "@@UNTAGGED@@" { + if name == "@@ST@@" && variant == "@@SIMPLETYPE@@" { + use serde::ser::Error as _; + + let v = crate::Value::serialized(value).map_err(Error::custom)?; + let v = v + .as_integer() + .ok_or_else(|| Error::custom("Internal error handling simple types"))?; + let v = u8::try_from(v).map_err(Error::custom)?; + return Ok(self.0.push(Header::Simple(v))?); + } else if name != "@@TAG@@" || variant != "@@UNTAGGED@@" { self.0.push(Header::Map(Some(1)))?; self.serialize_str(variant)?; } diff --git a/ciborium/src/simple_type.rs b/ciborium/src/simple_type.rs new file mode 100644 index 0000000..cbe0d42 --- /dev/null +++ b/ciborium/src/simple_type.rs @@ -0,0 +1,141 @@ +//! Contains helper types for dealing with CBOR simple types + +use serde::{de, de::Error as _, forward_to_deserialize_any, ser, Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename = "@@ST@@")] +enum Internal { + /// The integer can either be 23, or (32..=255) + #[serde(rename = "@@SIMPLETYPE@@")] + SimpleType(u8), +} + +/// A CBOR simple value +/// See https://datatracker.ietf.org/doc/html/rfc8949#section-3.3 +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct SimpleType(pub u8); + +impl<'de> Deserialize<'de> for SimpleType { + #[inline] + fn deserialize>(deserializer: D) -> Result { + match Internal::deserialize(deserializer)? { + Internal::SimpleType(t) => Ok(SimpleType(t)), + } + } +} + +impl Serialize for SimpleType { + #[inline] + fn serialize(&self, serializer: S) -> Result { + Internal::SimpleType(self.0).serialize(serializer) + } +} + +pub(crate) struct SimpleTypeAccess { + parent: Option, + state: usize, + typ: u8, +} + +impl SimpleTypeAccess { + pub fn new(parent: D, typ: u8) -> Self { + Self { + parent: Some(parent), + state: 0, + typ, + } + } +} + +impl<'de, D: de::Deserializer<'de>> de::Deserializer<'de> for &mut SimpleTypeAccess { + type Error = D::Error; + + #[inline] + fn deserialize_any>(self, visitor: V) -> Result { + self.state += 1; + match self.state { + 1 => visitor.visit_str("@@SIMPLETYPE@@"), + _ => visitor.visit_u8(self.typ), + } + } + + forward_to_deserialize_any! { + i8 i16 i32 i64 i128 + u8 u16 u32 u64 u128 + bool f32 f64 + char str string + bytes byte_buf + seq map + struct tuple tuple_struct + identifier ignored_any + option unit unit_struct newtype_struct enum + } +} + +impl<'de, D: de::Deserializer<'de>> de::EnumAccess<'de> for SimpleTypeAccess { + type Error = D::Error; + type Variant = Self; + + #[inline] + fn variant_seed>( + mut self, + seed: V, + ) -> Result<(V::Value, Self::Variant), Self::Error> { + let variant = seed.deserialize(&mut self)?; + Ok((variant, self)) + } +} + +impl<'de, D: de::Deserializer<'de>> de::VariantAccess<'de> for SimpleTypeAccess { + type Error = D::Error; + + #[inline] + fn unit_variant(self) -> Result<(), Self::Error> { + Err(Self::Error::custom("expected simple type")) + } + + #[inline] + fn newtype_variant_seed>( + mut self, + seed: U, + ) -> Result { + seed.deserialize(self.parent.take().unwrap()) + } + + #[inline] + fn tuple_variant>( + self, + _len: usize, + visitor: V, + ) -> Result { + visitor.visit_seq(self) + } + + #[inline] + fn struct_variant>( + self, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result { + Err(Self::Error::custom("expected simple_type")) + } +} + +impl<'de, D: de::Deserializer<'de>> de::SeqAccess<'de> for SimpleTypeAccess { + type Error = D::Error; + + #[inline] + fn next_element_seed>( + &mut self, + seed: T, + ) -> Result, Self::Error> { + if self.state < 2 { + return Ok(Some(seed.deserialize(self)?)); + } + + Ok(match self.parent.take() { + Some(x) => Some(seed.deserialize(x)?), + None => None, + }) + } +} diff --git a/ciborium/src/value/de.rs b/ciborium/src/value/de.rs index f58a017..7c82270 100644 --- a/ciborium/src/value/de.rs +++ b/ciborium/src/value/de.rs @@ -1,13 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 -use crate::tag::TagAccess; +use crate::{simple_type::SimpleTypeAccess, tag::TagAccess}; use super::{Error, Integer, Value}; use alloc::{boxed::Box, string::String, vec::Vec}; use core::iter::Peekable; -use ciborium_ll::tag; +use ciborium_ll::{simple, tag}; use serde::de::{self, Deserializer as _}; impl<'a> From for de::Unexpected<'a> { @@ -36,6 +36,7 @@ impl<'a> From<&'a Value> for de::Unexpected<'a> { Value::Map(..) => Self::Map, Value::Null => Self::Other("null"), Value::Tag(..) => Self::Other("tag"), + Value::Simple(..) => Self::Other("simple"), } } } @@ -141,9 +142,9 @@ impl<'de> serde::de::Visitor<'de> for Visitor { fn visit_enum>(self, acc: A) -> Result { use serde::de::VariantAccess; - struct Inner; + struct TagInner; - impl<'de> serde::de::Visitor<'de> for Inner { + impl<'de> serde::de::Visitor<'de> for TagInner { type Value = Value; fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { @@ -162,9 +163,30 @@ impl<'de> serde::de::Visitor<'de> for Visitor { } } + struct SimpleTypeInner; + + impl<'de> serde::de::Visitor<'de> for SimpleTypeInner { + type Value = Value; + + fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(formatter, "a valid CBOR item") + } + + #[inline] + fn visit_seq>(self, mut acc: A) -> Result { + let st = acc + .next_element::()? + .ok_or_else(|| de::Error::custom("expected simple type"))?; + Ok(Value::Simple(st)) + } + } + let (name, data): (String, _) = acc.variant()?; - assert_eq!("@@TAGGED@@", name); - data.tuple_variant(2, Inner) + match name.as_str() { + "@@TAGGED@@" => data.tuple_variant(2, TagInner), + "@@SIMPLETYPE@@" => data.tuple_variant(1, SimpleTypeInner), + _ => panic!("Implementation error"), + } } } @@ -218,6 +240,7 @@ impl<'a> Deserializer<&'a Value> { .map(|x| x ^ !0) .map_err(|_| err()) .and_then(|x| x.try_into().map_err(|_| err()))?, + Value::Simple(x) => i128::from(*x).try_into().map_err(|_| err())?, _ => return Err(de::Error::invalid_type(self.0.into(), &"(big)int")), }) } @@ -228,6 +251,7 @@ impl<'a, 'de> de::Deserializer<'de> for Deserializer<&'a Value> { #[inline] fn deserialize_any>(self, visitor: V) -> Result { + use serde::ser::Error as _; match self.0 { Value::Bytes(x) => visitor.visit_bytes(x), Value::Text(x) => visitor.visit_str(x), @@ -235,6 +259,14 @@ impl<'a, 'de> de::Deserializer<'de> for Deserializer<&'a Value> { Value::Map(x) => visitor.visit_map(Deserializer(x.iter().peekable())), Value::Bool(x) => visitor.visit_bool(*x), Value::Null => visitor.visit_none(), + Value::Simple(v @ simple::UNDEFINED) => { + visitor.visit_enum(SimpleTypeAccess::new(self, *v)) + } + Value::Simple(0..=31) => Err(Self::Error::custom("Unsupported simple type")), + Value::Simple(v) => { + let access = SimpleTypeAccess::new(self, *v); + visitor.visit_enum(access) + } Value::Tag(t, v) => { let parent: Deserializer<&Value> = Deserializer(v); @@ -493,6 +525,18 @@ impl<'a, 'de> de::Deserializer<'de> for Deserializer<&'a Value> { let parent: Deserializer<&Value> = Deserializer(val); let access = TagAccess::new(parent, tag); return visitor.visit_enum(access); + } else if name == "@@ST@@" { + use serde::ser::Error as _; + return match self.0 { + Value::Simple(v @ simple::UNDEFINED) => { + visitor.visit_enum(SimpleTypeAccess::new(Deserializer(self.0), *v)) + } + Value::Simple(0..=31) => return Err(Error::custom("Unsupported simple type")), + Value::Simple(v) => { + visitor.visit_enum(SimpleTypeAccess::new(Deserializer(self.0), *v)) + } + _ => Err(Error::custom("Implementation error for simple type")), + }; } match self.0 { diff --git a/ciborium/src/value/mod.rs b/ciborium/src/value/mod.rs index 7233026..9f8aab8 100644 --- a/ciborium/src/value/mod.rs +++ b/ciborium/src/value/mod.rs @@ -45,6 +45,9 @@ pub enum Value { /// A map Map(Vec<(Value, Value)>), + + /// A CBOR "Simple Value" other than true, false, or null + Simple(u8), } impl Value { @@ -594,6 +597,47 @@ impl Value { other => Err(other), } } + + /// Returns true if the `Value` is an `SimpleType`. Returns false otherwise. + /// + /// ``` + /// # use ciborium::Value; + /// # + /// assert!(Value::Simple(59).is_simple()); + /// ``` + pub fn is_simple(&self) -> bool { + self.as_simple().is_some() + } + + /// If the `Value` is a `SimpleType`. The value can only be in [0..23] or [32..255]. + /// + /// ``` + /// # use ciborium::Value; + /// # + /// assert_eq!(59, Value::Simple(59).as_simple().unwrap()); + /// ``` + pub fn as_simple(&self) -> Option { + match self { + Value::Simple(int) => Some(*int), + _ => None, + } + } + + /// If the `Value` is a `SimpleType`. The value can only be in [0..23] or [32..255]. + /// + /// ``` + /// # use ciborium::{Value, value::Integer}; + /// # + /// assert_eq!(Value::Simple(59).into_simple(), Ok(59)); + /// + /// assert_eq!(Value::Bool(true).into_simple(), Err(Value::Bool(true))); + /// ``` + pub fn into_simple(self) -> Result { + match self { + Value::Simple(int) => Ok(int), + other => Err(other), + } + } } macro_rules! implfrom { diff --git a/ciborium/src/value/ser.rs b/ciborium/src/value/ser.rs index 99e6587..833b93e 100644 --- a/ciborium/src/value/ser.rs +++ b/ciborium/src/value/ser.rs @@ -5,15 +5,23 @@ use super::{Error, Value}; use alloc::{vec, vec::Vec}; use ::serde::ser::{self, SerializeMap as _, SerializeSeq as _, SerializeTupleVariant as _}; +use ciborium_ll::simple; impl ser::Serialize for Value { #[inline] fn serialize(&self, serializer: S) -> Result { + use serde::ser::Error as _; + match self { Value::Bytes(x) => serializer.serialize_bytes(x), Value::Bool(x) => serializer.serialize_bool(*x), Value::Text(x) => serializer.serialize_str(x), Value::Null => serializer.serialize_unit(), + Value::Simple(x @ simple::UNDEFINED) => { + serializer.serialize_newtype_struct("@@SIMPLETYPE@@", x) + } + Value::Simple(0..=31) => Err(S::Error::custom("Unsupported simple type")), + Value::Simple(x) => serializer.serialize_newtype_struct("@@SIMPLETYPE@@", x), Value::Tag(t, v) => { let mut acc = serializer.serialize_tuple_variant("@@TAG@@", 0, "@@TAGGED@@", 2)?; @@ -174,9 +182,19 @@ impl ser::Serializer for Serializer<()> { #[inline] fn serialize_newtype_struct( self, - _name: &'static str, + name: &'static str, value: &U, ) -> Result { + if name == "@@SIMPLETYPE@@" { + use serde::ser::Error as _; + + let v = Value::serialized(value)?; + let v = v + .as_integer() + .ok_or_else(|| Error::custom("Internal error handling simple types"))?; + let v = u8::try_from(v).map_err(Error::custom)?; + return Ok(Value::Simple(v)); + } value.serialize(self) } @@ -190,6 +208,16 @@ impl ser::Serializer for Serializer<()> { ) -> Result { Ok(match (name, variant) { ("@@TAG@@", "@@UNTAGGED@@") => Value::serialized(value)?, + ("@@ST@@", "@@SIMPLETYPE@@") => { + use serde::ser::Error as _; + + let v = Value::serialized(value)?; + let v = v + .as_integer() + .ok_or_else(|| Error::custom("Internal error handling simple types"))?; + let v = u8::try_from(v).map_err(Error::custom)?; + Value::Simple(v) + } _ => vec![(variant.into(), Value::serialized(value)?)].into(), }) } diff --git a/ciborium/tests/simple_type.rs b/ciborium/tests/simple_type.rs new file mode 100644 index 0000000..912f67c --- /dev/null +++ b/ciborium/tests/simple_type.rs @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: Apache-2.0 + +extern crate alloc; + +use ciborium::{de::from_reader, ser::into_writer, simple_type::SimpleType, value::Value}; +use rstest::rstest; +use serde::{de::DeserializeOwned, Serialize}; + +use core::fmt::Debug; +use std::collections::HashMap; + +#[rstest(item, bytes, value, encode, success, + case(SimpleType(0), "e0", Value::Simple(0), true, false), // Registered via Standard Actions + case(SimpleType(19), "f3", Value::Simple(19), true, false), // Registered via Standard Actions + case(SimpleType(23), "f7", Value::Simple(23), true, true), // CBOR simple value "undefined" + case(SimpleType(32), "f820", Value::Simple(32), true, true), + case(SimpleType(59), "f83b", Value::Simple(59), true, true), + case(SimpleType(255), "f8ff", Value::Simple(255), true, true), + case(vec![SimpleType(255)], "81f8ff", Value::Array(vec![Value::Simple(255)]), true, true), + case(HashMap::::from_iter([(SimpleType(59), 0)]), "a1f83b00", Value::Map(vec![(Value::Simple(59), Value::Integer(0.into()))]), true, true), +)] +fn test( + item: T, + bytes: &str, + value: Value, + encode: bool, + success: bool, +) { + let bytes = hex::decode(bytes).unwrap(); + + if encode { + // Encode into bytes + let mut encoded = Vec::new(); + into_writer(&item, &mut encoded).unwrap(); + assert_eq!(bytes, encoded); + + // Encode into value + assert_eq!(value, Value::serialized(&item).unwrap()); + } + + // Decode from bytes + match from_reader(&bytes[..]) { + Ok(x) if success => assert_eq!(item, x), + Ok(..) => panic!("unexpected success"), + Err(e) if success => panic!("{:?}", e), + Err(..) => (), + } + + // Decode from value + match value.deserialized() { + Ok(x) if success => assert_eq!(item, x), + Ok(..) => panic!("unexpected success"), + Err(e) if success => panic!("{:?}", e), + Err(..) => (), + } +} + +#[test] +fn value_serialized() { + let st = Value::Simple(59); + assert_eq!(st.clone(), Value::serialized(&st).unwrap()); + + let map_as_key = Value::Map(vec![(st.clone(), Value::Integer(0.into()))]); + assert_eq!(map_as_key, Value::serialized(&map_as_key).unwrap()); + + let map_as_value = Value::Map(vec![(Value::Integer(0.into()), st.clone())]); + assert_eq!(map_as_value, Value::serialized(&map_as_value).unwrap()); + + let array = Value::Array(vec![st]); + assert_eq!(array, Value::serialized(&array).unwrap()); +} + +#[test] +fn value_deserialize() { + let st = Value::Simple(59); + + let in_map_as_label = Value::Map(vec![(st.clone(), Value::Integer(0.into()))]); + Value::deserialized::(&in_map_as_label).unwrap(); + + let in_map_as_value = Value::Map(vec![(Value::Integer(0.into()), st.clone())]); + Value::deserialized::(&in_map_as_value).unwrap(); + + let in_array = Value::Array(vec![st.clone()]); + Value::deserialized::(&in_array).unwrap(); +} + +#[test] +fn should_roundtrip() { + let st = Value::Simple(59); + + let map = Value::Map(vec![(st.clone(), Value::Array(vec![]))]); + + let mut encoded = vec![]; + into_writer(&map, &mut encoded).unwrap(); + + let decoded = from_reader::(&encoded[..]).unwrap(); + assert_eq!(decoded, map); +}