diff --git a/engine/src/ast/index_expr.rs b/engine/src/ast/index_expr.rs index 41ac7232..3500669e 100644 --- a/engine/src/ast/index_expr.rs +++ b/engine/src/ast/index_expr.rs @@ -7,9 +7,7 @@ use super::{ use crate::{ compiler::Compiler, execution_context::ExecutionContext, - filter::{ - CompiledExpr, CompiledOneExpr, CompiledValueExpr, CompiledVecExpr, CompiledVecExprResult, - }, + filter::{CompiledExpr, CompiledOneExpr, CompiledValueExpr, CompiledVecExpr}, lex::{Lex, LexErrorKind, LexResult, LexWith, expect, skip_space, span}, lhs_types::{Array, Map, TypedArray}, scheme::{FieldIndex, IndexAccessError}, @@ -17,6 +15,8 @@ use crate::{ }; use serde::{Serialize, Serializer, ser::SerializeSeq}; +const BOOL_ARRAY: TypedArray<'_, bool> = TypedArray::new(); + /// IndexExpr is an expr that destructures an index into an IdentifierExpr. /// /// For example, given a scheme which declares a field `http.request.headers`, @@ -31,45 +31,13 @@ pub struct IndexExpr { pub indexes: Vec, } -fn index_access_one<'s, 'e, U, F>( - indexes: &[FieldIndex], - first: Option<&'e LhsValue<'e>>, - default: bool, - ctx: &'e ExecutionContext<'e, U>, - func: F, -) -> bool -where - F: Fn(&LhsValue<'_>, &ExecutionContext<'_, U>) -> bool + Sync + Send + 's, -{ - indexes - .iter() - .fold(first, |value, idx| { - value.and_then(|val| val.get(idx).unwrap()) - }) - .map_or_else( - || default, - #[inline] - |val| func(val, ctx), - ) -} - -fn index_access_vec<'s, 'e, U, F>( - indexes: &[FieldIndex], - first: Option<&'e LhsValue<'e>>, - ctx: &'e ExecutionContext<'e, U>, - func: F, -) -> CompiledVecExprResult -where - F: Fn(&LhsValue<'_>, &ExecutionContext<'_, U>) -> bool + Sync + Send + 's, -{ - indexes - .iter() - .fold(first, |value, idx| { - value.and_then(|val| val.get(idx).unwrap()) - }) - .map_or(const { TypedArray::new() }, move |val: &LhsValue<'_>| { - TypedArray::from_iter(val.iter().unwrap().map(|item| func(item, ctx))) - }) +#[allow(clippy::manual_ok_err)] +#[inline] +pub fn ok_ref(result: &Result) -> Option<&T> { + match result { + Ok(x) => Some(x), + Err(_) => None, + } } impl ValueExpr for IndexExpr { @@ -119,23 +87,17 @@ impl ValueExpr for IndexExpr { // Average path match identifier { IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| { - indexes[..last] - .iter() - .try_fold(ctx.get_field_value_unchecked(&f), |value, index| { - value.get(index).unwrap() - }) + ctx.get_field_value_unchecked(&f) + .get_nested(&indexes[..last]) .map(LhsValue::as_ref) .ok_or(ty) }), IdentifierExpr::FunctionCallExpr(call) => { let call = compiler.compile_function_call_expr(call); CompiledValueExpr::new(move |ctx| { - let result = call.execute(ctx).ok(); - indexes[..last] - .iter() - .fold(result, |value, index| { - value.and_then(|val| val.extract(index).unwrap()) - }) + call.execute(ctx) + .ok() + .and_then(|val| val.extract_nested(&indexes[..last])) .ok_or(ty) }) } @@ -192,14 +154,13 @@ impl IndexExpr { }) } else { CompiledOneExpr::new(move |ctx| { - index_access_one( - &indexes, - call.execute(ctx).as_ref().ok(), - default, - ctx, - #[inline] - |val, ctx| func(val, ctx), - ) + ok_ref(&call.execute(ctx)) + .and_then(|val| val.get_nested(&indexes)) + .map_or( + default, + #[inline] + |val| func(val, ctx), + ) }) } } @@ -208,14 +169,13 @@ impl IndexExpr { CompiledOneExpr::new(move |ctx| func(ctx.get_field_value_unchecked(&f), ctx)) } else { CompiledOneExpr::new(move |ctx| { - index_access_one( - &indexes, - Some(ctx.get_field_value_unchecked(&f)), - default, - ctx, - #[inline] - |val, ctx| func(val, ctx), - ) + ctx.get_field_value_unchecked(&f) + .get_nested(&indexes) + .map_or( + default, + #[inline] + |val| func(val, ctx), + ) }) } } @@ -239,23 +199,21 @@ impl IndexExpr { IdentifierExpr::FunctionCallExpr(call) => { let call = compiler.compile_function_call_expr(call); CompiledVecExpr::new(move |ctx| { - index_access_vec( - &indexes, - call.execute(ctx).as_ref().ok(), - ctx, - #[inline] - |val, ctx| func(val, ctx), - ) + let func = &func; + ok_ref(&call.execute(ctx)) + .and_then(|val| val.get_nested(&indexes)) + .map_or(BOOL_ARRAY, move |val: &LhsValue<'_>| { + TypedArray::from_iter(val.iter().unwrap().map(|item| func(item, ctx))) + }) }) } IdentifierExpr::Field(f) => CompiledVecExpr::new(move |ctx| { - index_access_vec( - &indexes, - Some(ctx.get_field_value_unchecked(&f)), - ctx, - #[inline] - |val, ctx| func(val, ctx), - ) + let func = &func; + ctx.get_field_value_unchecked(&f) + .get_nested(&indexes) + .map_or(BOOL_ARRAY, move |val: &LhsValue<'_>| { + TypedArray::from_iter(val.iter().unwrap().map(|item| func(item, ctx))) + }) }), } } diff --git a/engine/src/types.rs b/engine/src/types.rs index c64b731a..aa00f17f 100644 --- a/engine/src/types.rs +++ b/engine/src/types.rs @@ -813,6 +813,13 @@ impl<'a> LhsValue<'a> { } } + #[inline] + pub(crate) fn get_nested(&'a self, indexes: &[FieldIndex]) -> Option<&'a LhsValue<'a>> { + indexes + .iter() + .try_fold(self, |value, idx| value.get(idx).unwrap()) + } + pub(crate) fn extract( self, item: &FieldIndex, @@ -839,6 +846,13 @@ impl<'a> LhsValue<'a> { } } + #[inline] + pub(crate) fn extract_nested(self, indexes: &[FieldIndex]) -> Option> { + indexes + .iter() + .try_fold(self, |value, idx| value.extract(idx).unwrap()) + } + /// Set an element in an LhsValue given a path item and a specified value. /// Returns a TypeMismatchError error if current type does not support /// nested element or if value type is invalid.