Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/text_embedding/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
cocoindex.functions.SplitRecursively(
language="markdown", chunk_size=300, chunk_overlap=100))

doc["chunks"] = flow_builder.call(
cocoindex.functions.SplitRecursively(),
doc["content"], language="markdown", chunk_size=300, chunk_overlap=100);

with doc["chunks"].row() as chunk:
chunk["embedding"] = text_to_embedding(chunk["text"])
doc_embeddings.collect(filename=doc["filename"], location=chunk["location"],
Expand Down
153 changes: 145 additions & 8 deletions src/ops/factory_bases.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;
Expand All @@ -9,10 +10,142 @@ use serde::Serialize;

use super::interface::*;
use super::registry::*;
use crate::api_bail;
use crate::api_error;
use crate::base::schema::*;
use crate::base::spec::*;
use crate::base::value;
use crate::builder::plan::AnalyzedValueMapping;
use crate::setup;
// SourceFactoryBase
pub struct ResolvedOpArg {
pub name: String,
pub typ: EnrichedValueType,
pub idx: usize,
}

impl ResolvedOpArg {
pub fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
if &self.typ.typ != expected_type {
api_bail!(
"Expected argument `{}` to be of type `{}`, got `{}`",
self.name,
expected_type,
self.typ.typ
);
}
Ok(self)
}

pub fn value<'a>(&self, args: &'a Vec<value::Value>) -> Result<&'a value::Value> {
if self.idx >= args.len() {
api_bail!(
"Two few arguments, {} provided, expected at least {} for `{}`",
args.len(),
self.idx + 1,
self.name
);
}
Ok(&args[self.idx])
}

pub fn take_value(&self, args: &mut Vec<value::Value>) -> Result<value::Value> {
if self.idx >= args.len() {
api_bail!(
"Two few arguments, {} provided, expected at least {} for `{}`",
args.len(),
self.idx + 1,
self.name
);
}
Ok(std::mem::take(&mut args[self.idx]))
}
}

pub struct OpArgsResolver<'a> {
args: &'a [OpArgSchema],
num_positional_args: usize,
next_positional_idx: usize,
remaining_kwargs: HashMap<&'a str, usize>,
}

impl<'a> OpArgsResolver<'a> {
pub fn new(args: &'a [OpArgSchema]) -> Result<Self> {
let mut num_positional_args = 0;
let mut kwargs = HashMap::new();
for (idx, arg) in args.iter().enumerate() {
if let Some(name) = &arg.name.0 {
kwargs.insert(name.as_str(), idx);
} else {
if !kwargs.is_empty() {
api_bail!("Positional arguments must be provided before keyword arguments");
}
num_positional_args += 1;
}
}
Ok(Self {
args,
num_positional_args,
next_positional_idx: 0,
remaining_kwargs: kwargs,
})
}

pub fn next_optional_arg(&mut self, name: &str) -> Result<Option<ResolvedOpArg>> {
let idx = if let Some(idx) = self.remaining_kwargs.remove(name) {
if self.next_positional_idx < self.num_positional_args {
api_bail!("`{name}` is provided as both positional and keyword arguments");
} else {
Some(idx)
}
} else {
if self.next_positional_idx < self.num_positional_args {
let idx = self.next_positional_idx;
self.next_positional_idx += 1;
Some(idx)
} else {
None
}
};
Ok(idx.map(|idx| ResolvedOpArg {
name: name.to_string(),
typ: self.args[idx].value_type.clone(),
idx,
}))
}

pub fn next_arg(&mut self, name: &str) -> Result<ResolvedOpArg> {
Ok(self
.next_optional_arg(name)?
.ok_or_else(|| api_error!("Required argument `{name}` is missing",))?)
}

pub fn done(self) -> Result<()> {
if self.next_positional_idx < self.num_positional_args {
api_bail!(
"Expected {} positional arguments, got {}",
self.next_positional_idx,
self.num_positional_args
);
}
if !self.remaining_kwargs.is_empty() {
api_bail!(
"Unexpected keyword arguments: {}",
self.remaining_kwargs
.keys()
.map(|k| format!("`{k}`"))
.collect::<Vec<_>>()
.join(", ")
)
}
Ok(())
}

pub fn get_analyze_value(&self, resolved_arg: &ResolvedOpArg) -> &AnalyzedValueMapping {
&self.args[resolved_arg.idx].analyzed_value
}
}

#[async_trait]
pub trait SourceFactoryBase: SourceFactory + Send + Sync + 'static {
type Spec: DeserializeOwned + Send + Sync;
Expand Down Expand Up @@ -63,20 +196,21 @@ impl<T: SourceFactoryBase> SourceFactory for T {
#[async_trait]
pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'static {
type Spec: DeserializeOwned + Send + Sync;
type ResolvedArgs: Send + Sync;

fn name(&self) -> &str;

fn get_output_schema(
&self,
spec: &Self::Spec,
input_schema: &Vec<OpArgSchema>,
fn resolve_schema<'a>(
&'a self,
spec: &'a Self::Spec,
args_resolver: &mut OpArgsResolver<'a>,
context: &FlowInstanceContext,
) -> Result<EnrichedValueType>;
) -> Result<(Self::ResolvedArgs, EnrichedValueType)>;

async fn build_executor(
self: Arc<Self>,
spec: Self::Spec,
input_schema: Vec<OpArgSchema>,
resolved_input_schema: Self::ResolvedArgs,
context: Arc<FlowInstanceContext>,
) -> Result<Box<dyn SimpleFunctionExecutor>>;

Expand All @@ -102,8 +236,11 @@ impl<T: SimpleFunctionFactoryBase> SimpleFunctionFactory for T {
ExecutorFuture<'static, Box<dyn SimpleFunctionExecutor>>,
)> {
let spec: T::Spec = serde_json::from_value(spec)?;
let output_schema = self.get_output_schema(&spec, &input_schema, &context)?;
let executor = self.build_executor(spec, input_schema, context);
let mut args_resolver = OpArgsResolver::new(&input_schema)?;
let (resolved_input_schema, output_schema) =
self.resolve_schema(&spec, &mut args_resolver, &context)?;
args_resolver.done()?;
let executor = self.build_executor(spec, resolved_input_schema, context);
Ok((output_schema, executor))
}
}
Expand Down
36 changes: 22 additions & 14 deletions src/ops/functions/extract_by_llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ pub struct Spec {
instruction: Option<String>,
}

pub struct Args {
text: ResolvedOpArg,
}

struct Executor {
args: Args,
client: Box<dyn LlmGenerationClient>,
output_json_schema: SchemaObject,
output_type: EnrichedValueType,
Expand All @@ -41,8 +46,9 @@ Output only the JSON without any additional messages or explanations."
}

impl Executor {
async fn new(spec: Spec) -> Result<Self> {
async fn new(spec: Spec, args: Args) -> Result<Self> {
Ok(Self {
args,
client: new_llm_generation_client(spec.llm_spec).await?,
output_json_schema: spec.output_type.to_json_schema(),
output_type: spec.output_type,
Expand All @@ -62,7 +68,7 @@ impl SimpleFunctionExecutor for Executor {
}

async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
let text = input.iter().next().unwrap().as_str()?;
let text = self.args.text.value(&input)?.as_str()?;
let req = LlmGenerateRequest {
system_prompt: Some(Cow::Borrowed(&self.system_prompt)),
user_prompt: Cow::Borrowed(text),
Expand All @@ -83,32 +89,34 @@ pub struct Factory;
#[async_trait]
impl SimpleFunctionFactoryBase for Factory {
type Spec = Spec;
type ResolvedArgs = Args;

fn name(&self) -> &str {
"ExtractByLlm"
}

fn get_output_schema(
fn resolve_schema(
&self,
spec: &Spec,
input_schema: &Vec<OpArgSchema>,
args_resolver: &mut OpArgsResolver<'_>,
_context: &FlowInstanceContext,
) -> Result<EnrichedValueType> {
match &expect_input_1(input_schema)?.value_type.typ {
ValueType::Basic(BasicValueType::Str) => {}
t => {
api_bail!("Expect String as input type, got {}", t)
}
}
Ok(spec.output_type.clone())
) -> Result<(Args, EnrichedValueType)> {
Ok((
Args {
text: args_resolver
.next_arg("text")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
},
spec.output_type.clone(),
))
}

async fn build_executor(
self: Arc<Self>,
spec: Spec,
_input_schema: Vec<OpArgSchema>,
resolved_input_schema: Args,
_context: Arc<FlowInstanceContext>,
) -> Result<Box<dyn SimpleFunctionExecutor>> {
Ok(Box::new(Executor::new(spec).await?))
Ok(Box::new(Executor::new(spec, resolved_input_schema).await?))
}
}
50 changes: 29 additions & 21 deletions src/ops/functions/split_recursively.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ pub struct Spec {
chunk_overlap: usize,
}

pub struct Args {
text: ResolvedOpArg,
}

static DEFAULT_SEPARATORS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
[r"\n\n+", r"\n", r"\s+"]
.into_iter()
Expand Down Expand Up @@ -95,11 +99,12 @@ static SEPARATORS_BY_LANG: LazyLock<HashMap<&'static str, Vec<Regex>>> = LazyLoc

struct Executor {
spec: Spec,
args: Args,
separators: &'static [Regex],
}

impl Executor {
fn new(spec: Spec) -> Result<Self> {
fn new(spec: Spec, args: Args) -> Result<Self> {
let separators = spec
.language
.as_ref()
Expand All @@ -109,7 +114,11 @@ impl Executor {
.map(|v| v.as_slice())
})
.unwrap_or(DEFAULT_SEPARATORS.as_slice());
Ok(Self { spec, separators })
Ok(Self {
spec,
args,
separators,
})
}

fn add_output<'s>(pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) {
Expand Down Expand Up @@ -220,14 +229,12 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mu
#[async_trait]
impl SimpleFunctionExecutor for Executor {
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
let str_value = input.into_iter().next().unwrap();
let str_value = str_value.as_str().unwrap();

let text = self.args.text.value(&input)?.as_str()?;
let mut output = Vec::new();
self.split_substring(str_value, 0, 0, &mut output);
self.split_substring(text, 0, 0, &mut output);

translate_bytes_to_chars(
str_value,
text,
output
.iter_mut()
.map(|(range, _)| [&mut range.start, &mut range.end].into_iter())
Expand All @@ -248,24 +255,24 @@ pub struct Factory;
#[async_trait]
impl SimpleFunctionFactoryBase for Factory {
type Spec = Spec;
type ResolvedArgs = Args;

fn name(&self) -> &str {
"SplitRecursively"
}

fn get_output_schema(
fn resolve_schema(
&self,
_spec: &Spec,
input_schema: &Vec<OpArgSchema>,
args_resolver: &mut OpArgsResolver<'_>,
_context: &FlowInstanceContext,
) -> Result<EnrichedValueType> {
match &expect_input_1(input_schema)?.value_type.typ {
ValueType::Basic(BasicValueType::Str) => {}
t => {
api_bail!("Expect String as input type, got {}", t)
}
}
Ok(make_output_type(CollectionSchema::new(
) -> Result<(Args, EnrichedValueType)> {
let args = Args {
text: args_resolver
.next_arg("text")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
};
let output_schema = make_output_type(CollectionSchema::new(
CollectionKind::Table,
vec![
FieldSchema::new("location", make_output_type(BasicValueType::Range)),
Expand All @@ -274,16 +281,17 @@ impl SimpleFunctionFactoryBase for Factory {
))
.with_attr(
field_attrs::CHUNK_BASE_TEXT,
serde_json::to_value(&input_schema[0].analyzed_value)?,
))
serde_json::to_value(&args_resolver.get_analyze_value(&args.text))?,
);
Ok((args, output_schema))
}

async fn build_executor(
self: Arc<Self>,
spec: Spec,
_input_schema: Vec<OpArgSchema>,
args: Args,
_context: Arc<FlowInstanceContext>,
) -> Result<Box<dyn SimpleFunctionExecutor>> {
Ok(Box::new(Executor::new(spec)?))
Ok(Box::new(Executor::new(spec, args)?))
}
}
Loading