diff --git a/engine/src/scheme.rs b/engine/src/scheme.rs index 03c3795c..33dcb339 100644 --- a/engine/src/scheme.rs +++ b/engine/src/scheme.rs @@ -9,6 +9,7 @@ use crate::{ types::{GetType, RhsValue, Type}, }; use fnv::FnvBuildHasher; +use serde::de::Visitor; use serde::ser::SerializeMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::collections::hash_map::Entry; @@ -139,7 +140,7 @@ impl<'s> FieldRef<'s> { /// Returns the field's name as recorded in the [`Scheme`](struct@Scheme). #[inline] pub fn name(&self) -> &'s str { - &self.scheme.inner.fields[self.index].0 + &self.scheme.inner.fields[self.index].name } /// Get the field's index in the [`Scheme`](struct@Scheme) identifier's list. @@ -183,7 +184,7 @@ impl<'s> FieldRef<'s> { impl GetType for FieldRef<'_> { #[inline] fn get_type(&self) -> Type { - self.scheme.inner.fields[self.index].1 + self.scheme.inner.fields[self.index].ty } } @@ -217,7 +218,7 @@ impl Field { /// Returns the field's name as recorded in the [`Scheme`](struct@Scheme). #[inline] pub fn name(&self) -> &str { - &self.scheme.inner.fields[self.index].0 + &self.scheme.inner.fields[self.index].name } /// Get the field's index in the [`Scheme`](struct@Scheme) identifier's list. @@ -245,7 +246,7 @@ impl Field { impl GetType for Field { #[inline] fn get_type(&self) -> Type { - self.scheme.inner.fields[self.index].1 + self.scheme.inner.fields[self.index].ty } } @@ -609,10 +610,16 @@ pub struct ListRedefinitionError(Type); type IdentifierName = Arc; +#[derive(Debug, PartialEq)] +struct FieldDefinition { + name: IdentifierName, + ty: Type, +} + /// A builder for a [`Scheme`]. #[derive(Default, Debug)] pub struct SchemeBuilder { - fields: Vec<(IdentifierName, Type)>, + fields: Vec, functions: Vec<(IdentifierName, Box)>, items: HashMap, @@ -643,7 +650,10 @@ impl SchemeBuilder { }, Entry::Vacant(entry) => { let index = self.fields.len(); - self.fields.push((entry.key().clone(), ty)); + self.fields.push(FieldDefinition { + name: entry.key().clone(), + ty, + }); entry.insert(SchemeItem::Field(index)); Ok(()) } @@ -737,14 +747,21 @@ impl Hash for Scheme { } } +#[derive(Deserialize, Serialize)] +struct SerdeField { + #[serde(rename = "type")] + ty: Type, +} + impl Serialize for Scheme { fn serialize(&self, serializer: S) -> Result where S: Serializer, { - let mut map = serializer.serialize_map(Some(self.field_count()))?; - for f in self.fields() { - map.serialize_entry(f.name(), &f.get_type())?; + let fields = self.fields(); + let mut map = serializer.serialize_map(Some(fields.len()))?; + for f in fields { + map.serialize_entry(f.name(), &SerdeField { ty: f.get_type() })?; } map.end() } @@ -757,12 +774,31 @@ impl<'de> Deserialize<'de> for Scheme { { use serde::de::Error; - let mut builder = SchemeBuilder::new(); - let map: HashMap = HashMap::::deserialize(deserializer)?; - for (name, ty) in map { - builder.add_field(&name, ty).map_err(D::Error::custom)?; + struct FieldMapVisitor; + + impl<'de> Visitor<'de> for FieldMapVisitor { + type Value = SchemeBuilder; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("a wirefilter scheme") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut builder = SchemeBuilder::new(); + while let Some((name, SerdeField { ty })) = map.next_entry::<&str, SerdeField>()? { + builder.add_field(name, ty).map_err(A::Error::custom)?; + } + + Ok(builder) + } } - Ok(builder.build()) + + deserializer + .deserialize_map(FieldMapVisitor) + .map(|builder| builder.build()) } } @@ -1743,3 +1779,24 @@ fn test_scheme_iter_fields() { ] ); } + +#[test] +fn test_scheme_json_serialization() { + let scheme = Scheme! { + bytes: Bytes, + int: Int, + bool: Bool, + ip: Ip, + map_of_bytes: Map(Bytes), + map_of_array_of_bytes: Map(Array(Bytes)), + array_of_bytes: Array(Bytes), + array_of_map_of_bytes: Array(Map(Bytes)), + } + .build(); + + let json = serde_json::to_string(&scheme).unwrap(); + + let new_scheme = serde_json::from_str::(&json).unwrap(); + + assert_eq!(scheme.inner.fields, new_scheme.inner.fields); +} diff --git a/ffi/tests/ctests/src/tests.c b/ffi/tests/ctests/src/tests.c index cab13e8c..8eded4b3 100644 --- a/ffi/tests/ctests/src/tests.c +++ b/ffi/tests/ctests/src/tests.c @@ -305,7 +305,7 @@ void wirefilter_ffi_ctest_scheme_serialize() { rust_assert(json.ptr != NULL && json.len > 0, "could not serialize scheme to JSON"); rust_assert( - strncmp(json.ptr, "{\"http.host\":\"Bytes\",\"ip.src\":\"Ip\",\"ip.dst\":\"Ip\",\"ssl\":\"Bool\",\"tcp.port\":\"Int\",\"http.headers\":{\"Map\":\"Bytes\"},\"http.cookies\":{\"Array\":\"Bytes\"}}", json.len) == 0, + strncmp(json.ptr, "{\"http.host\":{\"type\":\"Bytes\"},\"ip.src\":{\"type\":\"Ip\"},\"ip.dst\":{\"type\":\"Ip\"},\"ssl\":{\"type\":\"Bool\"},\"tcp.port\":{\"type\":\"Int\"},\"http.headers\":{\"type\":{\"Map\":\"Bytes\"}},\"http.cookies\":{\"type\":{\"Array\":\"Bytes\"}}}", json.len) == 0, "invalid JSON serialization" );