Skip to content

Commit 590ceec

Browse files
committed
Add support for optional fields
1 parent b0341cb commit 590ceec

File tree

5 files changed

+293
-24
lines changed

5 files changed

+293
-24
lines changed

engine/src/ast/field_expr.rs

Lines changed: 208 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ impl Expr for ComparisonExpr {
799799
mod tests {
800800
use super::*;
801801
use crate::{
802-
BytesFormat, FieldRef, LhsValue, ParserSettings, TypedMap,
802+
BytesFormat, FieldRef, LhsValue, ParserSettings, SchemeBuilder, TypedMap,
803803
ast::{
804804
function_expr::{FunctionCallArgExpr, FunctionCallExpr},
805805
logical_expr::LogicalExpr,
@@ -3040,4 +3040,211 @@ mod tests {
30403040
assert_eq!(expr.execute_one(ctx), expected, "failed test case {t:?}");
30413041
}
30423042
}
3043+
3044+
#[test]
3045+
fn test_optional_fields() {
3046+
let mut builder = SchemeBuilder::new();
3047+
builder
3048+
.add_optional_field("tcp.srcport", Type::Int)
3049+
.unwrap();
3050+
builder
3051+
.add_optional_field("tcp.dstport", Type::Int)
3052+
.unwrap();
3053+
builder
3054+
.add_optional_field("tcp.flags.syn", Type::Bool)
3055+
.unwrap();
3056+
builder
3057+
.add_optional_field("udp.srcport", Type::Int)
3058+
.unwrap();
3059+
builder
3060+
.add_optional_field("udp.dstport", Type::Int)
3061+
.unwrap();
3062+
builder.set_nil_not_equal_behavior(false);
3063+
let scheme = builder.build();
3064+
3065+
macro_rules! test_case {
3066+
($filter:ident {$($name:ident $(. $suffix:ident)*: $value:literal),*} => $outcome:literal) => {{
3067+
#[allow(unused_mut)]
3068+
let mut ctx = ExecutionContext::<()>::new(&scheme);
3069+
$(
3070+
ctx.set_field_value(scheme.get_field(stringify!($name $(. $suffix)*)).unwrap(), $value).unwrap();
3071+
)*
3072+
3073+
assert_eq!($filter.execute(&ctx), Ok($outcome));
3074+
}};
3075+
}
3076+
3077+
let filter = scheme
3078+
.parse("(tcp.dstport != 80) or (udp.dstport != 80)")
3079+
.unwrap()
3080+
.compile();
3081+
3082+
test_case!(filter { tcp.dstport: 443 } => true);
3083+
3084+
test_case!(filter { tcp.dstport: 80 } => false);
3085+
3086+
test_case!(filter { udp.dstport: 53 } => true);
3087+
3088+
test_case!(filter { udp.dstport: 80 } => false);
3089+
3090+
test_case!(filter {} => false);
3091+
3092+
let filter = scheme
3093+
.parse("(tcp.dstport != 80) and (udp.dstport != 80)")
3094+
.unwrap()
3095+
.compile();
3096+
3097+
test_case!(filter { tcp.dstport: 443 } => false);
3098+
3099+
test_case!(filter { tcp.dstport: 80 } => false);
3100+
3101+
test_case!(filter { udp.dstport: 53 } => false);
3102+
3103+
test_case!(filter { udp.dstport: 80 } => false);
3104+
3105+
test_case!(filter {} => false);
3106+
3107+
let filter = scheme
3108+
.parse("(tcp.srcport == 1337) or ((tcp.dstport != 80) or (udp.dstport != 80))")
3109+
.unwrap()
3110+
.compile();
3111+
3112+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => true);
3113+
3114+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => true);
3115+
3116+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false);
3117+
3118+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => true);
3119+
3120+
test_case!(filter { udp.dstport: 80 } => false);
3121+
3122+
test_case!(filter { udp.dstport: 444 } => true);
3123+
3124+
test_case!(filter {} => false);
3125+
3126+
let filter = scheme
3127+
.parse("(tcp.srcport == 1337) and ((tcp.dstport != 80) or (udp.dstport != 80))")
3128+
.unwrap()
3129+
.compile();
3130+
3131+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => false);
3132+
3133+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => true);
3134+
3135+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false);
3136+
3137+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => false);
3138+
3139+
test_case!(filter { udp.dstport: 80 } => false);
3140+
3141+
test_case!(filter { udp.dstport: 444 } => false);
3142+
3143+
test_case!(filter {} => false);
3144+
3145+
let filter = scheme
3146+
.parse("(tcp.srcport == 1337) or ((tcp.dstport != 80) and (udp.dstport != 80))")
3147+
.unwrap()
3148+
.compile();
3149+
3150+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => true);
3151+
3152+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => true);
3153+
3154+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false);
3155+
3156+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => false);
3157+
3158+
test_case!(filter { udp.dstport: 80 } => false);
3159+
3160+
test_case!(filter { udp.dstport: 444 } => false);
3161+
3162+
test_case!(filter {} => false);
3163+
3164+
let filter = scheme
3165+
.parse("(tcp.srcport == 1337) and ((tcp.dstport != 80) and (udp.dstport != 80))")
3166+
.unwrap()
3167+
.compile();
3168+
3169+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80 } => false);
3170+
3171+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443 } => false);
3172+
3173+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80 } => false);
3174+
3175+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443 } => false);
3176+
3177+
test_case!(filter { udp.dstport: 80 } => false);
3178+
3179+
test_case!(filter { udp.dstport: 444 } => false);
3180+
3181+
test_case!(filter {} => false);
3182+
3183+
let filter = scheme
3184+
.parse("(tcp.srcport == 1337) and ((tcp.dstport != 80) and ((tcp.flags.syn) or (udp.dstport != 80)))")
3185+
.unwrap()
3186+
.compile();
3187+
3188+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80, tcp.flags.syn: true } => false);
3189+
3190+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443, tcp.flags.syn: true } => true);
3191+
3192+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80, tcp.flags.syn: true } => false);
3193+
3194+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443, tcp.flags.syn: true } => false);
3195+
3196+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 80, tcp.flags.syn: false } => false);
3197+
3198+
test_case!(filter { tcp.srcport: 1337, tcp.dstport: 443, tcp.flags.syn: false } => false);
3199+
3200+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 80, tcp.flags.syn: false } => false);
3201+
3202+
test_case!(filter { tcp.srcport: 1234, tcp.dstport: 443, tcp.flags.syn: false } => false);
3203+
3204+
test_case!(filter { udp.dstport: 80 } => false);
3205+
3206+
test_case!(filter { udp.dstport: 444 } => false);
3207+
3208+
test_case!(filter {} => false);
3209+
3210+
let filter = scheme.parse("tcp.flags.syn").unwrap().compile();
3211+
3212+
test_case!(filter { tcp.flags.syn: true } => true);
3213+
3214+
test_case!(filter { tcp.flags.syn: false } => false);
3215+
3216+
test_case!(filter {} => false);
3217+
3218+
let filter = scheme.parse("not tcp.flags.syn").unwrap().compile();
3219+
3220+
test_case!(filter { tcp.flags.syn: true } => false);
3221+
3222+
test_case!(filter { tcp.flags.syn: false } => true);
3223+
3224+
test_case!(filter {} => true);
3225+
3226+
let filter = scheme.parse("not (not tcp.flags.syn)").unwrap().compile();
3227+
3228+
test_case!(filter { tcp.flags.syn: true } => true);
3229+
3230+
test_case!(filter { tcp.flags.syn: false } => false);
3231+
3232+
test_case!(filter {} => false);
3233+
3234+
let filter = scheme.parse("not (tcp.dstport eq 80)").unwrap().compile();
3235+
3236+
test_case!(filter { tcp.dstport: 80 } => false);
3237+
3238+
test_case!(filter { tcp.dstport: 443 } => true);
3239+
3240+
test_case!(filter {} => true);
3241+
3242+
let filter = scheme.parse("not (tcp.dstport ne 80)").unwrap().compile();
3243+
3244+
test_case!(filter { tcp.dstport: 80 } => true);
3245+
3246+
test_case!(filter { tcp.dstport: 443 } => false);
3247+
3248+
test_case!(filter {} => true);
3249+
}
30433250
}

engine/src/ast/index_expr.rs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ impl ValueExpr for IndexExpr {
7979
// Fast path
8080
match identifier {
8181
IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| {
82-
Ok(ctx.get_field_value_unchecked(&f).as_ref())
82+
ctx.get_field_value_unchecked(&f)
83+
.map(LhsValue::as_ref)
84+
.ok_or(ty)
8385
}),
8486
IdentifierExpr::FunctionCallExpr(call) => compiler.compile_function_call_expr(call),
8587
}
@@ -88,7 +90,7 @@ impl ValueExpr for IndexExpr {
8890
match identifier {
8991
IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| {
9092
ctx.get_field_value_unchecked(&f)
91-
.get_nested(&indexes[..last])
93+
.and_then(|value| value.get_nested(&indexes[..last]))
9294
.map(LhsValue::as_ref)
9395
.ok_or(ty)
9496
}),
@@ -103,18 +105,23 @@ impl ValueExpr for IndexExpr {
103105
}
104106
}
105107
} else {
108+
let return_type = Type::Array(ty.into());
106109
// Slow path
107110
match identifier {
108111
IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| {
109112
let mut iter = MapEachIterator::from_indexes(&indexes[..]);
110-
iter.reset(ctx.get_field_value_unchecked(&f).as_ref());
113+
iter.reset(
114+
ctx.get_field_value_unchecked(&f)
115+
.map(LhsValue::as_ref)
116+
.ok_or(return_type)?,
117+
);
111118
Ok(LhsValue::Array(Array::try_from_iter(ty, iter).unwrap()))
112119
}),
113120
IdentifierExpr::FunctionCallExpr(call) => {
114121
let call = compiler.compile_function_call_expr(call);
115122
CompiledValueExpr::new(move |ctx| {
116123
let mut iter = MapEachIterator::from_indexes(&indexes[..]);
117-
iter.reset(call.execute(ctx).map_err(|_| Type::Array(ty.into()))?);
124+
iter.reset(call.execute(ctx).map_err(|_| return_type)?);
118125
Ok(LhsValue::Array(Array::try_from_iter(ty, iter).unwrap()))
119126
})
120127
}
@@ -174,12 +181,14 @@ impl IndexExpr {
174181
IdentifierExpr::Field(f) => {
175182
if indexes.is_empty() {
176183
CompiledOneExpr::new(move |ctx| {
177-
comp.compare(ctx.get_field_value_unchecked(&f), ctx)
184+
ctx.get_field_value_unchecked(&f)
185+
.map(|value| comp.compare(value, ctx))
186+
.unwrap_or(default)
178187
})
179188
} else {
180189
CompiledOneExpr::new(move |ctx| {
181190
ctx.get_field_value_unchecked(&f)
182-
.get_nested(&indexes)
191+
.and_then(|value| value.get_nested(&indexes))
183192
.map_or(
184193
default,
185194
#[inline]
@@ -222,7 +231,7 @@ impl IndexExpr {
222231
IdentifierExpr::Field(f) => CompiledVecExpr::new(move |ctx| {
223232
let comp = &comp;
224233
ctx.get_field_value_unchecked(&f)
225-
.get_nested(&indexes)
234+
.and_then(|value| value.get_nested(&indexes))
226235
.map_or(
227236
BOOL_ARRAY,
228237
#[inline]
@@ -248,7 +257,10 @@ impl IndexExpr {
248257
match identifier {
249258
IdentifierExpr::Field(f) => CompiledVecExpr::new(move |ctx| {
250259
let mut iter = MapEachIterator::from_indexes(&indexes[..]);
251-
iter.reset(ctx.get_field_value_unchecked(&f).as_ref());
260+
match ctx.get_field_value_unchecked(&f) {
261+
Some(value) => iter.reset(value.as_ref()),
262+
None => return TypedArray::default(),
263+
};
252264
TypedArray::from_iter(iter.map(|item| comp.compare(&item, ctx)))
253265
}),
254266
IdentifierExpr::FunctionCallExpr(call) => {

engine/src/execution_context.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl<'e, U> ExecutionContext<'e, U> {
129129
}
130130

131131
#[inline]
132-
pub(crate) fn get_field_value_unchecked(&self, field: &Field) -> &LhsValue<'_> {
132+
pub(crate) fn get_field_value_unchecked(&self, field: &Field) -> Option<&LhsValue<'_>> {
133133
// This is safe because this code is reachable only from Filter::execute
134134
// which already performs the scheme compatibility check, but check that
135135
// invariant holds in the future at least in the debug mode.
@@ -138,12 +138,19 @@ impl<'e, U> ExecutionContext<'e, U> {
138138
// For now we panic in this, but later we are going to align behaviour
139139
// with wireshark: resolve all subexpressions that don't have RHS value
140140
// to `false`.
141-
self.values[field.index()].as_ref().unwrap_or_else(|| {
142-
panic!(
143-
"Field {} was registered but not given a value",
144-
field.name()
145-
);
146-
})
141+
match self.values[field.index()].as_ref() {
142+
Some(value) => Some(value),
143+
None => {
144+
if field.optional() {
145+
None
146+
} else {
147+
panic!(
148+
"Field {} was registered as mandatory but not given a value",
149+
field.name()
150+
);
151+
}
152+
}
153+
}
147154
}
148155

149156
/// Get the value of a field.

0 commit comments

Comments
 (0)