diff --git a/engine/src/ast/field_expr.rs b/engine/src/ast/field_expr.rs index ca121b16..cf0e00de 100644 --- a/engine/src/ast/field_expr.rs +++ b/engine/src/ast/field_expr.rs @@ -799,7 +799,7 @@ impl Expr for ComparisonExpr { mod tests { use super::*; use crate::{ - BytesFormat, FieldRef, LhsValue, ParserSettings, TypedMap, + BytesFormat, FieldRef, LhsValue, ParserSettings, SchemeBuilder, TypedMap, ast::{ function_expr::{FunctionCallArgExpr, FunctionCallExpr}, logical_expr::LogicalExpr, @@ -3040,4 +3040,211 @@ mod tests { assert_eq!(expr.execute_one(ctx), expected, "failed test case {t:?}"); } } + + #[test] + fn test_optional_fields() { + let mut builder = SchemeBuilder::new(); + builder + .add_optional_field("tcp.srcport", Type::Int) + .unwrap(); + builder + .add_optional_field("tcp.dstport", Type::Int) + .unwrap(); + builder + .add_optional_field("tcp.flags.syn", Type::Bool) + .unwrap(); + builder + .add_optional_field("udp.srcport", Type::Int) + .unwrap(); + builder + .add_optional_field("udp.dstport", Type::Int) + .unwrap(); + builder.set_nil_not_equal_behavior(false); + let scheme = builder.build(); + + macro_rules! test_case { + ($filter:ident {$($name:ident $(. $suffix:ident)*: $value:literal),*} => $outcome:literal) => {{ + #[allow(unused_mut)] + let mut ctx = ExecutionContext::<()>::new(&scheme); + $( + ctx.set_field_value(scheme.get_field(stringify!($name $(. $suffix)*)).unwrap(), $value).unwrap(); + )* + + assert_eq!($filter.execute(&ctx), Ok($outcome)); + }}; + } + + let filter = scheme + .parse("(tcp.dstport != 80) or (udp.dstport != 80)") + .unwrap() + .compile(); + + test_case!(filter { tcp.dstport: 443 } => true); + + test_case!(filter { tcp.dstport: 80 } => false); + + test_case!(filter { udp.dstport: 53 } => true); + + test_case!(filter { udp.dstport: 80 } => false); + + test_case!(filter {} => false); + + let filter = scheme + .parse("(tcp.dstport != 80) and (udp.dstport != 80)") + .unwrap() + .compile(); + + test_case!(filter { tcp.dstport: 443 } => false); + + test_case!(filter { tcp.dstport: 80 } => false); + + test_case!(filter { udp.dstport: 53 } => false); + + test_case!(filter { udp.dstport: 80 } => false); + + test_case!(filter {} => false); + + let filter = scheme + .parse("(tcp.srcport == 1337) or ((tcp.dstport != 80) or (udp.dstport != 80))") + .unwrap() + .compile(); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => true); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => true); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => true); + + test_case!(filter { udp.dstport: 80 } => false); + + test_case!(filter { udp.dstport: 444 } => true); + + test_case!(filter {} => false); + + let filter = scheme + .parse("(tcp.srcport == 1337) and ((tcp.dstport != 80) or (udp.dstport != 80))") + .unwrap() + .compile(); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => false); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => true); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => false); + + test_case!(filter { udp.dstport: 80 } => false); + + test_case!(filter { udp.dstport: 444 } => false); + + test_case!(filter {} => false); + + let filter = scheme + .parse("(tcp.srcport == 1337) or ((tcp.dstport != 80) and (udp.dstport != 80))") + .unwrap() + .compile(); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => true); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => true); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => false); + + test_case!(filter { udp.dstport: 80 } => false); + + test_case!(filter { udp.dstport: 444 } => false); + + test_case!(filter {} => false); + + let filter = scheme + .parse("(tcp.srcport == 1337) and ((tcp.dstport != 80) and (udp.dstport != 80))") + .unwrap() + .compile(); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => false); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => false); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => false); + + test_case!(filter { udp.dstport: 80 } => false); + + test_case!(filter { udp.dstport: 444 } => false); + + test_case!(filter {} => false); + + let filter = scheme + .parse("(tcp.srcport == 1337) and ((tcp.dstport != 80) and ((tcp.flags.syn) or (udp.dstport != 80)))") + .unwrap() + .compile(); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80, tcp.flags.syn: true } => false); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443, tcp.flags.syn: true } => true); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80, tcp.flags.syn: true } => false); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443, tcp.flags.syn: true } => false); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80, tcp.flags.syn: false } => false); + + test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443, tcp.flags.syn: false } => false); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80, tcp.flags.syn: false } => false); + + test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443, tcp.flags.syn: false } => false); + + test_case!(filter { udp.dstport: 80 } => false); + + test_case!(filter { udp.dstport: 444 } => false); + + test_case!(filter {} => false); + + let filter = scheme.parse("tcp.flags.syn").unwrap().compile(); + + test_case!(filter { tcp.flags.syn: true } => true); + + test_case!(filter { tcp.flags.syn: false } => false); + + test_case!(filter {} => false); + + let filter = scheme.parse("not tcp.flags.syn").unwrap().compile(); + + test_case!(filter { tcp.flags.syn: true } => false); + + test_case!(filter { tcp.flags.syn: false } => true); + + test_case!(filter {} => true); + + let filter = scheme.parse("not (not tcp.flags.syn)").unwrap().compile(); + + test_case!(filter { tcp.flags.syn: true } => true); + + test_case!(filter { tcp.flags.syn: false } => false); + + test_case!(filter {} => false); + + let filter = scheme.parse("not (tcp.dstport eq 80)").unwrap().compile(); + + test_case!(filter { tcp.dstport: 80 } => false); + + test_case!(filter { tcp.dstport: 443 } => true); + + test_case!(filter {} => true); + + let filter = scheme.parse("not (tcp.dstport ne 80)").unwrap().compile(); + + test_case!(filter { tcp.dstport: 80 } => true); + + test_case!(filter { tcp.dstport: 443 } => false); + + test_case!(filter {} => true); + } } diff --git a/engine/src/ast/index_expr.rs b/engine/src/ast/index_expr.rs index 9f3a3540..b554845e 100644 --- a/engine/src/ast/index_expr.rs +++ b/engine/src/ast/index_expr.rs @@ -79,7 +79,9 @@ impl ValueExpr for IndexExpr { // Fast path match identifier { IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| { - Ok(ctx.get_field_value_unchecked(&f).as_ref()) + ctx.get_field_value_unchecked(&f) + .map(LhsValue::as_ref) + .ok_or(ty) }), IdentifierExpr::FunctionCallExpr(call) => compiler.compile_function_call_expr(call), } @@ -88,7 +90,7 @@ impl ValueExpr for IndexExpr { match identifier { IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| { ctx.get_field_value_unchecked(&f) - .get_nested(&indexes[..last]) + .and_then(|value| value.get_nested(&indexes[..last])) .map(LhsValue::as_ref) .ok_or(ty) }), @@ -103,18 +105,23 @@ impl ValueExpr for IndexExpr { } } } else { + let return_type = Type::Array(ty.into()); // Slow path match identifier { IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| { let mut iter = MapEachIterator::from_indexes(&indexes[..]); - iter.reset(ctx.get_field_value_unchecked(&f).as_ref()); + iter.reset( + ctx.get_field_value_unchecked(&f) + .map(LhsValue::as_ref) + .ok_or(return_type)?, + ); Ok(LhsValue::Array(Array::try_from_iter(ty, iter).unwrap())) }), IdentifierExpr::FunctionCallExpr(call) => { let call = compiler.compile_function_call_expr(call); CompiledValueExpr::new(move |ctx| { let mut iter = MapEachIterator::from_indexes(&indexes[..]); - iter.reset(call.execute(ctx).map_err(|_| Type::Array(ty.into()))?); + iter.reset(call.execute(ctx).map_err(|_| return_type)?); Ok(LhsValue::Array(Array::try_from_iter(ty, iter).unwrap())) }) } @@ -174,12 +181,14 @@ impl IndexExpr { IdentifierExpr::Field(f) => { if indexes.is_empty() { CompiledOneExpr::new(move |ctx| { - comp.compare(ctx.get_field_value_unchecked(&f), ctx) + ctx.get_field_value_unchecked(&f) + .map(|value| comp.compare(value, ctx)) + .unwrap_or(default) }) } else { CompiledOneExpr::new(move |ctx| { ctx.get_field_value_unchecked(&f) - .get_nested(&indexes) + .and_then(|value| value.get_nested(&indexes)) .map_or( default, #[inline] @@ -222,7 +231,7 @@ impl IndexExpr { IdentifierExpr::Field(f) => CompiledVecExpr::new(move |ctx| { let comp = ∁ ctx.get_field_value_unchecked(&f) - .get_nested(&indexes) + .and_then(|value| value.get_nested(&indexes)) .map_or( BOOL_ARRAY, #[inline] @@ -248,7 +257,10 @@ impl IndexExpr { match identifier { IdentifierExpr::Field(f) => CompiledVecExpr::new(move |ctx| { let mut iter = MapEachIterator::from_indexes(&indexes[..]); - iter.reset(ctx.get_field_value_unchecked(&f).as_ref()); + match ctx.get_field_value_unchecked(&f) { + Some(value) => iter.reset(value.as_ref()), + None => return TypedArray::default(), + }; TypedArray::from_iter(iter.map(|item| comp.compare(&item, ctx))) }), IdentifierExpr::FunctionCallExpr(call) => { diff --git a/engine/src/execution_context.rs b/engine/src/execution_context.rs index 6fc40999..92c8fd8e 100644 --- a/engine/src/execution_context.rs +++ b/engine/src/execution_context.rs @@ -129,7 +129,7 @@ impl<'e, U> ExecutionContext<'e, U> { } #[inline] - pub(crate) fn get_field_value_unchecked(&self, field: &Field) -> &LhsValue<'_> { + pub(crate) fn get_field_value_unchecked(&self, field: &Field) -> Option<&LhsValue<'_>> { // This is safe because this code is reachable only from Filter::execute // which already performs the scheme compatibility check, but check that // invariant holds in the future at least in the debug mode. @@ -138,12 +138,19 @@ impl<'e, U> ExecutionContext<'e, U> { // For now we panic in this, but later we are going to align behaviour // with wireshark: resolve all subexpressions that don't have RHS value // to `false`. - self.values[field.index()].as_ref().unwrap_or_else(|| { - panic!( - "Field {} was registered but not given a value", - field.name() - ); - }) + match self.values[field.index()].as_ref() { + Some(value) => Some(value), + None => { + if field.optional() { + None + } else { + panic!( + "Field {} was registered as mandatory but not given a value", + field.name() + ); + } + } + } } /// Get the value of a field. diff --git a/engine/src/scheme.rs b/engine/src/scheme.rs index 7e7a4468..48e55a15 100644 --- a/engine/src/scheme.rs +++ b/engine/src/scheme.rs @@ -149,6 +149,12 @@ impl<'s> FieldRef<'s> { self.index } + /// Returns whether the field value is optional. + #[inline] + pub fn optional(&self) -> bool { + self.scheme.inner.fields[self.index].optional + } + /// Returns the [`Scheme`](struct@Scheme) to which this field belongs to. #[inline] pub fn scheme(&self) -> &'s Scheme { @@ -227,6 +233,12 @@ impl Field { self.index } + /// Returns whether the field value is optional. + #[inline] + pub fn optional(&self) -> bool { + self.scheme.inner.fields[self.index].optional + } + /// Returns the [`Scheme`](struct@Scheme) to which this field belongs to. #[inline] pub fn scheme(&self) -> &Scheme { @@ -614,6 +626,7 @@ type IdentifierName = Arc; struct FieldDefinition { name: IdentifierName, ty: Type, + optional: bool, } /// A builder for a [`Scheme`]. @@ -635,13 +648,13 @@ impl SchemeBuilder { Default::default() } - /// Registers a field and its corresponding type. - pub fn add_field>( + fn add_field_full( &mut self, - name: N, + name: Arc, ty: Type, + optional: bool, ) -> Result<(), IdentifierRedefinitionError> { - match self.items.entry(name.as_ref().into()) { + match self.items.entry(name) { Entry::Occupied(entry) => match entry.get() { SchemeItem::Field(_) => Err(IdentifierRedefinitionError::Field( FieldRedefinitionError(entry.key().to_string()), @@ -655,6 +668,7 @@ impl SchemeBuilder { self.fields.push(FieldDefinition { name: entry.key().clone(), ty, + optional, }); entry.insert(SchemeItem::Field(index)); Ok(()) @@ -662,6 +676,24 @@ impl SchemeBuilder { } } + /// Registers a field and its corresponding type. + pub fn add_field>( + &mut self, + name: N, + ty: Type, + ) -> Result<(), IdentifierRedefinitionError> { + self.add_field_full(name.as_ref().into(), ty, false) + } + + /// Registers an optional field and its corresponding type. + pub fn add_optional_field>( + &mut self, + name: N, + ty: Type, + ) -> Result<(), IdentifierRedefinitionError> { + self.add_field_full(name.as_ref().into(), ty, true) + } + /// Registers a function pub fn add_function>( &mut self, @@ -762,6 +794,7 @@ impl Hash for Scheme { struct SerdeField { #[serde(rename = "type")] ty: Type, + optional: bool, } impl Serialize for Scheme { @@ -772,7 +805,13 @@ impl Serialize for Scheme { let fields = self.fields(); let mut map = serializer.serialize_map(Some(fields.len()))?; for f in fields { - map.serialize_entry(f.name(), &SerdeField { ty: f.get_type() })?; + map.serialize_entry( + f.name(), + &SerdeField { + ty: f.get_type(), + optional: f.optional(), + }, + )?; } map.end() } @@ -799,8 +838,12 @@ impl<'de> Deserialize<'de> for Scheme { A: serde::de::MapAccess<'de>, { let mut builder = SchemeBuilder::new(); - while let Some((name, SerdeField { ty })) = map.next_entry::<&str, SerdeField>()? { - builder.add_field(name, ty).map_err(A::Error::custom)?; + while let Some((name, SerdeField { ty, optional })) = + map.next_entry::<&str, SerdeField>()? + { + builder + .add_field_full(name.into(), ty, optional) + .map_err(A::Error::custom)?; } Ok(builder) diff --git a/ffi/tests/ctests/src/tests.c b/ffi/tests/ctests/src/tests.c index 8eded4b3..c6ddf27c 100644 --- a/ffi/tests/ctests/src/tests.c +++ b/ffi/tests/ctests/src/tests.c @@ -305,7 +305,7 @@ void wirefilter_ffi_ctest_scheme_serialize() { rust_assert(json.ptr != NULL && json.len > 0, "could not serialize scheme to JSON"); rust_assert( - strncmp(json.ptr, "{\"http.host\":{\"type\":\"Bytes\"},\"ip.src\":{\"type\":\"Ip\"},\"ip.dst\":{\"type\":\"Ip\"},\"ssl\":{\"type\":\"Bool\"},\"tcp.port\":{\"type\":\"Int\"},\"http.headers\":{\"type\":{\"Map\":\"Bytes\"}},\"http.cookies\":{\"type\":{\"Array\":\"Bytes\"}}}", json.len) == 0, + strncmp(json.ptr, "{\"http.host\":{\"type\":\"Bytes\",\"optional\":false},\"ip.src\":{\"type\":\"Ip\",\"optional\":false},\"ip.dst\":{\"type\":\"Ip\",\"optional\":false},\"ssl\":{\"type\":\"Bool\",\"optional\":false},\"tcp.port\":{\"type\":\"Int\",\"optional\":false},\"http.headers\":{\"type\":{\"Map\":\"Bytes\"},\"optional\":false},\"http.cookies\":{\"type\":{\"Array\":\"Bytes\"},\"optional\":false}}", json.len) == 0, "invalid JSON serialization" );