diff --git a/src/ops/factory_bases.rs b/src/ops/factory_bases.rs index 2a4b6143..880f4444 100644 --- a/src/ops/factory_bases.rs +++ b/src/ops/factory_bases.rs @@ -6,12 +6,64 @@ use std::hash::Hash; use super::interface::*; use super::registry::*; use crate::api_bail; -use crate::api_error; use crate::base::schema::*; use crate::base::spec::*; use crate::builder::plan::AnalyzedValueMapping; use crate::setup; // SourceFactoryBase +pub struct OpArgResolver<'arg> { + name: String, + resolved_op_arg: Option<(usize, EnrichedValueType)>, + nonnull_args_idx: &'arg mut Vec, + may_nullify_output: &'arg mut bool, +} + +impl<'arg> OpArgResolver<'arg> { + pub fn expect_nullable_type(self, expected_type: &ValueType) -> Result { + let Some((_, typ)) = &self.resolved_op_arg else { + return Ok(self); + }; + if &typ.typ != expected_type { + api_bail!( + "Expected argument `{}` to be of type `{}`, got `{}`", + self.name, + expected_type, + typ.typ + ); + } + Ok(self) + } + pub fn expect_type(self, expected_type: &ValueType) -> Result { + let resolver = self.expect_nullable_type(expected_type)?; + resolver.resolved_op_arg.as_ref().map(|(idx, typ)| { + resolver.nonnull_args_idx.push(*idx); + if typ.nullable { + *resolver.may_nullify_output = true; + } + }); + Ok(resolver) + } + + pub fn optional(self) -> Option { + return self.resolved_op_arg.map(|(idx, typ)| ResolvedOpArg { + name: self.name, + typ, + idx, + }); + } + + pub fn required(self) -> Result { + let Some((idx, typ)) = self.resolved_op_arg else { + api_bail!("Required argument `{}` is missing", self.name); + }; + Ok(ResolvedOpArg { + name: self.name, + typ, + idx, + }) + } +} + pub struct ResolvedOpArg { pub name: String, pub typ: EnrichedValueType, @@ -19,24 +71,11 @@ pub struct ResolvedOpArg { } pub trait ResolvedOpArgExt: Sized { - fn expect_type(self, expected_type: &ValueType) -> Result; fn value<'a>(&self, args: &'a [value::Value]) -> Result<&'a value::Value>; fn take_value(&self, args: &mut [value::Value]) -> Result; } impl ResolvedOpArgExt for ResolvedOpArg { - fn expect_type(self, expected_type: &ValueType) -> Result { - if &self.typ.typ != expected_type { - api_bail!( - "Expected argument `{}` to be of type `{}`, got `{}`", - self.name, - expected_type, - self.typ.typ - ); - } - Ok(self) - } - fn value<'a>(&self, args: &'a [value::Value]) -> Result<&'a value::Value> { if self.idx >= args.len() { api_bail!( @@ -63,10 +102,6 @@ impl ResolvedOpArgExt for ResolvedOpArg { } impl ResolvedOpArgExt for Option { - fn expect_type(self, expected_type: &ValueType) -> Result { - self.map(|arg| arg.expect_type(expected_type)).transpose() - } - fn value<'a>(&self, args: &'a [value::Value]) -> Result<&'a value::Value> { Ok(self .as_ref() @@ -89,11 +124,16 @@ pub struct OpArgsResolver<'a> { num_positional_args: usize, next_positional_idx: usize, remaining_kwargs: HashMap<&'a str, usize>, - required_args_idx: &'a mut Vec, + nonnull_args_idx: &'a mut Vec, + may_nullify_output: &'a mut bool, } impl<'a> OpArgsResolver<'a> { - pub fn new(args: &'a [OpArgSchema], required_args_idx: &'a mut Vec) -> Result { + pub fn new( + args: &'a [OpArgSchema], + nonnull_args_idx: &'a mut Vec, + may_nullify_output: &'a mut bool, + ) -> Result { let mut num_positional_args = 0; let mut kwargs = HashMap::new(); for (idx, arg) in args.iter().enumerate() { @@ -111,11 +151,12 @@ impl<'a> OpArgsResolver<'a> { num_positional_args, next_positional_idx: 0, remaining_kwargs: kwargs, - required_args_idx, + nonnull_args_idx, + may_nullify_output, }) } - pub fn next_optional_arg(&mut self, name: &str) -> Result> { + pub fn next_arg<'arg>(&'arg mut self, name: &str) -> Result> { let idx = if let Some(idx) = self.remaining_kwargs.remove(name) { if self.next_positional_idx < self.num_positional_args { api_bail!("`{name}` is provided as both positional and keyword arguments"); @@ -129,19 +170,12 @@ impl<'a> OpArgsResolver<'a> { } else { None }; - Ok(idx.map(|idx| ResolvedOpArg { + Ok(OpArgResolver { name: name.to_string(), - typ: self.args[idx].value_type.clone(), - idx, - })) - } - - pub fn next_arg(&mut self, name: &str) -> Result { - let arg = self - .next_optional_arg(name)? - .ok_or_else(|| api_error!("Required argument `{name}` is missing",))?; - self.required_args_idx.push(arg.idx); - Ok(arg) + resolved_op_arg: idx.map(|idx| (idx, self.args[idx].value_type.clone())), + nonnull_args_idx: self.nonnull_args_idx, + may_nullify_output: self.may_nullify_output, + }) } pub fn done(self) -> Result<()> { @@ -252,13 +286,13 @@ pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'stat struct FunctionExecutorWrapper { executor: E, - required_args_idx: Vec, + nonnull_args_idx: Vec, } #[async_trait] impl SimpleFunctionExecutor for FunctionExecutorWrapper { async fn evaluate(&self, args: Vec) -> Result { - for idx in &self.required_args_idx { + for idx in &self.nonnull_args_idx { if args[*idx].is_null() { return Ok(value::Value::Null); } @@ -287,28 +321,29 @@ impl SimpleFunctionFactory for T { BoxFuture<'static, Result>>, )> { let spec: T::Spec = serde_json::from_value(spec)?; - let mut required_args_idx = vec![]; - let mut args_resolver = OpArgsResolver::new(&input_schema, &mut required_args_idx)?; + let mut nonnull_args_idx = vec![]; + let mut may_nullify_output = false; + let mut args_resolver = OpArgsResolver::new( + &input_schema, + &mut nonnull_args_idx, + &mut may_nullify_output, + )?; let (resolved_input_schema, mut output_schema) = self .resolve_schema(&spec, &mut args_resolver, &context) .await?; + args_resolver.done()?; // If any required argument is nullable, the output schema should be nullable. - if args_resolver - .required_args_idx - .iter() - .any(|idx| input_schema[*idx].value_type.nullable) - { + if may_nullify_output { output_schema.nullable = true; } - args_resolver.done()?; let executor = async move { Ok(Box::new(FunctionExecutorWrapper { executor: self .build_executor(spec, resolved_input_schema, context) .await?, - required_args_idx, + nonnull_args_idx, }) as Box) }; Ok((output_schema, Box::pin(executor))) diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index 46be3a16..23da89f1 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -69,7 +69,10 @@ impl SimpleFunctionFactoryBase for Factory { args_resolver: &mut OpArgsResolver<'a>, _context: &FlowInstanceContext, ) -> Result<(Self::ResolvedArgs, EnrichedValueType)> { - let text = args_resolver.next_arg("text")?; + let text = args_resolver + .next_arg("text")? + .expect_type(&ValueType::Basic(BasicValueType::Str))? + .required()?; let client = new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone()) .await?; diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index a8fddcfc..627cb93c 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -138,11 +138,13 @@ impl SimpleFunctionFactoryBase for Factory { ) -> Result<(Args, EnrichedValueType)> { let args = Args { text: args_resolver - .next_optional_arg("text")? - .expect_type(&ValueType::Basic(BasicValueType::Str))?, + .next_arg("text")? + .expect_nullable_type(&ValueType::Basic(BasicValueType::Str))? + .optional(), image: args_resolver - .next_optional_arg("image")? - .expect_type(&ValueType::Basic(BasicValueType::Bytes))?, + .next_arg("image")? + .expect_nullable_type(&ValueType::Basic(BasicValueType::Bytes))? + .optional(), }; if args.text.is_none() && args.image.is_none() { diff --git a/src/ops/functions/parse_json.rs b/src/ops/functions/parse_json.rs index a88f2fff..84900710 100644 --- a/src/ops/functions/parse_json.rs +++ b/src/ops/functions/parse_json.rs @@ -83,10 +83,12 @@ impl SimpleFunctionFactoryBase for Factory { let args = Args { text: args_resolver .next_arg("text")? - .expect_type(&ValueType::Basic(BasicValueType::Str))?, + .expect_type(&ValueType::Basic(BasicValueType::Str))? + .required()?, language: args_resolver - .next_optional_arg("language")? - .expect_type(&ValueType::Basic(BasicValueType::Str))?, + .next_arg("language")? + .expect_nullable_type(&ValueType::Basic(BasicValueType::Str))? + .optional(), }; let output_schema = make_output_type(BasicValueType::Json); diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index c740571d..b1a35c97 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -963,19 +963,24 @@ impl SimpleFunctionFactoryBase for Factory { let args = Args { text: args_resolver .next_arg("text")? - .expect_type(&ValueType::Basic(BasicValueType::Str))?, + .expect_type(&ValueType::Basic(BasicValueType::Str))? + .required()?, chunk_size: args_resolver .next_arg("chunk_size")? - .expect_type(&ValueType::Basic(BasicValueType::Int64))?, + .expect_type(&ValueType::Basic(BasicValueType::Int64))? + .required()?, min_chunk_size: args_resolver - .next_optional_arg("min_chunk_size")? - .expect_type(&ValueType::Basic(BasicValueType::Int64))?, + .next_arg("min_chunk_size")? + .expect_nullable_type(&ValueType::Basic(BasicValueType::Int64))? + .optional(), chunk_overlap: args_resolver - .next_optional_arg("chunk_overlap")? - .expect_type(&ValueType::Basic(BasicValueType::Int64))?, + .next_arg("chunk_overlap")? + .expect_nullable_type(&ValueType::Basic(BasicValueType::Int64))? + .optional(), language: args_resolver - .next_optional_arg("language")? - .expect_type(&ValueType::Basic(BasicValueType::Str))?, + .next_arg("language")? + .expect_nullable_type(&ValueType::Basic(BasicValueType::Str))? + .optional(), }; let pos_struct = schema::ValueType::Struct(schema::StructSchema {