Skip to content

Commit 28ec480

Browse files
authored
feat(null-propagation): decouple require and nullability for Rust args (#857)
1 parent c1c2a15 commit 28ec480

File tree

5 files changed

+108
-61
lines changed

5 files changed

+108
-61
lines changed

src/ops/factory_bases.rs

Lines changed: 80 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,76 @@ use std::hash::Hash;
66
use super::interface::*;
77
use super::registry::*;
88
use crate::api_bail;
9-
use crate::api_error;
109
use crate::base::schema::*;
1110
use crate::base::spec::*;
1211
use crate::builder::plan::AnalyzedValueMapping;
1312
use crate::setup;
1413
// SourceFactoryBase
14+
pub struct OpArgResolver<'arg> {
15+
name: String,
16+
resolved_op_arg: Option<(usize, EnrichedValueType)>,
17+
nonnull_args_idx: &'arg mut Vec<usize>,
18+
may_nullify_output: &'arg mut bool,
19+
}
20+
21+
impl<'arg> OpArgResolver<'arg> {
22+
pub fn expect_nullable_type(self, expected_type: &ValueType) -> Result<Self> {
23+
let Some((_, typ)) = &self.resolved_op_arg else {
24+
return Ok(self);
25+
};
26+
if &typ.typ != expected_type {
27+
api_bail!(
28+
"Expected argument `{}` to be of type `{}`, got `{}`",
29+
self.name,
30+
expected_type,
31+
typ.typ
32+
);
33+
}
34+
Ok(self)
35+
}
36+
pub fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
37+
let resolver = self.expect_nullable_type(expected_type)?;
38+
resolver.resolved_op_arg.as_ref().map(|(idx, typ)| {
39+
resolver.nonnull_args_idx.push(*idx);
40+
if typ.nullable {
41+
*resolver.may_nullify_output = true;
42+
}
43+
});
44+
Ok(resolver)
45+
}
46+
47+
pub fn optional(self) -> Option<ResolvedOpArg> {
48+
return self.resolved_op_arg.map(|(idx, typ)| ResolvedOpArg {
49+
name: self.name,
50+
typ,
51+
idx,
52+
});
53+
}
54+
55+
pub fn required(self) -> Result<ResolvedOpArg> {
56+
let Some((idx, typ)) = self.resolved_op_arg else {
57+
api_bail!("Required argument `{}` is missing", self.name);
58+
};
59+
Ok(ResolvedOpArg {
60+
name: self.name,
61+
typ,
62+
idx,
63+
})
64+
}
65+
}
66+
1567
pub struct ResolvedOpArg {
1668
pub name: String,
1769
pub typ: EnrichedValueType,
1870
pub idx: usize,
1971
}
2072

2173
pub trait ResolvedOpArgExt: Sized {
22-
fn expect_type(self, expected_type: &ValueType) -> Result<Self>;
2374
fn value<'a>(&self, args: &'a [value::Value]) -> Result<&'a value::Value>;
2475
fn take_value(&self, args: &mut [value::Value]) -> Result<value::Value>;
2576
}
2677

2778
impl ResolvedOpArgExt for ResolvedOpArg {
28-
fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
29-
if &self.typ.typ != expected_type {
30-
api_bail!(
31-
"Expected argument `{}` to be of type `{}`, got `{}`",
32-
self.name,
33-
expected_type,
34-
self.typ.typ
35-
);
36-
}
37-
Ok(self)
38-
}
39-
4079
fn value<'a>(&self, args: &'a [value::Value]) -> Result<&'a value::Value> {
4180
if self.idx >= args.len() {
4281
api_bail!(
@@ -63,10 +102,6 @@ impl ResolvedOpArgExt for ResolvedOpArg {
63102
}
64103

65104
impl ResolvedOpArgExt for Option<ResolvedOpArg> {
66-
fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
67-
self.map(|arg| arg.expect_type(expected_type)).transpose()
68-
}
69-
70105
fn value<'a>(&self, args: &'a [value::Value]) -> Result<&'a value::Value> {
71106
Ok(self
72107
.as_ref()
@@ -89,11 +124,16 @@ pub struct OpArgsResolver<'a> {
89124
num_positional_args: usize,
90125
next_positional_idx: usize,
91126
remaining_kwargs: HashMap<&'a str, usize>,
92-
required_args_idx: &'a mut Vec<usize>,
127+
nonnull_args_idx: &'a mut Vec<usize>,
128+
may_nullify_output: &'a mut bool,
93129
}
94130

95131
impl<'a> OpArgsResolver<'a> {
96-
pub fn new(args: &'a [OpArgSchema], required_args_idx: &'a mut Vec<usize>) -> Result<Self> {
132+
pub fn new(
133+
args: &'a [OpArgSchema],
134+
nonnull_args_idx: &'a mut Vec<usize>,
135+
may_nullify_output: &'a mut bool,
136+
) -> Result<Self> {
97137
let mut num_positional_args = 0;
98138
let mut kwargs = HashMap::new();
99139
for (idx, arg) in args.iter().enumerate() {
@@ -111,11 +151,12 @@ impl<'a> OpArgsResolver<'a> {
111151
num_positional_args,
112152
next_positional_idx: 0,
113153
remaining_kwargs: kwargs,
114-
required_args_idx,
154+
nonnull_args_idx,
155+
may_nullify_output,
115156
})
116157
}
117158

118-
pub fn next_optional_arg(&mut self, name: &str) -> Result<Option<ResolvedOpArg>> {
159+
pub fn next_arg<'arg>(&'arg mut self, name: &str) -> Result<OpArgResolver<'arg>> {
119160
let idx = if let Some(idx) = self.remaining_kwargs.remove(name) {
120161
if self.next_positional_idx < self.num_positional_args {
121162
api_bail!("`{name}` is provided as both positional and keyword arguments");
@@ -129,19 +170,12 @@ impl<'a> OpArgsResolver<'a> {
129170
} else {
130171
None
131172
};
132-
Ok(idx.map(|idx| ResolvedOpArg {
173+
Ok(OpArgResolver {
133174
name: name.to_string(),
134-
typ: self.args[idx].value_type.clone(),
135-
idx,
136-
}))
137-
}
138-
139-
pub fn next_arg(&mut self, name: &str) -> Result<ResolvedOpArg> {
140-
let arg = self
141-
.next_optional_arg(name)?
142-
.ok_or_else(|| api_error!("Required argument `{name}` is missing",))?;
143-
self.required_args_idx.push(arg.idx);
144-
Ok(arg)
175+
resolved_op_arg: idx.map(|idx| (idx, self.args[idx].value_type.clone())),
176+
nonnull_args_idx: self.nonnull_args_idx,
177+
may_nullify_output: self.may_nullify_output,
178+
})
145179
}
146180

147181
pub fn done(self) -> Result<()> {
@@ -252,13 +286,13 @@ pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'stat
252286

253287
struct FunctionExecutorWrapper<E: SimpleFunctionExecutor> {
254288
executor: E,
255-
required_args_idx: Vec<usize>,
289+
nonnull_args_idx: Vec<usize>,
256290
}
257291

258292
#[async_trait]
259293
impl<E: SimpleFunctionExecutor> SimpleFunctionExecutor for FunctionExecutorWrapper<E> {
260294
async fn evaluate(&self, args: Vec<value::Value>) -> Result<value::Value> {
261-
for idx in &self.required_args_idx {
295+
for idx in &self.nonnull_args_idx {
262296
if args[*idx].is_null() {
263297
return Ok(value::Value::Null);
264298
}
@@ -287,28 +321,29 @@ impl<T: SimpleFunctionFactoryBase> SimpleFunctionFactory for T {
287321
BoxFuture<'static, Result<Box<dyn SimpleFunctionExecutor>>>,
288322
)> {
289323
let spec: T::Spec = serde_json::from_value(spec)?;
290-
let mut required_args_idx = vec![];
291-
let mut args_resolver = OpArgsResolver::new(&input_schema, &mut required_args_idx)?;
324+
let mut nonnull_args_idx = vec![];
325+
let mut may_nullify_output = false;
326+
let mut args_resolver = OpArgsResolver::new(
327+
&input_schema,
328+
&mut nonnull_args_idx,
329+
&mut may_nullify_output,
330+
)?;
292331
let (resolved_input_schema, mut output_schema) = self
293332
.resolve_schema(&spec, &mut args_resolver, &context)
294333
.await?;
334+
args_resolver.done()?;
295335

296336
// If any required argument is nullable, the output schema should be nullable.
297-
if args_resolver
298-
.required_args_idx
299-
.iter()
300-
.any(|idx| input_schema[*idx].value_type.nullable)
301-
{
337+
if may_nullify_output {
302338
output_schema.nullable = true;
303339
}
304340

305-
args_resolver.done()?;
306341
let executor = async move {
307342
Ok(Box::new(FunctionExecutorWrapper {
308343
executor: self
309344
.build_executor(spec, resolved_input_schema, context)
310345
.await?,
311-
required_args_idx,
346+
nonnull_args_idx,
312347
}) as Box<dyn SimpleFunctionExecutor>)
313348
};
314349
Ok((output_schema, Box::pin(executor)))

src/ops/functions/embed_text.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ impl SimpleFunctionFactoryBase for Factory {
6969
args_resolver: &mut OpArgsResolver<'a>,
7070
_context: &FlowInstanceContext,
7171
) -> Result<(Self::ResolvedArgs, EnrichedValueType)> {
72-
let text = args_resolver.next_arg("text")?;
72+
let text = args_resolver
73+
.next_arg("text")?
74+
.expect_type(&ValueType::Basic(BasicValueType::Str))?
75+
.required()?;
7376
let client =
7477
new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone())
7578
.await?;

src/ops/functions/extract_by_llm.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,13 @@ impl SimpleFunctionFactoryBase for Factory {
138138
) -> Result<(Args, EnrichedValueType)> {
139139
let args = Args {
140140
text: args_resolver
141-
.next_optional_arg("text")?
142-
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
141+
.next_arg("text")?
142+
.expect_nullable_type(&ValueType::Basic(BasicValueType::Str))?
143+
.optional(),
143144
image: args_resolver
144-
.next_optional_arg("image")?
145-
.expect_type(&ValueType::Basic(BasicValueType::Bytes))?,
145+
.next_arg("image")?
146+
.expect_nullable_type(&ValueType::Basic(BasicValueType::Bytes))?
147+
.optional(),
146148
};
147149

148150
if args.text.is_none() && args.image.is_none() {

src/ops/functions/parse_json.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,12 @@ impl SimpleFunctionFactoryBase for Factory {
8383
let args = Args {
8484
text: args_resolver
8585
.next_arg("text")?
86-
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
86+
.expect_type(&ValueType::Basic(BasicValueType::Str))?
87+
.required()?,
8788
language: args_resolver
88-
.next_optional_arg("language")?
89-
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
89+
.next_arg("language")?
90+
.expect_nullable_type(&ValueType::Basic(BasicValueType::Str))?
91+
.optional(),
9092
};
9193

9294
let output_schema = make_output_type(BasicValueType::Json);

src/ops/functions/split_recursively.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -963,19 +963,24 @@ impl SimpleFunctionFactoryBase for Factory {
963963
let args = Args {
964964
text: args_resolver
965965
.next_arg("text")?
966-
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
966+
.expect_type(&ValueType::Basic(BasicValueType::Str))?
967+
.required()?,
967968
chunk_size: args_resolver
968969
.next_arg("chunk_size")?
969-
.expect_type(&ValueType::Basic(BasicValueType::Int64))?,
970+
.expect_type(&ValueType::Basic(BasicValueType::Int64))?
971+
.required()?,
970972
min_chunk_size: args_resolver
971-
.next_optional_arg("min_chunk_size")?
972-
.expect_type(&ValueType::Basic(BasicValueType::Int64))?,
973+
.next_arg("min_chunk_size")?
974+
.expect_nullable_type(&ValueType::Basic(BasicValueType::Int64))?
975+
.optional(),
973976
chunk_overlap: args_resolver
974-
.next_optional_arg("chunk_overlap")?
975-
.expect_type(&ValueType::Basic(BasicValueType::Int64))?,
977+
.next_arg("chunk_overlap")?
978+
.expect_nullable_type(&ValueType::Basic(BasicValueType::Int64))?
979+
.optional(),
976980
language: args_resolver
977-
.next_optional_arg("language")?
978-
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
981+
.next_arg("language")?
982+
.expect_nullable_type(&ValueType::Basic(BasicValueType::Str))?
983+
.optional(),
979984
};
980985

981986
let pos_struct = schema::ValueType::Struct(schema::StructSchema {

0 commit comments

Comments
 (0)