diff --git a/examples/text_embedding/text_embedding.py b/examples/text_embedding/text_embedding.py index 7b7e491a6..510340f70 100644 --- a/examples/text_embedding/text_embedding.py +++ b/examples/text_embedding/text_embedding.py @@ -26,6 +26,10 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind cocoindex.functions.SplitRecursively( language="markdown", chunk_size=300, chunk_overlap=100)) + doc["chunks"] = flow_builder.call( + cocoindex.functions.SplitRecursively(), + doc["content"], language="markdown", chunk_size=300, chunk_overlap=100); + with doc["chunks"].row() as chunk: chunk["embedding"] = text_to_embedding(chunk["text"]) doc_embeddings.collect(filename=doc["filename"], location=chunk["location"], diff --git a/src/ops/factory_bases.rs b/src/ops/factory_bases.rs index 5a3a73bd1..40e18634d 100644 --- a/src/ops/factory_bases.rs +++ b/src/ops/factory_bases.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::fmt::Debug; use std::hash::Hash; use std::sync::Arc; @@ -9,10 +10,142 @@ use serde::Serialize; use super::interface::*; use super::registry::*; +use crate::api_bail; +use crate::api_error; use crate::base::schema::*; use crate::base::spec::*; +use crate::base::value; +use crate::builder::plan::AnalyzedValueMapping; use crate::setup; // SourceFactoryBase +pub struct ResolvedOpArg { + pub name: String, + pub typ: EnrichedValueType, + pub idx: usize, +} + +impl ResolvedOpArg { + pub fn expect_type(self, expected_type: &ValueType) -> Result { + if &self.typ.typ != expected_type { + api_bail!( + "Expected argument `{}` to be of type `{}`, got `{}`", + self.name, + expected_type, + self.typ.typ + ); + } + Ok(self) + } + + pub fn value<'a>(&self, args: &'a Vec) -> Result<&'a value::Value> { + if self.idx >= args.len() { + api_bail!( + "Two few arguments, {} provided, expected at least {} for `{}`", + args.len(), + self.idx + 1, + self.name + ); + } + Ok(&args[self.idx]) + } + + pub fn take_value(&self, args: &mut Vec) -> Result { + if self.idx >= args.len() { + api_bail!( + "Two few arguments, {} provided, expected at least {} for `{}`", + args.len(), + self.idx + 1, + self.name + ); + } + Ok(std::mem::take(&mut args[self.idx])) + } +} + +pub struct OpArgsResolver<'a> { + args: &'a [OpArgSchema], + num_positional_args: usize, + next_positional_idx: usize, + remaining_kwargs: HashMap<&'a str, usize>, +} + +impl<'a> OpArgsResolver<'a> { + pub fn new(args: &'a [OpArgSchema]) -> Result { + let mut num_positional_args = 0; + let mut kwargs = HashMap::new(); + for (idx, arg) in args.iter().enumerate() { + if let Some(name) = &arg.name.0 { + kwargs.insert(name.as_str(), idx); + } else { + if !kwargs.is_empty() { + api_bail!("Positional arguments must be provided before keyword arguments"); + } + num_positional_args += 1; + } + } + Ok(Self { + args, + num_positional_args, + next_positional_idx: 0, + remaining_kwargs: kwargs, + }) + } + + pub fn next_optional_arg(&mut self, name: &str) -> Result> { + let idx = if let Some(idx) = self.remaining_kwargs.remove(name) { + if self.next_positional_idx < self.num_positional_args { + api_bail!("`{name}` is provided as both positional and keyword arguments"); + } else { + Some(idx) + } + } else { + if self.next_positional_idx < self.num_positional_args { + let idx = self.next_positional_idx; + self.next_positional_idx += 1; + Some(idx) + } else { + None + } + }; + Ok(idx.map(|idx| ResolvedOpArg { + name: name.to_string(), + typ: self.args[idx].value_type.clone(), + idx, + })) + } + + pub fn next_arg(&mut self, name: &str) -> Result { + Ok(self + .next_optional_arg(name)? + .ok_or_else(|| api_error!("Required argument `{name}` is missing",))?) + } + + pub fn done(self) -> Result<()> { + if self.next_positional_idx < self.num_positional_args { + api_bail!( + "Expected {} positional arguments, got {}", + self.next_positional_idx, + self.num_positional_args + ); + } + if !self.remaining_kwargs.is_empty() { + api_bail!( + "Unexpected keyword arguments: {}", + self.remaining_kwargs + .keys() + .map(|k| format!("`{k}`")) + .collect::>() + .join(", ") + ) + } + Ok(()) + } + + pub fn get_analyze_value(&self, resolved_arg: &ResolvedOpArg) -> &AnalyzedValueMapping { + &self.args[resolved_arg.idx].analyzed_value + } +} + #[async_trait] pub trait SourceFactoryBase: SourceFactory + Send + Sync + 'static { type Spec: DeserializeOwned + Send + Sync; @@ -63,20 +196,21 @@ impl SourceFactory for T { #[async_trait] pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'static { type Spec: DeserializeOwned + Send + Sync; + type ResolvedArgs: Send + Sync; fn name(&self) -> &str; - fn get_output_schema( - &self, - spec: &Self::Spec, - input_schema: &Vec, + fn resolve_schema<'a>( + &'a self, + spec: &'a Self::Spec, + args_resolver: &mut OpArgsResolver<'a>, context: &FlowInstanceContext, - ) -> Result; + ) -> Result<(Self::ResolvedArgs, EnrichedValueType)>; async fn build_executor( self: Arc, spec: Self::Spec, - input_schema: Vec, + resolved_input_schema: Self::ResolvedArgs, context: Arc, ) -> Result>; @@ -102,8 +236,11 @@ impl SimpleFunctionFactory for T { ExecutorFuture<'static, Box>, )> { let spec: T::Spec = serde_json::from_value(spec)?; - let output_schema = self.get_output_schema(&spec, &input_schema, &context)?; - let executor = self.build_executor(spec, input_schema, context); + let mut args_resolver = OpArgsResolver::new(&input_schema)?; + let (resolved_input_schema, output_schema) = + self.resolve_schema(&spec, &mut args_resolver, &context)?; + args_resolver.done()?; + let executor = self.build_executor(spec, resolved_input_schema, context); Ok((output_schema, executor)) } } diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 2dfe88a63..35a06474d 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -17,7 +17,12 @@ pub struct Spec { instruction: Option, } +pub struct Args { + text: ResolvedOpArg, +} + struct Executor { + args: Args, client: Box, output_json_schema: SchemaObject, output_type: EnrichedValueType, @@ -41,8 +46,9 @@ Output only the JSON without any additional messages or explanations." } impl Executor { - async fn new(spec: Spec) -> Result { + async fn new(spec: Spec, args: Args) -> Result { Ok(Self { + args, client: new_llm_generation_client(spec.llm_spec).await?, output_json_schema: spec.output_type.to_json_schema(), output_type: spec.output_type, @@ -62,7 +68,7 @@ impl SimpleFunctionExecutor for Executor { } async fn evaluate(&self, input: Vec) -> Result { - let text = input.iter().next().unwrap().as_str()?; + let text = self.args.text.value(&input)?.as_str()?; let req = LlmGenerateRequest { system_prompt: Some(Cow::Borrowed(&self.system_prompt)), user_prompt: Cow::Borrowed(text), @@ -83,32 +89,34 @@ pub struct Factory; #[async_trait] impl SimpleFunctionFactoryBase for Factory { type Spec = Spec; + type ResolvedArgs = Args; fn name(&self) -> &str { "ExtractByLlm" } - fn get_output_schema( + fn resolve_schema( &self, spec: &Spec, - input_schema: &Vec, + args_resolver: &mut OpArgsResolver<'_>, _context: &FlowInstanceContext, - ) -> Result { - match &expect_input_1(input_schema)?.value_type.typ { - ValueType::Basic(BasicValueType::Str) => {} - t => { - api_bail!("Expect String as input type, got {}", t) - } - } - Ok(spec.output_type.clone()) + ) -> Result<(Args, EnrichedValueType)> { + Ok(( + Args { + text: args_resolver + .next_arg("text")? + .expect_type(&ValueType::Basic(BasicValueType::Str))?, + }, + spec.output_type.clone(), + )) } async fn build_executor( self: Arc, spec: Spec, - _input_schema: Vec, + resolved_input_schema: Args, _context: Arc, ) -> Result> { - Ok(Box::new(Executor::new(spec).await?)) + Ok(Box::new(Executor::new(spec, resolved_input_schema).await?)) } } diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index 9dbdbf581..798f876fe 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -16,6 +16,10 @@ pub struct Spec { chunk_overlap: usize, } +pub struct Args { + text: ResolvedOpArg, +} + static DEFAULT_SEPARATORS: LazyLock> = LazyLock::new(|| { [r"\n\n+", r"\n", r"\s+"] .into_iter() @@ -95,11 +99,12 @@ static SEPARATORS_BY_LANG: LazyLock>> = LazyLoc struct Executor { spec: Spec, + args: Args, separators: &'static [Regex], } impl Executor { - fn new(spec: Spec) -> Result { + fn new(spec: Spec, args: Args) -> Result { let separators = spec .language .as_ref() @@ -109,7 +114,11 @@ impl Executor { .map(|v| v.as_slice()) }) .unwrap_or(DEFAULT_SEPARATORS.as_slice()); - Ok(Self { spec, separators }) + Ok(Self { + spec, + args, + separators, + }) } fn add_output<'s>(pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) { @@ -220,14 +229,12 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator) -> Result { - let str_value = input.into_iter().next().unwrap(); - let str_value = str_value.as_str().unwrap(); - + let text = self.args.text.value(&input)?.as_str()?; let mut output = Vec::new(); - self.split_substring(str_value, 0, 0, &mut output); + self.split_substring(text, 0, 0, &mut output); translate_bytes_to_chars( - str_value, + text, output .iter_mut() .map(|(range, _)| [&mut range.start, &mut range.end].into_iter()) @@ -248,24 +255,24 @@ pub struct Factory; #[async_trait] impl SimpleFunctionFactoryBase for Factory { type Spec = Spec; + type ResolvedArgs = Args; fn name(&self) -> &str { "SplitRecursively" } - fn get_output_schema( + fn resolve_schema( &self, _spec: &Spec, - input_schema: &Vec, + args_resolver: &mut OpArgsResolver<'_>, _context: &FlowInstanceContext, - ) -> Result { - match &expect_input_1(input_schema)?.value_type.typ { - ValueType::Basic(BasicValueType::Str) => {} - t => { - api_bail!("Expect String as input type, got {}", t) - } - } - Ok(make_output_type(CollectionSchema::new( + ) -> Result<(Args, EnrichedValueType)> { + let args = Args { + text: args_resolver + .next_arg("text")? + .expect_type(&ValueType::Basic(BasicValueType::Str))?, + }; + let output_schema = make_output_type(CollectionSchema::new( CollectionKind::Table, vec![ FieldSchema::new("location", make_output_type(BasicValueType::Range)), @@ -274,16 +281,17 @@ impl SimpleFunctionFactoryBase for Factory { )) .with_attr( field_attrs::CHUNK_BASE_TEXT, - serde_json::to_value(&input_schema[0].analyzed_value)?, - )) + serde_json::to_value(&args_resolver.get_analyze_value(&args.text))?, + ); + Ok((args, output_schema)) } async fn build_executor( self: Arc, spec: Spec, - _input_schema: Vec, + args: Args, _context: Arc, ) -> Result> { - Ok(Box::new(Executor::new(spec)?)) + Ok(Box::new(Executor::new(spec, args)?)) } } diff --git a/src/ops/sdk.rs b/src/ops/sdk.rs index f65fa6baf..8631ac500 100644 --- a/src/ops/sdk.rs +++ b/src/ops/sdk.rs @@ -1,6 +1,5 @@ pub use super::factory_bases::*; pub use super::interface::*; -pub use crate::api_bail; pub use crate::base::schema::*; pub use crate::base::value::*; pub use anyhow::Result; @@ -41,27 +40,6 @@ pub fn make_output_type(value_type: Type) -> EnrichedValueType { #[derive(Debug, Deserialize)] pub struct EmptySpec {} -pub fn expect_input_0(input_schema: &Vec) -> Result<()> { - if input_schema.len() != 1 { - api_bail!("Expected 0 input field, got {}", input_schema.len()); - } - Ok(()) -} - -pub fn expect_input_1(input_schema: &Vec) -> Result<&OpArgSchema> { - if input_schema.len() != 1 { - api_bail!("Expected 1 input field, got {}", input_schema.len()); - } - Ok(&input_schema[0]) -} - -pub fn expect_input_2(input_schema: &Vec) -> Result<(&OpArgSchema, &OpArgSchema)> { - if input_schema.len() != 2 { - api_bail!("Expected 2 input fields, got {}", input_schema.len()); - } - Ok((&input_schema[0], &input_schema[1])) -} - #[macro_export] macro_rules! fields_value { ($($field:expr), +) => {