Skip to content

Commit a6f77f7

Browse files
committed
Use OpArgsResolver to make resolve for multiple input args easier.
1 parent fea9c77 commit a6f77f7

File tree

5 files changed

+200
-65
lines changed

5 files changed

+200
-65
lines changed

examples/text_embedding/text_embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
2626
cocoindex.functions.SplitRecursively(
2727
language="markdown", chunk_size=300, chunk_overlap=100))
2828

29+
doc["chunks"] = flow_builder.call(
30+
cocoindex.functions.SplitRecursively(),
31+
doc["content"], language="markdown", chunk_size=300, chunk_overlap=100);
32+
2933
with doc["chunks"].row() as chunk:
3034
chunk["embedding"] = text_to_embedding(chunk["text"])
3135
doc_embeddings.collect(filename=doc["filename"], location=chunk["location"],

src/ops/factory_bases.rs

Lines changed: 145 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::HashMap;
12
use std::fmt::Debug;
23
use std::hash::Hash;
34
use std::sync::Arc;
@@ -9,10 +10,142 @@ use serde::Serialize;
910

1011
use super::interface::*;
1112
use super::registry::*;
13+
use crate::api_bail;
14+
use crate::api_error;
1215
use crate::base::schema::*;
1316
use crate::base::spec::*;
17+
use crate::base::value;
18+
use crate::builder::plan::AnalyzedValueMapping;
1419
use crate::setup;
1520
// SourceFactoryBase
21+
pub struct ResolvedOpArg {
22+
pub name: String,
23+
pub typ: EnrichedValueType,
24+
pub idx: usize,
25+
}
26+
27+
impl ResolvedOpArg {
28+
pub fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
29+
if &self.typ.typ != expected_type {
30+
api_bail!(
31+
"Expected argument `{}` to be of type `{}`, got `{}`",
32+
self.name,
33+
expected_type,
34+
self.typ.typ
35+
);
36+
}
37+
Ok(self)
38+
}
39+
40+
pub fn value<'a>(&self, args: &'a Vec<value::Value>) -> Result<&'a value::Value> {
41+
if self.idx >= args.len() {
42+
api_bail!(
43+
"Two few arguments, {} provided, expected at least {} for `{}`",
44+
args.len(),
45+
self.idx + 1,
46+
self.name
47+
);
48+
}
49+
Ok(&args[self.idx])
50+
}
51+
52+
pub fn take_value(&self, args: &mut Vec<value::Value>) -> Result<value::Value> {
53+
if self.idx >= args.len() {
54+
api_bail!(
55+
"Two few arguments, {} provided, expected at least {} for `{}`",
56+
args.len(),
57+
self.idx + 1,
58+
self.name
59+
);
60+
}
61+
Ok(std::mem::take(&mut args[self.idx]))
62+
}
63+
}
64+
65+
pub struct OpArgsResolver<'a> {
66+
args: &'a [OpArgSchema],
67+
num_positional_args: usize,
68+
next_positional_idx: usize,
69+
remaining_kwargs: HashMap<&'a str, usize>,
70+
}
71+
72+
impl<'a> OpArgsResolver<'a> {
73+
pub fn new(args: &'a [OpArgSchema]) -> Result<Self> {
74+
let mut num_positional_args = 0;
75+
let mut kwargs = HashMap::new();
76+
for (idx, arg) in args.iter().enumerate() {
77+
if let Some(name) = &arg.name.0 {
78+
kwargs.insert(name.as_str(), idx);
79+
} else {
80+
if !kwargs.is_empty() {
81+
api_bail!("Positional arguments must be provided before keyword arguments");
82+
}
83+
num_positional_args += 1;
84+
}
85+
}
86+
Ok(Self {
87+
args,
88+
num_positional_args,
89+
next_positional_idx: 0,
90+
remaining_kwargs: kwargs,
91+
})
92+
}
93+
94+
pub fn next_optional_arg(&mut self, name: &str) -> Result<Option<ResolvedOpArg>> {
95+
let idx = if let Some(idx) = self.remaining_kwargs.remove(name) {
96+
if self.next_positional_idx < self.num_positional_args {
97+
api_bail!("`{name}` is provided as both positional and keyword arguments");
98+
} else {
99+
Some(idx)
100+
}
101+
} else {
102+
if self.next_positional_idx < self.num_positional_args {
103+
let idx = self.next_positional_idx;
104+
self.next_positional_idx += 1;
105+
Some(idx)
106+
} else {
107+
None
108+
}
109+
};
110+
Ok(idx.map(|idx| ResolvedOpArg {
111+
name: name.to_string(),
112+
typ: self.args[idx].value_type.clone(),
113+
idx,
114+
}))
115+
}
116+
117+
pub fn next_arg(&mut self, name: &str) -> Result<ResolvedOpArg> {
118+
Ok(self
119+
.next_optional_arg(name)?
120+
.ok_or_else(|| api_error!("Required argument `{name}` is missing",))?)
121+
}
122+
123+
pub fn done(self) -> Result<()> {
124+
if self.next_positional_idx < self.num_positional_args {
125+
api_bail!(
126+
"Expected {} positional arguments, got {}",
127+
self.next_positional_idx,
128+
self.num_positional_args
129+
);
130+
}
131+
if !self.remaining_kwargs.is_empty() {
132+
api_bail!(
133+
"Unexpected keyword arguments: {}",
134+
self.remaining_kwargs
135+
.keys()
136+
.map(|k| format!("`{k}`"))
137+
.collect::<Vec<_>>()
138+
.join(", ")
139+
)
140+
}
141+
Ok(())
142+
}
143+
144+
pub fn get_analyze_value(&self, resolved_arg: &ResolvedOpArg) -> &AnalyzedValueMapping {
145+
&self.args[resolved_arg.idx].analyzed_value
146+
}
147+
}
148+
16149
#[async_trait]
17150
pub trait SourceFactoryBase: SourceFactory + Send + Sync + 'static {
18151
type Spec: DeserializeOwned + Send + Sync;
@@ -63,20 +196,21 @@ impl<T: SourceFactoryBase> SourceFactory for T {
63196
#[async_trait]
64197
pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'static {
65198
type Spec: DeserializeOwned + Send + Sync;
199+
type ResolvedArgs: Send + Sync;
66200

67201
fn name(&self) -> &str;
68202

69-
fn get_output_schema(
70-
&self,
71-
spec: &Self::Spec,
72-
input_schema: &Vec<OpArgSchema>,
203+
fn resolve_schema<'a>(
204+
&'a self,
205+
spec: &'a Self::Spec,
206+
args_resolver: &mut OpArgsResolver<'a>,
73207
context: &FlowInstanceContext,
74-
) -> Result<EnrichedValueType>;
208+
) -> Result<(Self::ResolvedArgs, EnrichedValueType)>;
75209

76210
async fn build_executor(
77211
self: Arc<Self>,
78212
spec: Self::Spec,
79-
input_schema: Vec<OpArgSchema>,
213+
resolved_input_schema: Self::ResolvedArgs,
80214
context: Arc<FlowInstanceContext>,
81215
) -> Result<Box<dyn SimpleFunctionExecutor>>;
82216

@@ -102,8 +236,11 @@ impl<T: SimpleFunctionFactoryBase> SimpleFunctionFactory for T {
102236
ExecutorFuture<'static, Box<dyn SimpleFunctionExecutor>>,
103237
)> {
104238
let spec: T::Spec = serde_json::from_value(spec)?;
105-
let output_schema = self.get_output_schema(&spec, &input_schema, &context)?;
106-
let executor = self.build_executor(spec, input_schema, context);
239+
let mut args_resolver = OpArgsResolver::new(&input_schema)?;
240+
let (resolved_input_schema, output_schema) =
241+
self.resolve_schema(&spec, &mut args_resolver, &context)?;
242+
args_resolver.done()?;
243+
let executor = self.build_executor(spec, resolved_input_schema, context);
107244
Ok((output_schema, executor))
108245
}
109246
}

src/ops/functions/extract_by_llm.rs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ pub struct Spec {
1717
instruction: Option<String>,
1818
}
1919

20+
pub struct Args {
21+
text: ResolvedOpArg,
22+
}
23+
2024
struct Executor {
25+
args: Args,
2126
client: Box<dyn LlmGenerationClient>,
2227
output_json_schema: SchemaObject,
2328
output_type: EnrichedValueType,
@@ -41,8 +46,9 @@ Output only the JSON without any additional messages or explanations."
4146
}
4247

4348
impl Executor {
44-
async fn new(spec: Spec) -> Result<Self> {
49+
async fn new(spec: Spec, args: Args) -> Result<Self> {
4550
Ok(Self {
51+
args,
4652
client: new_llm_generation_client(spec.llm_spec).await?,
4753
output_json_schema: spec.output_type.to_json_schema(),
4854
output_type: spec.output_type,
@@ -62,7 +68,7 @@ impl SimpleFunctionExecutor for Executor {
6268
}
6369

6470
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
65-
let text = input.iter().next().unwrap().as_str()?;
71+
let text = self.args.text.value(&input)?.as_str()?;
6672
let req = LlmGenerateRequest {
6773
system_prompt: Some(Cow::Borrowed(&self.system_prompt)),
6874
user_prompt: Cow::Borrowed(text),
@@ -83,32 +89,34 @@ pub struct Factory;
8389
#[async_trait]
8490
impl SimpleFunctionFactoryBase for Factory {
8591
type Spec = Spec;
92+
type ResolvedArgs = Args;
8693

8794
fn name(&self) -> &str {
8895
"ExtractByLlm"
8996
}
9097

91-
fn get_output_schema(
98+
fn resolve_schema(
9299
&self,
93100
spec: &Spec,
94-
input_schema: &Vec<OpArgSchema>,
101+
args_resolver: &mut OpArgsResolver<'_>,
95102
_context: &FlowInstanceContext,
96-
) -> Result<EnrichedValueType> {
97-
match &expect_input_1(input_schema)?.value_type.typ {
98-
ValueType::Basic(BasicValueType::Str) => {}
99-
t => {
100-
api_bail!("Expect String as input type, got {}", t)
101-
}
102-
}
103-
Ok(spec.output_type.clone())
103+
) -> Result<(Args, EnrichedValueType)> {
104+
Ok((
105+
Args {
106+
text: args_resolver
107+
.next_arg("text")?
108+
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
109+
},
110+
spec.output_type.clone(),
111+
))
104112
}
105113

106114
async fn build_executor(
107115
self: Arc<Self>,
108116
spec: Spec,
109-
_input_schema: Vec<OpArgSchema>,
117+
resolved_input_schema: Args,
110118
_context: Arc<FlowInstanceContext>,
111119
) -> Result<Box<dyn SimpleFunctionExecutor>> {
112-
Ok(Box::new(Executor::new(spec).await?))
120+
Ok(Box::new(Executor::new(spec, resolved_input_schema).await?))
113121
}
114122
}

src/ops/functions/split_recursively.rs

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ pub struct Spec {
1616
chunk_overlap: usize,
1717
}
1818

19+
pub struct Args {
20+
text: ResolvedOpArg,
21+
}
22+
1923
static DEFAULT_SEPARATORS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
2024
[r"\n\n+", r"\n", r"\s+"]
2125
.into_iter()
@@ -95,11 +99,12 @@ static SEPARATORS_BY_LANG: LazyLock<HashMap<&'static str, Vec<Regex>>> = LazyLoc
9599

96100
struct Executor {
97101
spec: Spec,
102+
args: Args,
98103
separators: &'static [Regex],
99104
}
100105

101106
impl Executor {
102-
fn new(spec: Spec) -> Result<Self> {
107+
fn new(spec: Spec, args: Args) -> Result<Self> {
103108
let separators = spec
104109
.language
105110
.as_ref()
@@ -109,7 +114,11 @@ impl Executor {
109114
.map(|v| v.as_slice())
110115
})
111116
.unwrap_or(DEFAULT_SEPARATORS.as_slice());
112-
Ok(Self { spec, separators })
117+
Ok(Self {
118+
spec,
119+
args,
120+
separators,
121+
})
113122
}
114123

115124
fn add_output<'s>(pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) {
@@ -220,14 +229,12 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mu
220229
#[async_trait]
221230
impl SimpleFunctionExecutor for Executor {
222231
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
223-
let str_value = input.into_iter().next().unwrap();
224-
let str_value = str_value.as_str().unwrap();
225-
232+
let text = self.args.text.value(&input)?.as_str()?;
226233
let mut output = Vec::new();
227-
self.split_substring(str_value, 0, 0, &mut output);
234+
self.split_substring(text, 0, 0, &mut output);
228235

229236
translate_bytes_to_chars(
230-
str_value,
237+
text,
231238
output
232239
.iter_mut()
233240
.map(|(range, _)| [&mut range.start, &mut range.end].into_iter())
@@ -248,24 +255,24 @@ pub struct Factory;
248255
#[async_trait]
249256
impl SimpleFunctionFactoryBase for Factory {
250257
type Spec = Spec;
258+
type ResolvedArgs = Args;
251259

252260
fn name(&self) -> &str {
253261
"SplitRecursively"
254262
}
255263

256-
fn get_output_schema(
264+
fn resolve_schema(
257265
&self,
258266
_spec: &Spec,
259-
input_schema: &Vec<OpArgSchema>,
267+
args_resolver: &mut OpArgsResolver<'_>,
260268
_context: &FlowInstanceContext,
261-
) -> Result<EnrichedValueType> {
262-
match &expect_input_1(input_schema)?.value_type.typ {
263-
ValueType::Basic(BasicValueType::Str) => {}
264-
t => {
265-
api_bail!("Expect String as input type, got {}", t)
266-
}
267-
}
268-
Ok(make_output_type(CollectionSchema::new(
269+
) -> Result<(Args, EnrichedValueType)> {
270+
let args = Args {
271+
text: args_resolver
272+
.next_arg("text")?
273+
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
274+
};
275+
let output_schema = make_output_type(CollectionSchema::new(
269276
CollectionKind::Table,
270277
vec![
271278
FieldSchema::new("location", make_output_type(BasicValueType::Range)),
@@ -274,16 +281,17 @@ impl SimpleFunctionFactoryBase for Factory {
274281
))
275282
.with_attr(
276283
field_attrs::CHUNK_BASE_TEXT,
277-
serde_json::to_value(&input_schema[0].analyzed_value)?,
278-
))
284+
serde_json::to_value(&args_resolver.get_analyze_value(&args.text))?,
285+
);
286+
Ok((args, output_schema))
279287
}
280288

281289
async fn build_executor(
282290
self: Arc<Self>,
283291
spec: Spec,
284-
_input_schema: Vec<OpArgSchema>,
292+
args: Args,
285293
_context: Arc<FlowInstanceContext>,
286294
) -> Result<Box<dyn SimpleFunctionExecutor>> {
287-
Ok(Box::new(Executor::new(spec)?))
295+
Ok(Box::new(Executor::new(spec, args)?))
288296
}
289297
}

0 commit comments

Comments
 (0)