Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 80 additions & 45 deletions src/ops/factory_bases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,76 @@ use std::hash::Hash;
use super::interface::*;
use super::registry::*;
use crate::api_bail;
use crate::api_error;
use crate::base::schema::*;
use crate::base::spec::*;
use crate::builder::plan::AnalyzedValueMapping;
use crate::setup;
// SourceFactoryBase
pub struct OpArgResolver<'arg> {
name: String,
resolved_op_arg: Option<(usize, EnrichedValueType)>,
nonnull_args_idx: &'arg mut Vec<usize>,
may_nullify_output: &'arg mut bool,
}

impl<'arg> OpArgResolver<'arg> {
pub fn expect_nullable_type(self, expected_type: &ValueType) -> Result<Self> {
let Some((_, typ)) = &self.resolved_op_arg else {
return Ok(self);
};
if &typ.typ != expected_type {
api_bail!(
"Expected argument `{}` to be of type `{}`, got `{}`",
self.name,
expected_type,
typ.typ
);
}
Ok(self)
}
pub fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
let resolver = self.expect_nullable_type(expected_type)?;
resolver.resolved_op_arg.as_ref().map(|(idx, typ)| {
resolver.nonnull_args_idx.push(*idx);
if typ.nullable {
*resolver.may_nullify_output = true;
}
});
Ok(resolver)
}

pub fn optional(self) -> Option<ResolvedOpArg> {
return self.resolved_op_arg.map(|(idx, typ)| ResolvedOpArg {
name: self.name,
typ,
idx,
});
}

pub fn required(self) -> Result<ResolvedOpArg> {
let Some((idx, typ)) = self.resolved_op_arg else {
api_bail!("Required argument `{}` is missing", self.name);
};
Ok(ResolvedOpArg {
name: self.name,
typ,
idx,
})
}
}

pub struct ResolvedOpArg {
pub name: String,
pub typ: EnrichedValueType,
pub idx: usize,
}

pub trait ResolvedOpArgExt: Sized {
fn expect_type(self, expected_type: &ValueType) -> Result<Self>;
fn value<'a>(&self, args: &'a [value::Value]) -> Result<&'a value::Value>;
fn take_value(&self, args: &mut [value::Value]) -> Result<value::Value>;
}

impl ResolvedOpArgExt for ResolvedOpArg {
fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
if &self.typ.typ != expected_type {
api_bail!(
"Expected argument `{}` to be of type `{}`, got `{}`",
self.name,
expected_type,
self.typ.typ
);
}
Ok(self)
}

fn value<'a>(&self, args: &'a [value::Value]) -> Result<&'a value::Value> {
if self.idx >= args.len() {
api_bail!(
Expand All @@ -63,10 +102,6 @@ impl ResolvedOpArgExt for ResolvedOpArg {
}

impl ResolvedOpArgExt for Option<ResolvedOpArg> {
fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
self.map(|arg| arg.expect_type(expected_type)).transpose()
}

fn value<'a>(&self, args: &'a [value::Value]) -> Result<&'a value::Value> {
Ok(self
.as_ref()
Expand All @@ -89,11 +124,16 @@ pub struct OpArgsResolver<'a> {
num_positional_args: usize,
next_positional_idx: usize,
remaining_kwargs: HashMap<&'a str, usize>,
required_args_idx: &'a mut Vec<usize>,
nonnull_args_idx: &'a mut Vec<usize>,
may_nullify_output: &'a mut bool,
}

impl<'a> OpArgsResolver<'a> {
pub fn new(args: &'a [OpArgSchema], required_args_idx: &'a mut Vec<usize>) -> Result<Self> {
pub fn new(
args: &'a [OpArgSchema],
nonnull_args_idx: &'a mut Vec<usize>,
may_nullify_output: &'a mut bool,
) -> Result<Self> {
let mut num_positional_args = 0;
let mut kwargs = HashMap::new();
for (idx, arg) in args.iter().enumerate() {
Expand All @@ -111,11 +151,12 @@ impl<'a> OpArgsResolver<'a> {
num_positional_args,
next_positional_idx: 0,
remaining_kwargs: kwargs,
required_args_idx,
nonnull_args_idx,
may_nullify_output,
})
}

pub fn next_optional_arg(&mut self, name: &str) -> Result<Option<ResolvedOpArg>> {
pub fn next_arg<'arg>(&'arg mut self, name: &str) -> Result<OpArgResolver<'arg>> {
let idx = if let Some(idx) = self.remaining_kwargs.remove(name) {
if self.next_positional_idx < self.num_positional_args {
api_bail!("`{name}` is provided as both positional and keyword arguments");
Expand All @@ -129,19 +170,12 @@ impl<'a> OpArgsResolver<'a> {
} else {
None
};
Ok(idx.map(|idx| ResolvedOpArg {
Ok(OpArgResolver {
name: name.to_string(),
typ: self.args[idx].value_type.clone(),
idx,
}))
}

pub fn next_arg(&mut self, name: &str) -> Result<ResolvedOpArg> {
let arg = self
.next_optional_arg(name)?
.ok_or_else(|| api_error!("Required argument `{name}` is missing",))?;
self.required_args_idx.push(arg.idx);
Ok(arg)
resolved_op_arg: idx.map(|idx| (idx, self.args[idx].value_type.clone())),
nonnull_args_idx: self.nonnull_args_idx,
may_nullify_output: self.may_nullify_output,
})
}

pub fn done(self) -> Result<()> {
Expand Down Expand Up @@ -252,13 +286,13 @@ pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'stat

struct FunctionExecutorWrapper<E: SimpleFunctionExecutor> {
executor: E,
required_args_idx: Vec<usize>,
nonnull_args_idx: Vec<usize>,
}

#[async_trait]
impl<E: SimpleFunctionExecutor> SimpleFunctionExecutor for FunctionExecutorWrapper<E> {
async fn evaluate(&self, args: Vec<value::Value>) -> Result<value::Value> {
for idx in &self.required_args_idx {
for idx in &self.nonnull_args_idx {
if args[*idx].is_null() {
return Ok(value::Value::Null);
}
Expand Down Expand Up @@ -287,28 +321,29 @@ impl<T: SimpleFunctionFactoryBase> SimpleFunctionFactory for T {
BoxFuture<'static, Result<Box<dyn SimpleFunctionExecutor>>>,
)> {
let spec: T::Spec = serde_json::from_value(spec)?;
let mut required_args_idx = vec![];
let mut args_resolver = OpArgsResolver::new(&input_schema, &mut required_args_idx)?;
let mut nonnull_args_idx = vec![];
let mut may_nullify_output = false;
let mut args_resolver = OpArgsResolver::new(
&input_schema,
&mut nonnull_args_idx,
&mut may_nullify_output,
)?;
let (resolved_input_schema, mut output_schema) = self
.resolve_schema(&spec, &mut args_resolver, &context)
.await?;
args_resolver.done()?;

// If any required argument is nullable, the output schema should be nullable.
if args_resolver
.required_args_idx
.iter()
.any(|idx| input_schema[*idx].value_type.nullable)
{
if may_nullify_output {
output_schema.nullable = true;
}

args_resolver.done()?;
let executor = async move {
Ok(Box::new(FunctionExecutorWrapper {
executor: self
.build_executor(spec, resolved_input_schema, context)
.await?,
required_args_idx,
nonnull_args_idx,
}) as Box<dyn SimpleFunctionExecutor>)
};
Ok((output_schema, Box::pin(executor)))
Expand Down
5 changes: 4 additions & 1 deletion src/ops/functions/embed_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ impl SimpleFunctionFactoryBase for Factory {
args_resolver: &mut OpArgsResolver<'a>,
_context: &FlowInstanceContext,
) -> Result<(Self::ResolvedArgs, EnrichedValueType)> {
let text = args_resolver.next_arg("text")?;
let text = args_resolver
.next_arg("text")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?
.required()?;
let client =
new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone())
.await?;
Expand Down
10 changes: 6 additions & 4 deletions src/ops/functions/extract_by_llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,13 @@ impl SimpleFunctionFactoryBase for Factory {
) -> Result<(Args, EnrichedValueType)> {
let args = Args {
text: args_resolver
.next_optional_arg("text")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
.next_arg("text")?
.expect_nullable_type(&ValueType::Basic(BasicValueType::Str))?
.optional(),
image: args_resolver
.next_optional_arg("image")?
.expect_type(&ValueType::Basic(BasicValueType::Bytes))?,
.next_arg("image")?
.expect_nullable_type(&ValueType::Basic(BasicValueType::Bytes))?
.optional(),
};

if args.text.is_none() && args.image.is_none() {
Expand Down
8 changes: 5 additions & 3 deletions src/ops/functions/parse_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,12 @@ impl SimpleFunctionFactoryBase for Factory {
let args = Args {
text: args_resolver
.next_arg("text")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
.expect_type(&ValueType::Basic(BasicValueType::Str))?
.required()?,
language: args_resolver
.next_optional_arg("language")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
.next_arg("language")?
.expect_nullable_type(&ValueType::Basic(BasicValueType::Str))?
.optional(),
};

let output_schema = make_output_type(BasicValueType::Json);
Expand Down
21 changes: 13 additions & 8 deletions src/ops/functions/split_recursively.rs
Original file line number Diff line number Diff line change
Expand Up @@ -963,19 +963,24 @@ impl SimpleFunctionFactoryBase for Factory {
let args = Args {
text: args_resolver
.next_arg("text")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
.expect_type(&ValueType::Basic(BasicValueType::Str))?
.required()?,
chunk_size: args_resolver
.next_arg("chunk_size")?
.expect_type(&ValueType::Basic(BasicValueType::Int64))?,
.expect_type(&ValueType::Basic(BasicValueType::Int64))?
.required()?,
min_chunk_size: args_resolver
.next_optional_arg("min_chunk_size")?
.expect_type(&ValueType::Basic(BasicValueType::Int64))?,
.next_arg("min_chunk_size")?
.expect_nullable_type(&ValueType::Basic(BasicValueType::Int64))?
.optional(),
chunk_overlap: args_resolver
.next_optional_arg("chunk_overlap")?
.expect_type(&ValueType::Basic(BasicValueType::Int64))?,
.next_arg("chunk_overlap")?
.expect_nullable_type(&ValueType::Basic(BasicValueType::Int64))?
.optional(),
language: args_resolver
.next_optional_arg("language")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
.next_arg("language")?
.expect_nullable_type(&ValueType::Basic(BasicValueType::Str))?
.optional(),
};

let pos_struct = schema::ValueType::Struct(schema::StructSchema {
Expand Down
Loading