diff --git a/engine/src/ast/field_expr.rs b/engine/src/ast/field_expr.rs index f238fde5..cab7b7e4 100644 --- a/engine/src/ast/field_expr.rs +++ b/engine/src/ast/field_expr.rs @@ -788,7 +788,7 @@ impl Expr for ComparisonExpr { mod tests { use super::*; use crate::{ - BytesFormat, FieldRef, LhsValue, ParserSettings, + BytesFormat, FieldRef, LhsValue, ParserSettings, TypedMap, ast::{ function_expr::{FunctionCallArgExpr, FunctionCallExpr}, logical_expr::LogicalExpr, @@ -1586,18 +1586,18 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - let headers = LhsValue::Map({ - let mut map = Map::new(Type::Bytes); - map.insert(b"host", "example.org").unwrap(); + let headers = LhsValue::from({ + let mut map = TypedMap::new(); + map.insert(b"host".to_vec().into(), "example.org"); map }); ctx.set_field_value(field("http.headers"), headers).unwrap(); assert_eq!(expr.execute_one(ctx), false); - let headers = LhsValue::Map({ - let mut map = Map::new(Type::Bytes); - map.insert(b"host", "abc.net.au").unwrap(); + let headers = LhsValue::from({ + let mut map = TypedMap::new(); + map.insert(b"host".to_vec().into(), "abc.net.au"); map }); @@ -2186,11 +2186,11 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - let headers = LhsValue::Map({ - let mut map = Map::new(Type::Bytes); - map.insert(b"0", "one").unwrap(); - map.insert(b"1", "two").unwrap(); - map.insert(b"2", "three").unwrap(); + let headers = LhsValue::from({ + let mut map = TypedMap::new(); + map.insert(b"0".to_vec().into(), "one"); + map.insert(b"1".to_vec().into(), "two"); + map.insert(b"2".to_vec().into(), "three"); map }); ctx.set_field_value(field("http.headers"), headers).unwrap(); @@ -2267,11 +2267,11 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - let headers = LhsValue::Map({ - let mut map = Map::new(Type::Bytes); - map.insert(b"0", "one").unwrap(); - map.insert(b"1", "two").unwrap(); - map.insert(b"2", "three").unwrap(); + let headers = LhsValue::from({ + let mut map = TypedMap::new(); + map.insert(b"0".to_vec().into(), "one"); + map.insert(b"1".to_vec().into(), "two"); + map.insert(b"2".to_vec().into(), "three"); map }); ctx.set_field_value(field("http.headers"), headers).unwrap(); diff --git a/engine/src/execution_context.rs b/engine/src/execution_context.rs index e5f390ae..6fc40999 100644 --- a/engine/src/execution_context.rs +++ b/engine/src/execution_context.rs @@ -425,8 +425,7 @@ fn test_scheme_mismatch() { #[test] fn test_serde() { - use crate::lhs_types::{Array, Map}; - use crate::types::Type; + use crate::lhs_types::{Array, TypedMap}; use std::net::IpAddr; use std::str::FromStr; @@ -491,9 +490,9 @@ fn test_serde() { assert_eq!( ctx.set_field_value(scheme.get_field("map").unwrap(), { - let mut map = Map::new(Type::Int); - map.insert(b"leet", 1337).unwrap(); - map.insert(b"tabs", 25).unwrap(); + let mut map = TypedMap::::new(); + map.insert(b"leet".to_vec().into(), 1337); + map.insert(b"tabs".to_vec().into(), 25); map }), Ok(None), @@ -535,16 +534,16 @@ fn test_serde() { assert_eq!( ctx.set_field_value(scheme.get_field("map").unwrap(), { - let mut map = Map::new(Type::Int); - map.insert(b"leet", 1337).unwrap(); - map.insert(b"tabs", 25).unwrap(); - map.insert(b"a\xFF\xFFb", 17).unwrap(); + let mut map = TypedMap::::new(); + map.insert(b"leet".to_vec().into(), 1337); + map.insert(b"tabs".to_vec().into(), 25); + map.insert(b"a\xFF\xFFb".to_vec().into(), 17); map }), Ok(Some({ - let mut map = Map::new(Type::Int); - map.insert(b"leet", 1337).unwrap(); - map.insert(b"tabs", 25).unwrap(); + let mut map = TypedMap::::new(); + map.insert(b"leet".to_vec().into(), 1337); + map.insert(b"tabs".to_vec().into(), 25); map.into() })), ); diff --git a/engine/src/lhs_types/array.rs b/engine/src/lhs_types/array.rs index ead1273f..cf4ce14a 100644 --- a/engine/src/lhs_types/array.rs +++ b/engine/src/lhs_types/array.rs @@ -1,9 +1,6 @@ use crate::{ lhs_types::AsRefIterator, - types::{ - CompoundType, GetType, IntoValue, LhsValue, LhsValueMut, LhsValueSeed, Type, - TypeMismatchError, - }, + types::{CompoundType, GetType, IntoValue, LhsValue, LhsValueSeed, Type, TypeMismatchError}, }; use serde::{ Serialize, Serializer, @@ -53,11 +50,6 @@ impl<'a> InnerArray<'a> { self.as_vec().get_mut(idx) } - #[inline] - fn insert(&mut self, idx: usize, value: LhsValue<'a>) { - self.as_vec().insert(idx, value) - } - #[inline] fn push(&mut self, value: LhsValue<'a>) { self.as_vec().push(value) @@ -121,43 +113,6 @@ impl<'a> Array<'a> { self.data.get(idx) } - /// Get a mutable reference to an element if it exists - pub fn get_mut(&mut self, idx: usize) -> Option> { - self.data.get_mut(idx).map(LhsValueMut::from) - } - - /// Inserts an element at index `idx` - pub fn insert( - &mut self, - idx: usize, - value: impl Into>, - ) -> Result<(), TypeMismatchError> { - let value = value.into(); - let value_type = value.get_type(); - if value_type != self.val_type.into() { - return Err(TypeMismatchError { - expected: Type::from(self.val_type).into(), - actual: value_type, - }); - } - self.data.insert(idx, value); - Ok(()) - } - - /// Push an element to the back of the array - pub fn push(&mut self, value: impl Into>) -> Result<(), TypeMismatchError> { - let value = value.into(); - let value_type = value.get_type(); - if value_type != self.val_type.into() { - return Err(TypeMismatchError { - expected: Type::from(self.val_type).into(), - actual: value_type, - }); - } - self.data.push(value); - Ok(()) - } - pub(crate) fn as_ref(&'a self) -> Array<'a> { Array { val_type: self.val_type, @@ -445,51 +400,6 @@ impl<'de> DeserializeSeed<'de> for &mut Array<'de> { } } -/// Wrapper type around mutable `Array` to prevent -/// illegal operations like changing the type of -/// its values. -pub struct ArrayMut<'a, 'b>(&'a mut Array<'b>); - -impl<'a, 'b> ArrayMut<'a, 'b> { - /// Push an element to the back of the array - #[inline] - pub fn push(&mut self, value: impl Into>) -> Result<(), TypeMismatchError> { - self.0.push(value) - } - - /// Inserts an element at index `idx` - #[inline] - pub fn insert( - &mut self, - idx: usize, - value: impl Into>, - ) -> Result<(), TypeMismatchError> { - self.0.insert(idx, value) - } - - /// Get a mutable reference to an element if it exists - #[inline] - pub fn get_mut(&'a mut self, idx: usize) -> Option> { - self.0.get_mut(idx) - } -} - -impl<'b> Deref for ArrayMut<'_, 'b> { - type Target = Array<'b>; - - #[inline] - fn deref(&self) -> &Self::Target { - self.0 - } -} - -impl<'a, 'b> From<&'a mut Array<'b>> for ArrayMut<'a, 'b> { - #[inline] - fn from(arr: &'a mut Array<'b>) -> Self { - Self(arr) - } -} - /// Typed wrapper over an `Array` which provides /// infaillible operations. #[repr(transparent)] @@ -703,11 +613,11 @@ mod tests { #[test] fn test_borrowed_eq_owned() { - let mut owned = Array::new(Type::Bytes); + let mut arr = TypedArray::new(); + + arr.push("borrowed"); - owned - .push(LhsValue::Bytes("borrowed".as_bytes().into())) - .unwrap(); + let owned = Array::from(arr); let borrowed = owned.as_ref(); diff --git a/engine/src/lhs_types/map.rs b/engine/src/lhs_types/map.rs index ebcd7bfa..8850b229 100644 --- a/engine/src/lhs_types/map.rs +++ b/engine/src/lhs_types/map.rs @@ -1,9 +1,7 @@ use crate::{ + TypeMismatchError, lhs_types::AsRefIterator, - types::{ - BytesOrString, CompoundType, GetType, IntoValue, LhsValue, LhsValueMut, LhsValueSeed, Type, - TypeMismatchError, - }, + types::{BytesOrString, CompoundType, GetType, IntoValue, LhsValue, LhsValueSeed, Type}, }; use serde::{ Serialize, Serializer, @@ -102,46 +100,6 @@ impl<'a> Map<'a> { self.data.get(key.as_ref()) } - /// Get a mutable reference to an element if it exists - pub fn get_mut>(&mut self, key: K) -> Option> { - self.data.get_mut(key.as_ref()).map(LhsValueMut::from) - } - - /// Inserts an element, overwriting if one already exists - pub fn insert( - &mut self, - key: &[u8], - value: impl Into>, - ) -> Result<(), TypeMismatchError> { - let value = value.into(); - let value_type = value.get_type(); - if value_type != self.val_type.into() { - return Err(TypeMismatchError { - expected: Type::from(self.val_type).into(), - actual: value_type, - }); - } - self.data.insert(key.into(), value); - Ok(()) - } - - /// Inserts `value` if `key` is missing, then returns a mutable reference to the contained value. - pub fn get_or_insert( - &mut self, - key: Box<[u8]>, - value: impl Into>, - ) -> Result, TypeMismatchError> { - let value = value.into(); - let value_type = value.get_type(); - if value_type != self.val_type.into() { - return Err(TypeMismatchError { - expected: Type::from(self.val_type).into(), - actual: value_type, - }); - } - Ok(LhsValueMut::from(self.data.get_or_insert(key, value))) - } - pub(crate) fn as_ref(&'a self) -> Map<'a> { Map { val_type: self.val_type, @@ -210,6 +168,34 @@ impl<'a> Map<'a> { pub fn iter(&self) -> MapIter<'a, '_> { MapIter(self.data.iter()) } + + /// Creates a new map from the specified iterator. + pub fn try_from_iter, V: Into>>( + val_type: impl Into, + iter: impl IntoIterator, V), E>>, + ) -> Result { + let val_type = val_type.into(); + iter.into_iter() + .map(|key_val| { + key_val.and_then(|(key, val)| { + let elem = val.into(); + let elem_type = elem.get_type(); + if val_type != elem_type.into() { + Err(E::from(TypeMismatchError { + expected: Type::from(val_type).into(), + actual: elem_type, + })) + } else { + Ok((key, elem)) + } + }) + }) + .collect::, _>>() + .map(|map| Map { + val_type, + data: InnerMap::Owned(map), + }) + } } impl<'a> PartialEq for Map<'a> { @@ -402,14 +388,19 @@ impl<'de> DeserializeSeed<'de> for &mut Map<'de> { where M: MapAccess<'de>, { + let value_type = self.0.value_type(); while let Some(key) = access.next_key::>()? { - let value = access.next_value_seed(LhsValueSeed(&self.0.value_type()))?; - self.0.insert(key.as_bytes(), value).map_err(|e| { - de::Error::custom(format!( + let value = access.next_value_seed(LhsValueSeed(&value_type))?; + if value.get_type() != value_type { + return Err(de::Error::custom(format!( "invalid type: {:?}, expected {:?}", - e.actual, e.expected - )) - })?; + value.get_type(), + value_type + ))); + } + self.0 + .data + .insert(key.into_owned().into_bytes().into(), value); } Ok(()) @@ -419,13 +410,16 @@ impl<'de> DeserializeSeed<'de> for &mut Map<'de> { where V: SeqAccess<'de>, { - while let Some(entry) = seq.next_element_seed(MapEntrySeed(&self.0.value_type()))? { - self.0.insert(&entry.0, entry.1).map_err(|e| { - de::Error::custom(format!( + let value_type = self.0.value_type(); + while let Some((key, value)) = seq.next_element_seed(MapEntrySeed(&value_type))? { + if value.get_type() != value_type { + return Err(de::Error::custom(format!( "invalid type: {:?}, expected {:?}", - e.actual, e.expected - )) - })?; + value.get_type(), + value_type + ))); + } + self.0.data.insert(key.into_owned().into(), value); } Ok(()) } @@ -435,55 +429,6 @@ impl<'de> DeserializeSeed<'de> for &mut Map<'de> { } } -/// Wrapper type around mutable `Map` to prevent -/// illegal operations like changing the type of -/// its values. -pub struct MapMut<'a, 'b>(&'a mut Map<'b>); - -impl<'a, 'b> MapMut<'a, 'b> { - /// Get a mutable reference to an element if it exists - #[inline] - pub fn get_mut(&'a mut self, key: &[u8]) -> Option> { - self.0.get_mut(key) - } - - /// Inserts an element, overwriting if one already exists - #[inline] - pub fn insert( - &mut self, - key: &[u8], - value: impl Into>, - ) -> Result<(), TypeMismatchError> { - self.0.insert(key, value) - } - - /// Inserts `value` if `key` is missing, then returns a mutable reference to the contained value. - #[inline] - pub fn get_or_insert( - &'a mut self, - key: Box<[u8]>, - value: impl Into>, - ) -> Result, TypeMismatchError> { - self.0.get_or_insert(key, value) - } -} - -impl<'b> Deref for MapMut<'_, 'b> { - type Target = Map<'b>; - - #[inline] - fn deref(&self) -> &Self::Target { - self.0 - } -} - -impl<'a, 'b> From<&'a mut Map<'b>> for MapMut<'a, 'b> { - #[inline] - fn from(map: &'a mut Map<'b>) -> Self { - Self(map) - } -} - /// Typed wrapper over a `Map` which provides /// infaillible operations. #[repr(transparent)] @@ -711,11 +656,11 @@ mod tests { #[test] fn test_borrowed_eq_owned() { - let mut owned = Map::new(Type::Bytes); + let mut map = TypedMap::new(); + + map.insert("key".as_bytes().to_vec().into(), "borrowed"); - owned - .insert(b"key", LhsValue::Bytes("borrowed".as_bytes().into())) - .unwrap(); + let owned = Map::from(map); let borrowed = owned.as_ref(); diff --git a/engine/src/lhs_types/mod.rs b/engine/src/lhs_types/mod.rs index 4d6c99a2..93172b41 100644 --- a/engine/src/lhs_types/mod.rs +++ b/engine/src/lhs_types/mod.rs @@ -4,8 +4,8 @@ mod map; use crate::types::LhsValue; pub use self::{ - array::{Array, ArrayIterator, ArrayMut, TypedArray}, - map::{Map, MapIter, MapMut, MapValuesIntoIter, TypedMap}, + array::{Array, ArrayIterator, TypedArray}, + map::{Map, MapIter, MapValuesIntoIter, TypedMap}, }; pub struct AsRefIterator<'a, T: Iterator>>(T); diff --git a/engine/src/lib.rs b/engine/src/lib.rs index 8c03b14b..10fadd50 100644 --- a/engine/src/lib.rs +++ b/engine/src/lib.rs @@ -102,7 +102,7 @@ pub use self::{ SimpleFunctionOptParam, SimpleFunctionParam, }, lex::LexErrorKind, - lhs_types::{Array, ArrayMut, Map, MapIter, MapMut, TypedArray, TypedMap}, + lhs_types::{Array, Map, MapIter, TypedArray, TypedMap}, list_matcher::{ AlwaysList, AlwaysListMatcher, ListDefinition, ListMatcher, NeverList, NeverListMatcher, }, @@ -120,7 +120,7 @@ pub use self::{ SchemeBuilder, SchemeMismatchError, UnknownFieldError, }, types::{ - CompoundType, ExpectedType, ExpectedTypeList, GetType, LhsValue, LhsValueMut, RhsValue, - RhsValues, Type, TypeMismatchError, + CompoundType, ExpectedType, ExpectedTypeList, GetType, LhsValue, RhsValue, RhsValues, Type, + TypeMismatchError, }, }; diff --git a/engine/src/types.rs b/engine/src/types.rs index aa00f17f..a0a6289f 100644 --- a/engine/src/types.rs +++ b/engine/src/types.rs @@ -1,6 +1,6 @@ use crate::{ lex::{Lex, LexResult, LexWith, expect, skip_space}, - lhs_types::{Array, ArrayIterator, ArrayMut, Map, MapIter, MapMut, MapValuesIntoIter}, + lhs_types::{Array, ArrayIterator, Map, MapIter, MapValuesIntoIter}, rhs_types::{Bytes, IntRange, IpRange, UninhabitedArray, UninhabitedBool, UninhabitedMap}, scheme::{FieldIndex, IndexAccessError}, strict_partial_ord::StrictPartialOrd, @@ -138,15 +138,6 @@ pub struct TypeMismatchError { pub actual: Type, } -/// An error that occurs on a type mismatch. -#[derive(Debug, PartialEq, Eq, Error)] -pub enum SetValueError { - #[error("{0}")] - TypeMismatch(#[source] TypeMismatchError), - #[error("{0}")] - IndexAccess(#[source] IndexAccessError), -} - macro_rules! replace_underscore { ($name:ident ($val_ty:ty)) => { Type::$name(_) @@ -418,6 +409,14 @@ impl Type { pub fn map(ty: impl Into) -> Self { Self::Map(ty.into()) } + + /// Deserializes a value based on its type. + pub fn deserialize_value<'de, D>(&self, deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + LhsValueSeed(self).deserialize(deserializer) + } } impl std::fmt::Display for Type { @@ -853,42 +852,6 @@ impl<'a> LhsValue<'a> { .try_fold(self, |value, idx| value.extract(idx).unwrap()) } - /// Set an element in an LhsValue given a path item and a specified value. - /// Returns a TypeMismatchError error if current type does not support - /// nested element or if value type is invalid. - /// Only LhsValyue::Map supports nested elements for now. - pub fn set>>( - &mut self, - item: FieldIndex, - value: V, - ) -> Result<(), SetValueError> { - let value = value.into(); - match item { - FieldIndex::ArrayIndex(idx) => match self { - LhsValue::Array(arr) => arr - .insert(idx as usize, value) - .map_err(SetValueError::TypeMismatch), - _ => Err(SetValueError::IndexAccess(IndexAccessError { - index: item, - actual: self.get_type(), - })), - }, - FieldIndex::MapKey(name) => match self { - LhsValue::Map(map) => map - .insert(name.as_bytes(), value) - .map_err(SetValueError::TypeMismatch), - _ => Err(SetValueError::IndexAccess(IndexAccessError { - index: FieldIndex::MapKey(name), - actual: self.get_type(), - })), - }, - FieldIndex::MapEach => Err(SetValueError::IndexAccess(IndexAccessError { - index: item, - actual: self.get_type(), - })), - } - } - /// Returns an iterator over the Map or Array pub fn iter(&'a self) -> Option> { match self { @@ -1185,38 +1148,6 @@ declare_types!( Map[CompoundType](#[serde(skip_deserializing)] Map<'a> | UninhabitedMap | UninhabitedMap), ); -/// Wrapper type around mutable `LhsValue` to prevent -/// illegal operations like changing the type of values -/// in an `Array` or a `Map`. -pub enum LhsValueMut<'a, 'b> { - /// A mutable boolean. - Bool(&'a mut bool), - /// A mutable 32-bit integer number. - Int(&'a mut i64), - /// A mutable IPv4 or IPv6 address. - Ip(&'a mut IpAddr), - /// A mutable byte string. - Bytes(&'a mut Cow<'b, [u8]>), - /// A mutable array. - Array(ArrayMut<'a, 'b>), - /// A mutable map. - Map(MapMut<'a, 'b>), -} - -impl<'a, 'b> From<&'a mut LhsValue<'b>> for LhsValueMut<'a, 'b> { - #[inline] - fn from(value: &'a mut LhsValue<'b>) -> Self { - match value { - LhsValue::Bool(b) => LhsValueMut::Bool(b), - LhsValue::Int(i) => LhsValueMut::Int(i), - LhsValue::Ip(ip) => LhsValueMut::Ip(ip), - LhsValue::Bytes(b) => LhsValueMut::Bytes(b), - LhsValue::Array(arr) => LhsValueMut::Array(arr.into()), - LhsValue::Map(map) => LhsValueMut::Map(map.into()), - } - } -} - #[test] fn test_lhs_value_deserialize() { use std::str::FromStr; diff --git a/ffi/include/wirefilter.h b/ffi/include/wirefilter.h index a50b99a2..23c01510 100644 --- a/ffi/include/wirefilter.h +++ b/ffi/include/wirefilter.h @@ -47,16 +47,12 @@ enum wirefilter_status { WIREFILTER_STATUS_PANIC, }; -struct wirefilter_array; - struct wirefilter_execution_context; struct wirefilter_filter; struct wirefilter_filter_ast; -struct wirefilter_map; - struct wirefilter_scheme; struct wirefilter_scheme_builder; @@ -175,6 +171,12 @@ bool wirefilter_deserialize_json_to_execution_context(struct wirefilter_executio void wirefilter_free_execution_context(struct wirefilter_execution_context *exec_context); +bool wirefilter_add_json_value_to_execution_context(struct wirefilter_execution_context *exec_context, + const char *name_ptr, + size_t name_len, + const uint8_t *json_ptr, + size_t json_len); + bool wirefilter_add_int_value_to_execution_context(struct wirefilter_execution_context *exec_context, const char *name_ptr, size_t name_len, @@ -201,87 +203,6 @@ bool wirefilter_add_bool_value_to_execution_context(struct wirefilter_execution_ size_t name_len, bool value); -bool wirefilter_add_map_value_to_execution_context(struct wirefilter_execution_context *exec_context, - const char *name_ptr, - size_t name_len, - struct wirefilter_map *value); - -bool wirefilter_add_array_value_to_execution_context(struct wirefilter_execution_context *exec_context, - const char *name_ptr, - size_t name_len, - struct wirefilter_array *value); - -struct wirefilter_map *wirefilter_create_map(struct wirefilter_type ty); - -bool wirefilter_add_int_value_to_map(struct wirefilter_map *map, - const uint8_t *name_ptr, - size_t name_len, - int64_t value); - -bool wirefilter_add_bytes_value_to_map(struct wirefilter_map *map, - const uint8_t *name_ptr, - size_t name_len, - const uint8_t *value_ptr, - size_t value_len); - -bool wirefilter_add_ipv6_value_to_map(struct wirefilter_map *map, - const uint8_t *name_ptr, - size_t name_len, - const uint8_t (*value)[16]); - -bool wirefilter_add_ipv4_value_to_map(struct wirefilter_map *map, - const uint8_t *name_ptr, - size_t name_len, - const uint8_t (*value)[4]); - -bool wirefilter_add_bool_value_to_map(struct wirefilter_map *map, - const uint8_t *name_ptr, - size_t name_len, - bool value); - -bool wirefilter_add_map_value_to_map(struct wirefilter_map *map, - const uint8_t *name_ptr, - size_t name_len, - struct wirefilter_map *value); - -bool wirefilter_add_array_value_to_map(struct wirefilter_map *map, - const uint8_t *name_ptr, - size_t name_len, - struct wirefilter_array *value); - -void wirefilter_free_map(struct wirefilter_map *map); - -struct wirefilter_array *wirefilter_create_array(struct wirefilter_type ty); - -bool wirefilter_add_int_value_to_array(struct wirefilter_array *array, - uint32_t index, - int64_t value); - -bool wirefilter_add_bytes_value_to_array(struct wirefilter_array *array, - uint32_t index, - const uint8_t *value_ptr, - size_t value_len); - -bool wirefilter_add_ipv6_value_to_array(struct wirefilter_array *array, - uint32_t index, - const uint8_t (*value)[16]); - -bool wirefilter_add_ipv4_value_to_array(struct wirefilter_array *array, - uint32_t index, - const uint8_t (*value)[4]); - -bool wirefilter_add_bool_value_to_array(struct wirefilter_array *array, uint32_t index, bool value); - -bool wirefilter_add_map_value_to_array(struct wirefilter_array *array, - uint32_t index, - struct wirefilter_map *value); - -bool wirefilter_add_array_value_to_array(struct wirefilter_array *array, - uint32_t index, - struct wirefilter_array *value); - -void wirefilter_free_array(struct wirefilter_array *array); - struct wirefilter_compiling_result wirefilter_compile_filter(struct wirefilter_filter_ast *filter_ast); struct wirefilter_matching_result wirefilter_match(const struct wirefilter_filter *filter, diff --git a/ffi/src/lib.rs b/ffi/src/lib.rs index 3c1c8925..ccc0f258 100644 --- a/ffi/src/lib.rs +++ b/ffi/src/lib.rs @@ -17,7 +17,7 @@ use std::{ io::{self, Write}, net::IpAddr, }; -use wirefilter::{AlwaysList, LhsValue, NeverList, Type, catch_panic}; +use wirefilter::{AlwaysList, GetType, NeverList, Type, catch_panic}; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -192,30 +192,6 @@ pub struct ExecutionContext<'s>(wirefilter::ExecutionContext<'s>); wrap_type!(ExecutionContext<'s>); -#[derive(Debug, PartialEq)] -#[repr(Rust)] -pub struct Array<'s>(wirefilter::Array<'s>); - -wrap_type!(Array<'s>); - -impl<'s> From> for LhsValue<'s> { - fn from(array: Array<'s>) -> Self { - Self::Array(array.into()) - } -} - -#[derive(Debug, PartialEq)] -#[repr(Rust)] -pub struct Map<'s>(wirefilter::Map<'s>); - -wrap_type!(Map<'s>); - -impl<'s> From> for LhsValue<'s> { - fn from(map: Map<'s>) -> Self { - Self::Map(map.into()) - } -} - #[derive(Debug, PartialEq)] #[repr(Rust)] pub struct FilterAst(wirefilter::FilterAst); @@ -548,7 +524,7 @@ pub extern "C" fn wirefilter_deserialize_json_to_execution_context( ) -> bool { assert!(!json_ptr.is_null()); let json = unsafe { std::slice::from_raw_parts(json_ptr, json_len) }; - let mut deserializer = serde_json::Deserializer::from_slice(json); + let mut deserializer = serde_json::Deserializer::from_reader(json); match exec_context.deserialize(&mut deserializer) { Ok(_) => true, Err(err) => { @@ -563,6 +539,39 @@ pub extern "C" fn wirefilter_free_execution_context(exec_context: Box, + name_ptr: *const c_char, + name_len: usize, + json_ptr: *const u8, + json_len: usize, +) -> bool { + let name = to_str!(name_ptr, name_len); + let json = unsafe { std::slice::from_raw_parts(json_ptr, json_len) }; + let ty = match exec_context.scheme().get_field(name) { + Ok(field) => field.get_type(), + Err(err) => { + write_last_error!("{}", err); + return false; + } + }; + let value = match ty.deserialize_value(&mut serde_json::Deserializer::from_reader(json)) { + Ok(value) => value, + Err(err) => { + write_last_error!("{}", err); + return false; + } + }; + match exec_context.set_field_value_from_name(name, value) { + Ok(_) => true, + Err(err) => { + write_last_error!("{}", err); + false + } + } +} + #[unsafe(no_mangle)] pub extern "C" fn wirefilter_add_int_value_to_execution_context( exec_context: &mut ExecutionContext<'_>, @@ -625,205 +634,6 @@ pub extern "C" fn wirefilter_add_bool_value_to_execution_context( exec_context.set_field_value_from_name(name, value).is_ok() } -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_map_value_to_execution_context<'a>( - exec_context: &mut ExecutionContext<'a>, - name_ptr: *const c_char, - name_len: usize, - value: Box>, -) -> bool { - let name = to_str!(name_ptr, name_len); - exec_context.set_field_value_from_name(name, *value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_array_value_to_execution_context<'a>( - exec_context: &mut ExecutionContext<'a>, - name_ptr: *const c_char, - name_len: usize, - value: Box>, -) -> bool { - let name = to_str!(name_ptr, name_len); - exec_context.set_field_value_from_name(name, *value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_create_map<'a>(ty: CType) -> Box> { - Box::new(Map(wirefilter::Map::new(Type::from(ty)))) -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_int_value_to_map( - map: &mut Map<'_>, - name_ptr: *const u8, - name_len: usize, - value: i64, -) -> bool { - assert!(!name_ptr.is_null()); - let name = unsafe { std::slice::from_raw_parts(name_ptr, name_len) }; - map.insert(name, value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_bytes_value_to_map( - map: &mut Map<'_>, - name_ptr: *const u8, - name_len: usize, - value_ptr: *const u8, - value_len: usize, -) -> bool { - assert!(!name_ptr.is_null()); - let name = unsafe { std::slice::from_raw_parts(name_ptr, name_len) }; - assert!(!value_ptr.is_null()); - let value = unsafe { std::slice::from_raw_parts(value_ptr, value_len) }; - map.insert(name, value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_ipv6_value_to_map( - map: &mut Map<'_>, - name_ptr: *const u8, - name_len: usize, - value: &[u8; 16], -) -> bool { - assert!(!name_ptr.is_null()); - let name = unsafe { std::slice::from_raw_parts(name_ptr, name_len) }; - let value = IpAddr::from(*value); - map.insert(name, value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_ipv4_value_to_map( - map: &mut Map<'_>, - name_ptr: *const u8, - name_len: usize, - value: &[u8; 4], -) -> bool { - assert!(!name_ptr.is_null()); - let name = unsafe { std::slice::from_raw_parts(name_ptr, name_len) }; - let value = IpAddr::from(*value); - map.insert(name, value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_bool_value_to_map( - map: &mut Map<'_>, - name_ptr: *const u8, - name_len: usize, - value: bool, -) -> bool { - assert!(!name_ptr.is_null()); - let name = unsafe { std::slice::from_raw_parts(name_ptr, name_len) }; - map.insert(name, value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_map_value_to_map<'a>( - map: &mut Map<'a>, - name_ptr: *const u8, - name_len: usize, - value: Box>, -) -> bool { - assert!(!name_ptr.is_null()); - let name = unsafe { std::slice::from_raw_parts(name_ptr, name_len) }; - map.insert(name, *value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_array_value_to_map<'a>( - map: &mut Map<'a>, - name_ptr: *const u8, - name_len: usize, - value: Box>, -) -> bool { - assert!(!name_ptr.is_null()); - let name = unsafe { std::slice::from_raw_parts(name_ptr, name_len) }; - map.insert(name, *value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_free_map(map: Box>) { - drop(map) -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_create_array<'a>(ty: CType) -> Box> { - Box::new(Array(wirefilter::Array::new(Type::from(ty)))) -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_int_value_to_array( - array: &mut Array<'_>, - index: u32, - value: i64, -) -> bool { - array.insert(index.try_into().unwrap(), value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_bytes_value_to_array( - array: &mut Array<'_>, - index: u32, - value_ptr: *const u8, - value_len: usize, -) -> bool { - assert!(!value_ptr.is_null()); - let value = unsafe { std::slice::from_raw_parts(value_ptr, value_len) }; - array.insert(index.try_into().unwrap(), value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_ipv6_value_to_array( - array: &mut Array<'_>, - index: u32, - value: &[u8; 16], -) -> bool { - let value = IpAddr::from(*value); - array.insert(index.try_into().unwrap(), value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_ipv4_value_to_array( - array: &mut Array<'_>, - index: u32, - value: &[u8; 4], -) -> bool { - let value = IpAddr::from(*value); - array.insert(index.try_into().unwrap(), value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_bool_value_to_array( - array: &mut Array<'_>, - index: u32, - value: bool, -) -> bool { - array.insert(index.try_into().unwrap(), value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_map_value_to_array<'a>( - array: &mut Array<'a>, - index: u32, - value: Box>, -) -> bool { - array.insert(index.try_into().unwrap(), *value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_add_array_value_to_array<'a>( - array: &mut Array<'a>, - index: u32, - value: Box>, -) -> bool { - array.insert(index.try_into().unwrap(), *value).is_ok() -} - -#[unsafe(no_mangle)] -pub extern "C" fn wirefilter_free_array(array: Box>) { - drop(array) -} - #[derive(Debug)] #[repr(C)] pub struct CompilingResult { @@ -1021,6 +831,7 @@ pub extern "C" fn wirefilter_get_version() -> StaticRustAllocatedString { mod ffi_test { use super::*; use regex_automata::meta::Regex; + use serde_json::json; use std::ffi::CStr; impl RustAllocatedString { @@ -1137,49 +948,27 @@ mod ffi_test { 1337, )); - let mut map1 = wirefilter_create_map(Type::Int.into()); - - let key = b"key"; - wirefilter_add_int_value_to_map(&mut map1, key.as_ptr(), key.len(), 42); - - wirefilter_add_int_value_to_map(&mut map1, invalid_key.as_ptr(), invalid_key.len(), 42); + let json = json!([["key", 42], [invalid_key, 42]]).to_string(); let field = "map1"; - wirefilter_add_map_value_to_execution_context( + assert!(wirefilter_add_json_value_to_execution_context( &mut exec_context, field.as_ptr().cast(), field.len(), - map1, - ); - - let mut map2 = wirefilter_create_map(Type::Bytes.into()); - - let key = b"key"; - let value = "value"; - wirefilter_add_bytes_value_to_map( - &mut map2, - key.as_ptr(), - key.len(), - value.as_ptr(), - value.len(), - ); + json.as_bytes().as_ptr(), + json.len(), + )); - let value = "value"; - wirefilter_add_bytes_value_to_map( - &mut map2, - invalid_key.as_ptr(), - invalid_key.len(), - value.as_ptr(), - value.len(), - ); + let json = json!([["key", "value"], [invalid_key, "value"]]).to_string(); let field = "map2"; - wirefilter_add_map_value_to_execution_context( + assert!(wirefilter_add_json_value_to_execution_context( &mut exec_context, field.as_ptr().cast(), field.len(), - map2, - ); + json.as_ptr().cast(), + json.len(), + )); exec_context } diff --git a/ffi/tests/ctests/src/tests.c b/ffi/tests/ctests/src/tests.c index e7fba272..cab13e8c 100644 --- a/ffi/tests/ctests/src/tests.c +++ b/ffi/tests/ctests/src/tests.c @@ -461,24 +461,6 @@ void wirefilter_ffi_ctest_add_values_to_execution_context_errors() { 80 ) == false, "managed to set value for non-existent int field"); - struct wirefilter_map *more_http_headers = wirefilter_create_map( - WIREFILTER_TYPE_BYTES - ); - rust_assert(wirefilter_add_map_value_to_execution_context( - exec_ctx, - STRING("doesnotexist"), - more_http_headers - ) == false, "managed to set value for non-existent map field"); - - struct wirefilter_array *http_cookies = wirefilter_create_array( - WIREFILTER_TYPE_BYTES - ); - rust_assert(wirefilter_add_array_value_to_execution_context( - exec_ctx, - STRING("doesnotexist"), - http_cookies - ) == false, "managed to set value for non-existent array field"); - wirefilter_free_execution_context(exec_ctx); wirefilter_free_scheme(scheme); @@ -701,22 +683,16 @@ void wirefilter_ffi_ctest_match_map() { 80 ); - struct wirefilter_map *http_headers = wirefilter_create_map( - WIREFILTER_TYPE_BYTES + const char *json = "{\"host\":\"www.cloudflare.com\"}"; + rust_assert( + wirefilter_add_json_value_to_execution_context( + exec_ctx, + STRING("http.headers"), + BYTES(json) + ) == true, + "could not set value for map field http.headers" ); - rust_assert(wirefilter_add_bytes_value_to_map( - http_headers, - BYTES("host"), - BYTES("www.cloudflare.com") - ), "could not add bytes value to map"); - - rust_assert(wirefilter_add_map_value_to_execution_context( - exec_ctx, - STRING("http.headers"), - http_headers - ) == true, "could not set value for map field http.headers"); - struct wirefilter_matching_result matching_result = wirefilter_match(filter, exec_ctx); rust_assert(matching_result.status == WIREFILTER_STATUS_SUCCESS, "could not match filter"); @@ -773,34 +749,16 @@ void wirefilter_ffi_ctest_match_array() { 80 ); - struct wirefilter_array *http_cookies = wirefilter_create_array( - WIREFILTER_TYPE_BYTES + const char *json = "[\"one\", \"two\", \"www.cloudflare.com\"]"; + rust_assert( + wirefilter_add_json_value_to_execution_context( + exec_ctx, + STRING("http.cookies"), + BYTES(json) + ) == true, + "could not set value for map field http.cookies" ); - rust_assert(wirefilter_add_bytes_value_to_array( - http_cookies, - 0, - BYTES("one") - ), "could not add bytes value to array"); - - rust_assert(wirefilter_add_bytes_value_to_array( - http_cookies, - 1, - BYTES("two") - ), "could not add bytes value to array"); - - rust_assert(wirefilter_add_bytes_value_to_array( - http_cookies, - 2, - BYTES("www.cloudflare.com") - ), "could not add bytes value to array"); - - rust_assert(wirefilter_add_array_value_to_execution_context( - exec_ctx, - STRING("http.cookies"), - http_cookies - ) == true, "could not set value for map field http.cookies"); - struct wirefilter_matching_result matching_result = wirefilter_match(filter, exec_ctx); rust_assert(matching_result.status == WIREFILTER_STATUS_SUCCESS, "could not match filter");