Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 208 additions & 1 deletion engine/src/ast/field_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ impl Expr for ComparisonExpr {
mod tests {
use super::*;
use crate::{
BytesFormat, FieldRef, LhsValue, ParserSettings, TypedMap,
BytesFormat, FieldRef, LhsValue, ParserSettings, SchemeBuilder, TypedMap,
ast::{
function_expr::{FunctionCallArgExpr, FunctionCallExpr},
logical_expr::LogicalExpr,
Expand Down Expand Up @@ -3040,4 +3040,211 @@ mod tests {
assert_eq!(expr.execute_one(ctx), expected, "failed test case {t:?}");
}
}

#[test]
fn test_optional_fields() {
let mut builder = SchemeBuilder::new();
builder
.add_optional_field("tcp.srcport", Type::Int)
.unwrap();
builder
.add_optional_field("tcp.dstport", Type::Int)
.unwrap();
builder
.add_optional_field("tcp.flags.syn", Type::Bool)
.unwrap();
builder
.add_optional_field("udp.srcport", Type::Int)
.unwrap();
builder
.add_optional_field("udp.dstport", Type::Int)
.unwrap();
builder.set_nil_not_equal_behavior(false);
let scheme = builder.build();

macro_rules! test_case {
($filter:ident {$($name:ident $(. $suffix:ident)*: $value:literal),*} => $outcome:literal) => {{
#[allow(unused_mut)]
let mut ctx = ExecutionContext::<()>::new(&scheme);
$(
ctx.set_field_value(scheme.get_field(stringify!($name $(. $suffix)*)).unwrap(), $value).unwrap();
)*

assert_eq!($filter.execute(&ctx), Ok($outcome));
}};
}

let filter = scheme
.parse("(tcp.dstport != 80) or (udp.dstport != 80)")
.unwrap()
.compile();

test_case!(filter { tcp.dstport: 443 } => true);

test_case!(filter { tcp.dstport: 80 } => false);

test_case!(filter { udp.dstport: 53 } => true);

test_case!(filter { udp.dstport: 80 } => false);

test_case!(filter {} => false);

let filter = scheme
.parse("(tcp.dstport != 80) and (udp.dstport != 80)")
.unwrap()
.compile();

test_case!(filter { tcp.dstport: 443 } => false);

test_case!(filter { tcp.dstport: 80 } => false);

test_case!(filter { udp.dstport: 53 } => false);

test_case!(filter { udp.dstport: 80 } => false);

test_case!(filter {} => false);

let filter = scheme
.parse("(tcp.srcport == 1337) or ((tcp.dstport != 80) or (udp.dstport != 80))")
.unwrap()
.compile();

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => true);

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => true);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => true);

test_case!(filter { udp.dstport: 80 } => false);

test_case!(filter { udp.dstport: 444 } => true);

test_case!(filter {} => false);

let filter = scheme
.parse("(tcp.srcport == 1337) and ((tcp.dstport != 80) or (udp.dstport != 80))")
.unwrap()
.compile();

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => false);

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => true);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => false);

test_case!(filter { udp.dstport: 80 } => false);

test_case!(filter { udp.dstport: 444 } => false);

test_case!(filter {} => false);

let filter = scheme
.parse("(tcp.srcport == 1337) or ((tcp.dstport != 80) and (udp.dstport != 80))")
.unwrap()
.compile();

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => true);

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => true);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => false);

test_case!(filter { udp.dstport: 80 } => false);

test_case!(filter { udp.dstport: 444 } => false);

test_case!(filter {} => false);

let filter = scheme
.parse("(tcp.srcport == 1337) and ((tcp.dstport != 80) and (udp.dstport != 80))")
.unwrap()
.compile();

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => false);

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => false);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => false);

test_case!(filter { udp.dstport: 80 } => false);

test_case!(filter { udp.dstport: 444 } => false);

test_case!(filter {} => false);

let filter = scheme
.parse("(tcp.srcport == 1337) and ((tcp.dstport != 80) and ((tcp.flags.syn) or (udp.dstport != 80)))")
.unwrap()
.compile();

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80, tcp.flags.syn: true } => false);

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443, tcp.flags.syn: true } => true);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80, tcp.flags.syn: true } => false);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443, tcp.flags.syn: true } => false);

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80, tcp.flags.syn: false } => false);

test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443, tcp.flags.syn: false } => false);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80, tcp.flags.syn: false } => false);

test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443, tcp.flags.syn: false } => false);

test_case!(filter { udp.dstport: 80 } => false);

test_case!(filter { udp.dstport: 444 } => false);

test_case!(filter {} => false);

let filter = scheme.parse("tcp.flags.syn").unwrap().compile();

test_case!(filter { tcp.flags.syn: true } => true);

test_case!(filter { tcp.flags.syn: false } => false);

test_case!(filter {} => false);

let filter = scheme.parse("not tcp.flags.syn").unwrap().compile();

test_case!(filter { tcp.flags.syn: true } => false);

test_case!(filter { tcp.flags.syn: false } => true);

test_case!(filter {} => true);

let filter = scheme.parse("not (not tcp.flags.syn)").unwrap().compile();

test_case!(filter { tcp.flags.syn: true } => true);

test_case!(filter { tcp.flags.syn: false } => false);

test_case!(filter {} => false);

let filter = scheme.parse("not (tcp.dstport eq 80)").unwrap().compile();

test_case!(filter { tcp.dstport: 80 } => false);

test_case!(filter { tcp.dstport: 443 } => true);

test_case!(filter {} => true);

let filter = scheme.parse("not (tcp.dstport ne 80)").unwrap().compile();

test_case!(filter { tcp.dstport: 80 } => true);

test_case!(filter { tcp.dstport: 443 } => false);

test_case!(filter {} => true);
}
}
28 changes: 20 additions & 8 deletions engine/src/ast/index_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ impl ValueExpr for IndexExpr {
// Fast path
match identifier {
IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| {
Ok(ctx.get_field_value_unchecked(&f).as_ref())
ctx.get_field_value_unchecked(&f)
.map(LhsValue::as_ref)
.ok_or(ty)
}),
IdentifierExpr::FunctionCallExpr(call) => compiler.compile_function_call_expr(call),
}
Expand All @@ -88,7 +90,7 @@ impl ValueExpr for IndexExpr {
match identifier {
IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| {
ctx.get_field_value_unchecked(&f)
.get_nested(&indexes[..last])
.and_then(|value| value.get_nested(&indexes[..last]))
.map(LhsValue::as_ref)
.ok_or(ty)
}),
Expand All @@ -103,18 +105,23 @@ impl ValueExpr for IndexExpr {
}
}
} else {
let return_type = Type::Array(ty.into());
// Slow path
match identifier {
IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| {
let mut iter = MapEachIterator::from_indexes(&indexes[..]);
iter.reset(ctx.get_field_value_unchecked(&f).as_ref());
iter.reset(
ctx.get_field_value_unchecked(&f)
.map(LhsValue::as_ref)
.ok_or(return_type)?,
);
Ok(LhsValue::Array(Array::try_from_iter(ty, iter).unwrap()))
}),
IdentifierExpr::FunctionCallExpr(call) => {
let call = compiler.compile_function_call_expr(call);
CompiledValueExpr::new(move |ctx| {
let mut iter = MapEachIterator::from_indexes(&indexes[..]);
iter.reset(call.execute(ctx).map_err(|_| Type::Array(ty.into()))?);
iter.reset(call.execute(ctx).map_err(|_| return_type)?);
Ok(LhsValue::Array(Array::try_from_iter(ty, iter).unwrap()))
})
}
Expand Down Expand Up @@ -174,12 +181,14 @@ impl IndexExpr {
IdentifierExpr::Field(f) => {
if indexes.is_empty() {
CompiledOneExpr::new(move |ctx| {
comp.compare(ctx.get_field_value_unchecked(&f), ctx)
ctx.get_field_value_unchecked(&f)
.map(|value| comp.compare(value, ctx))
.unwrap_or(default)
})
} else {
CompiledOneExpr::new(move |ctx| {
ctx.get_field_value_unchecked(&f)
.get_nested(&indexes)
.and_then(|value| value.get_nested(&indexes))
.map_or(
default,
#[inline]
Expand Down Expand Up @@ -222,7 +231,7 @@ impl IndexExpr {
IdentifierExpr::Field(f) => CompiledVecExpr::new(move |ctx| {
let comp = &comp;
ctx.get_field_value_unchecked(&f)
.get_nested(&indexes)
.and_then(|value| value.get_nested(&indexes))
.map_or(
BOOL_ARRAY,
#[inline]
Expand All @@ -248,7 +257,10 @@ impl IndexExpr {
match identifier {
IdentifierExpr::Field(f) => CompiledVecExpr::new(move |ctx| {
let mut iter = MapEachIterator::from_indexes(&indexes[..]);
iter.reset(ctx.get_field_value_unchecked(&f).as_ref());
match ctx.get_field_value_unchecked(&f) {
Some(value) => iter.reset(value.as_ref()),
None => return TypedArray::default(),
};
TypedArray::from_iter(iter.map(|item| comp.compare(&item, ctx)))
}),
IdentifierExpr::FunctionCallExpr(call) => {
Expand Down
21 changes: 14 additions & 7 deletions engine/src/execution_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl<'e, U> ExecutionContext<'e, U> {
}

#[inline]
pub(crate) fn get_field_value_unchecked(&self, field: &Field) -> &LhsValue<'_> {
pub(crate) fn get_field_value_unchecked(&self, field: &Field) -> Option<&LhsValue<'_>> {
// This is safe because this code is reachable only from Filter::execute
// which already performs the scheme compatibility check, but check that
// invariant holds in the future at least in the debug mode.
Expand All @@ -138,12 +138,19 @@ impl<'e, U> ExecutionContext<'e, U> {
// For now we panic in this, but later we are going to align behaviour
// with wireshark: resolve all subexpressions that don't have RHS value
// to `false`.
self.values[field.index()].as_ref().unwrap_or_else(|| {
panic!(
"Field {} was registered but not given a value",
field.name()
);
})
match self.values[field.index()].as_ref() {
Some(value) => Some(value),
None => {
if field.optional() {
None
} else {
panic!(
"Field {} was registered as mandatory but not given a value",
field.name()
);
}
}
}
}

/// Get the value of a field.
Expand Down
Loading