diff --git a/engine/src/functions/all.rs b/engine/src/functions/all.rs new file mode 100644 index 00000000..1dfc0223 --- /dev/null +++ b/engine/src/functions/all.rs @@ -0,0 +1,132 @@ +use crate::{ + FunctionArgKind, FunctionArgs, FunctionDefinition, FunctionDefinitionContext, FunctionParam, + FunctionParamError, GetType, LhsValue, ParserSettings, Type, +}; +use std::iter::once; + +#[inline] +fn all_impl<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + let arg = args.next().expect("expected 1 argument, got 0"); + if args.next().is_some() { + panic!("expected 1 argument, got {}", 2 + args.count()); + } + match arg { + Ok(LhsValue::Array(arr)) => Some(LhsValue::Bool( + arr.into_iter().all(|lhs| bool::try_from(lhs).unwrap()), + )), + Err(Type::Array(ref arr)) if arr.get_type() == Type::Bool => None, + _ => unreachable!(), + } +} + +/// A function which, given an array of bool, returns true if all of the +/// arguments are true, otherwise false. +/// +/// It expects one argument and will error if given an incorrect number of +/// arguments or an argument of invalid type. +#[derive(Debug, Default)] +pub struct AllFunction {} + +impl FunctionDefinition for AllFunction { + fn check_param( + &self, + _: &ParserSettings, + params: &mut dyn ExactSizeIterator>, + next_param: &FunctionParam<'_>, + _: Option<&mut FunctionDefinitionContext>, + ) -> Result<(), FunctionParamError> { + match params.len() { + 0 => { + next_param.expect_arg_kind(FunctionArgKind::Field)?; + next_param.expect_val_type(once(Type::Array(Type::Bool.into()).into()))?; + } + _ => unreachable!(), + } + + Ok(()) + } + + fn return_type( + &self, + _: &mut dyn ExactSizeIterator>, + _: Option<&FunctionDefinitionContext>, + ) -> Type { + Type::Bool + } + + fn arg_count(&self) -> (usize, Option) { + (1, Some(0)) + } + + fn compile<'s>( + &'s self, + _: &mut dyn ExactSizeIterator>, + _: Option, + ) -> Box Fn(FunctionArgs<'_, 'a>) -> Option> + Sync + Send + 's> { + Box::new(all_impl) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Array; + + #[test] + fn test_all_fn() { + // assert that all([]) is true + let arr = LhsValue::Array(Array::new(Type::Bool)); + let mut args = vec![Ok(arr)].into_iter(); + assert_eq!(Some(LhsValue::from(true)), all_impl(&mut args)); + + // assert that all([true]) is true + let arr = LhsValue::Array(Array::from_iter([true])); + let mut args = vec![Ok(arr)].into_iter(); + assert_eq!(Some(LhsValue::from(true)), all_impl(&mut args)); + + // assert that all([false]) is false + let arr = LhsValue::Array(Array::from_iter([false])); + let mut args = vec![Ok(arr)].into_iter(); + assert_eq!(Some(LhsValue::from(false)), all_impl(&mut args)); + + // assert that all([false, true]) is true + let arr = LhsValue::Array(Array::from_iter([false, true])); + let mut args = vec![Ok(arr)].into_iter(); + assert_eq!(Some(LhsValue::from(false)), all_impl(&mut args)); + + // assert that all([true, true]) is true + let arr = LhsValue::Array(Array::from_iter([true, true])); + let mut args = vec![Ok(arr)].into_iter(); + assert_eq!(Some(LhsValue::from(true)), all_impl(&mut args)); + } + + #[test] + #[should_panic(expected = "expected 1 argument, got 0")] + fn test_all_fn_no_args() { + let mut args = vec![].into_iter(); + all_impl(&mut args); + } + + #[test] + #[should_panic(expected = "expected 1 argument, got 2")] + fn test_all_fn_too_many_args() { + let arr = LhsValue::Array(Array::new(Type::Bool)); + let mut args = vec![Ok(arr.clone()), Ok(arr.clone())].into_iter(); + all_impl(&mut args); + } + + #[test] + #[should_panic] + fn test_all_fn_bad_lhs_value() { + let mut args = vec![Ok(LhsValue::from(false))].into_iter(); + all_impl(&mut args); + } + + #[test] + #[should_panic] + fn test_all_fn_bad_lhs_arr_value() { + let arr = LhsValue::Array(Array::from_iter(["hello"])); + let mut args = vec![Ok(arr)].into_iter(); + all_impl(&mut args); + } +} diff --git a/engine/src/functions/any.rs b/engine/src/functions/any.rs new file mode 100644 index 00000000..c1d19885 --- /dev/null +++ b/engine/src/functions/any.rs @@ -0,0 +1,132 @@ +use crate::{ + FunctionArgKind, FunctionArgs, FunctionDefinition, FunctionDefinitionContext, FunctionParam, + FunctionParamError, GetType, LhsValue, ParserSettings, Type, +}; +use std::iter::once; + +#[inline] +fn any_impl<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + let arg = args.next().expect("expected 1 argument, got 0"); + if args.next().is_some() { + panic!("expected 1 argument, got {}", 2 + args.count()); + } + match arg { + Ok(LhsValue::Array(arr)) => Some(LhsValue::Bool( + arr.into_iter().any(|lhs| bool::try_from(lhs).unwrap()), + )), + Err(Type::Array(ref arr)) if arr.get_type() == Type::Bool => None, + _ => unreachable!(), + } +} + +/// A function which, given an array of bool, returns true if any one of the +/// arguments is true, otherwise false. +/// +/// It expects one argument and will error if given an incorrect number of +/// arguments or an argument of invalid type. +#[derive(Debug, Default)] +pub struct AnyFunction {} + +impl FunctionDefinition for AnyFunction { + fn check_param( + &self, + _: &ParserSettings, + params: &mut dyn ExactSizeIterator>, + next_param: &FunctionParam<'_>, + _: Option<&mut FunctionDefinitionContext>, + ) -> Result<(), FunctionParamError> { + match params.len() { + 0 => { + next_param.expect_arg_kind(FunctionArgKind::Field)?; + next_param.expect_val_type(once(Type::Array(Type::Bool.into()).into()))?; + } + _ => unreachable!(), + } + + Ok(()) + } + + fn return_type( + &self, + _: &mut dyn ExactSizeIterator>, + _: Option<&FunctionDefinitionContext>, + ) -> Type { + Type::Bool + } + + fn arg_count(&self) -> (usize, Option) { + (1, Some(0)) + } + + fn compile<'s>( + &'s self, + _: &mut dyn ExactSizeIterator>, + _: Option, + ) -> Box Fn(FunctionArgs<'_, 'a>) -> Option> + Sync + Send + 's> { + Box::new(any_impl) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Array; + + #[test] + fn test_any_fn() { + // assert that any([]) is false + let arr = LhsValue::Array(Array::new(Type::Bool)); + let mut args = vec![Ok(arr)].into_iter(); + assert_eq!(Some(LhsValue::from(false)), any_impl(&mut args)); + + // assert that any([true]) is true + let arr = LhsValue::Array(Array::from_iter([true])); + let mut args = vec![Ok(arr)].into_iter(); + assert_eq!(Some(LhsValue::from(true)), any_impl(&mut args)); + + // assert that any([false]) is false + let arr = LhsValue::Array(Array::from_iter([false])); + let mut args = vec![Ok(arr)].into_iter(); + assert_eq!(Some(LhsValue::from(false)), any_impl(&mut args)); + + // assert that any([false, true]) is true + let arr = LhsValue::Array(Array::from_iter([false, true])); + let mut args = vec![Ok(arr)].into_iter(); + assert_eq!(Some(LhsValue::from(true)), any_impl(&mut args)); + + // assert that any([true, true]) is true + let arr = LhsValue::Array(Array::from_iter([true, true])); + let mut args = vec![Ok(arr)].into_iter(); + assert_eq!(Some(LhsValue::from(true)), any_impl(&mut args)); + } + + #[test] + #[should_panic(expected = "expected 1 argument, got 0")] + fn test_any_fn_no_args() { + let mut args = vec![].into_iter(); + any_impl(&mut args); + } + + #[test] + #[should_panic(expected = "expected 1 argument, got 2")] + fn test_any_fn_too_many_args() { + let arr = LhsValue::Array(Array::new(Type::Bool)); + let mut args = vec![Ok(arr.clone()), Ok(arr.clone())].into_iter(); + any_impl(&mut args); + } + + #[test] + #[should_panic] + fn test_any_fn_bad_lhs_value() { + let mut args = vec![Ok(LhsValue::from(false))].into_iter(); + any_impl(&mut args); + } + + #[test] + #[should_panic] + fn test_any_fn_bad_lhs_arr_value() { + let arr = LhsValue::Array(Array::from_iter(["hello"])); + let mut args = vec![Ok(arr)].into_iter(); + any_impl(&mut args); + } +} diff --git a/engine/src/functions/concat.rs b/engine/src/functions/concat.rs new file mode 100644 index 00000000..d66ffcfe --- /dev/null +++ b/engine/src/functions/concat.rs @@ -0,0 +1,208 @@ +use crate::{ + Array, ExpectedType, FunctionArgs, FunctionDefinition, FunctionDefinitionContext, + FunctionParam, FunctionParamError, GetType, LhsValue, ParserSettings, Type, +}; +use std::{borrow::Cow, iter::once}; + +/// A function which, given one or more arrays or byte-strings, returns the +/// concatenation of each of them. +/// +/// It expects at least two arguments and will error if given no arguments +/// or the arguments are of different types. +#[derive(Debug, Default)] +pub struct ConcatFunction {} + +impl ConcatFunction { + /// Creates a new definition for the `concat` function. + pub const fn new() -> Self { + Self {} + } +} + +fn concat_array<'a>(mut accumulator: Array<'a>, args: FunctionArgs<'_, 'a>) -> Array<'a> { + for arg in args { + match arg { + Ok(LhsValue::Array(value)) => accumulator.try_extend(value).unwrap(), + Err(Type::Array(_)) => (), + _ => (), + }; + } + accumulator +} + +fn concat_bytes<'a>(mut accumulator: Cow<'a, [u8]>, args: FunctionArgs<'_, 'a>) -> Cow<'a, [u8]> { + for arg in args { + match arg { + Ok(LhsValue::Bytes(value)) => accumulator.to_mut().extend(value.iter()), + Err(Type::Bytes) => (), + _ => (), + } + } + accumulator +} + +const EXPECTED_TYPES: [ExpectedType; 2] = [ExpectedType::Array, ExpectedType::Type(Type::Bytes)]; + +impl FunctionDefinition for ConcatFunction { + fn check_param( + &self, + _: &ParserSettings, + params: &mut dyn ExactSizeIterator>, + next_param: &FunctionParam<'_>, + _: Option<&mut FunctionDefinitionContext>, + ) -> Result<(), FunctionParamError> { + match params.next() { + // the next argument must have the same type + Some(param) => { + next_param.expect_val_type(once(param.get_type().into()))?; + } + None => { + next_param.expect_val_type(EXPECTED_TYPES.iter().cloned())?; + } + } + + Ok(()) + } + + fn return_type( + &self, + params: &mut dyn ExactSizeIterator>, + _: Option<&FunctionDefinitionContext>, + ) -> Type { + params.next().unwrap().get_type() + } + + fn arg_count(&self) -> (usize, Option) { + (2, None) + } + + fn compile<'s>( + &'s self, + _: &mut dyn ExactSizeIterator>, + _: Option, + ) -> Box Fn(FunctionArgs<'_, 'a>) -> Option> + Sync + Send + 's> { + Box::new(|args| { + while let Some(arg) = args.next() { + match arg { + Ok(LhsValue::Array(array)) => { + return Some(LhsValue::Array(concat_array(array, args))) + } + Ok(LhsValue::Bytes(bytes)) => { + return Some(LhsValue::Bytes(concat_bytes(bytes, args))) + } + Err(_) => (), + _ => unreachable!(), + } + } + None + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TypeMismatchError; + + pub static CONCAT_FN: ConcatFunction = ConcatFunction::new(); + + #[test] + fn test_concat_bytes() { + let mut args = vec![ + Ok(LhsValue::Bytes(Cow::Borrowed(b"hello"))), + Ok(LhsValue::Bytes(Cow::Borrowed(b"world"))), + ] + .into_iter(); + assert_eq!( + Some(LhsValue::Bytes(Cow::Borrowed(b"helloworld"))), + CONCAT_FN.compile(&mut std::iter::empty(), None)(&mut args) + ); + } + + #[test] + fn test_concat_many_bytes() { + let mut args = vec![ + Ok(LhsValue::Bytes(Cow::Borrowed(b"hello"))), + Ok(LhsValue::Bytes(Cow::Borrowed(b"world"))), + Ok(LhsValue::Bytes(Cow::Borrowed(b"hello2"))), + Ok(LhsValue::Bytes(Cow::Borrowed(b"world2"))), + ] + .into_iter(); + assert_eq!( + Some(LhsValue::Bytes(Cow::Borrowed(b"helloworldhello2world2"))), + CONCAT_FN.compile(&mut std::iter::empty(), None)(&mut args) + ); + } + + #[test] + fn test_concat_function() { + let arg1 = LhsValue::Array(Array::from_iter([1, 2, 3])); + let arg2 = LhsValue::Array(Array::from_iter([4, 5, 6])); + let mut args = vec![Ok(arg1), Ok(arg2)].into_iter(); + assert_eq!( + Some(LhsValue::Array(Array::from_iter([1, 2, 3, 4, 5, 6]))), + CONCAT_FN.compile(&mut std::iter::empty(), None)(&mut args) + ); + } + + #[test] + #[should_panic] + fn test_concat_function_bad_arg_type() { + let mut args = vec![Ok(LhsValue::from(2))].into_iter(); + CONCAT_FN.compile(&mut std::iter::empty(), None)(&mut args); + } + + #[test] + fn test_concat_function_check_param() { + let settings = ParserSettings::default(); + + let arg1 = FunctionParam::Variable(Type::Array(Type::Bytes.into())); + assert_eq!( + Ok(()), + CONCAT_FN.check_param(&settings, &mut vec![].into_iter(), &arg1, None) + ); + + let arg2 = FunctionParam::Variable(Type::Array(Type::Bytes.into())); + assert_eq!( + Ok(()), + CONCAT_FN.check_param(&settings, &mut vec![arg1.clone()].into_iter(), &arg2, None) + ); + + let arg3 = FunctionParam::Variable(Type::Int); + + assert_eq!( + Err(FunctionParamError::TypeMismatch(TypeMismatchError { + expected: Type::Array(Type::Bytes.into()).into(), + actual: Type::Int, + })), + CONCAT_FN.check_param( + &settings, + &mut vec![arg1.clone(), arg2.clone()].into_iter(), + &arg3, + None + ) + ); + + assert_eq!( + Err(FunctionParamError::TypeMismatch(TypeMismatchError { + expected: [ExpectedType::Array, ExpectedType::Type(Type::Bytes)] + .into_iter() + .into(), + actual: Type::Int, + })), + CONCAT_FN.check_param(&settings, &mut vec![].into_iter(), &arg3, None) + ); + + let arg_literal = FunctionParam::Variable(Type::Bytes); + + assert_eq!( + Ok(()), + CONCAT_FN.check_param( + &settings, + &mut vec![arg_literal.clone()].into_iter(), + &arg_literal, + None + ) + ); + } +} diff --git a/engine/src/functions.rs b/engine/src/functions/mod.rs similarity index 99% rename from engine/src/functions.rs rename to engine/src/functions/mod.rs index 476b06d1..f3b68dca 100644 --- a/engine/src/functions.rs +++ b/engine/src/functions/mod.rs @@ -1,8 +1,15 @@ +mod all; +mod any; +mod concat; + use crate::{ filter::CompiledValueResult, types::{ExpectedType, ExpectedTypeList, GetType, LhsValue, RhsValue, Type, TypeMismatchError}, ParserSettings, }; +pub use all::AllFunction; +pub use any::AnyFunction; +pub use concat::ConcatFunction; use std::any::Any; use std::convert::TryFrom; use std::fmt::{self, Debug}; diff --git a/engine/src/lib.rs b/engine/src/lib.rs index 88c9a315..346c095a 100644 --- a/engine/src/lib.rs +++ b/engine/src/lib.rs @@ -96,10 +96,10 @@ pub use self::{ CompiledExpr, CompiledOneExpr, CompiledValueExpr, CompiledVecExpr, Filter, FilterValue, }, functions::{ - FunctionArgInvalidConstantError, FunctionArgKind, FunctionArgKindMismatchError, - FunctionArgs, FunctionDefinition, FunctionDefinitionContext, FunctionParam, - FunctionParamError, SimpleFunctionDefinition, SimpleFunctionImpl, SimpleFunctionOptParam, - SimpleFunctionParam, + AllFunction, AnyFunction, ConcatFunction, FunctionArgInvalidConstantError, FunctionArgKind, + FunctionArgKindMismatchError, FunctionArgs, FunctionDefinition, FunctionDefinitionContext, + FunctionParam, FunctionParamError, SimpleFunctionDefinition, SimpleFunctionImpl, + SimpleFunctionOptParam, SimpleFunctionParam, }, lex::LexErrorKind, lhs_types::{Array, ArrayMut, Map, MapIter, MapMut, TypedArray, TypedMap},