Skip to content

Commit 65b0b3a

Browse files
committed
feat(test-util): construct field-based arg schemas
1 parent d347648 commit 65b0b3a

File tree

5 files changed

+40
-47
lines changed

5 files changed

+40
-47
lines changed

src/ops/functions/embed_text.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ mod tests {
104104
#[tokio::test]
105105
#[ignore = "This test requires OpenAI API key or a configured local LLM and may make network calls."]
106106
async fn test_embed_text() {
107-
// Using OpenAI as an example.
108107
let spec = Spec {
109108
api_type: LlmApiType::OpenAi,
110109
model: "text-embedding-ada-002".to_string(),
@@ -118,11 +117,7 @@ mod tests {
118117

119118
let input_args_values = vec![text_content.to_string().into()];
120119

121-
let input_arg_schemas = vec![build_arg_schema(
122-
"text",
123-
text_content.to_string().into(),
124-
BasicValueType::Str,
125-
)];
120+
let input_arg_schemas = vec![build_arg_schema("text", BasicValueType::Str)];
126121

127122
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
128123

src/ops/functions/extract_by_llm.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ mod tests {
185185
attrs: Arc::new(BTreeMap::new()),
186186
};
187187

188-
// Spec using OpenAI as an example.
189188
let spec = Spec {
190189
llm_spec: LlmSpec {
191190
api_type: crate::llm::LlmApiType::OpenAi,
@@ -201,17 +200,13 @@ mod tests {
201200

202201
let input_args_values = vec![text_content.to_string().into()];
203202

204-
let input_arg_schemas = vec![build_arg_schema(
205-
"text",
206-
text_content.to_string().into(),
207-
BasicValueType::Str,
208-
)];
203+
let input_arg_schemas = vec![build_arg_schema("text", BasicValueType::Str)];
209204

210205
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
211206

212207
if result.is_err() {
213208
eprintln!(
214-
"test_extract_by_llm_with_util: test_flow_function returned error (potentially expected for evaluate): {:?}",
209+
"test_extract_by_llm: test_flow_function returned error (potentially expected for evaluate): {:?}",
215210
result.as_ref().err()
216211
);
217212
}

src/ops/functions/parse_json.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,8 @@ mod tests {
120120
let input_args_values = vec![json_string_content.to_string().into(), lang_value.clone()];
121121

122122
let input_arg_schemas = vec![
123-
build_arg_schema(
124-
"text",
125-
json_string_content.to_string().into(),
126-
BasicValueType::Str,
127-
),
128-
build_arg_schema("language", lang_value, BasicValueType::Str),
123+
build_arg_schema("text", BasicValueType::Str),
124+
build_arg_schema("language", BasicValueType::Str),
129125
];
130126

131127
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;

src/ops/functions/split_recursively.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,11 +1090,11 @@ mod tests {
10901090
];
10911091

10921092
let input_arg_schemas = vec![
1093-
build_arg_schema("text", text_content.to_string().into(), BasicValueType::Str),
1094-
build_arg_schema("chunk_size", (15i64).into(), BasicValueType::Int64),
1095-
build_arg_schema("min_chunk_size", (5i64).into(), BasicValueType::Int64),
1096-
build_arg_schema("chunk_overlap", (0i64).into(), BasicValueType::Int64),
1097-
build_arg_schema("language", Value::Null, BasicValueType::Str),
1093+
build_arg_schema("text", BasicValueType::Str),
1094+
build_arg_schema("chunk_size", BasicValueType::Int64),
1095+
build_arg_schema("min_chunk_size", BasicValueType::Int64),
1096+
build_arg_schema("chunk_overlap", BasicValueType::Int64),
1097+
build_arg_schema("language", BasicValueType::Str),
10981098
];
10991099

11001100
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;

src/ops/functions/test_utils.rs

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use crate::builder::plan::AnalyzedValueMapping;
1+
use crate::builder::plan::{
2+
AnalyzedFieldReference, AnalyzedLocalFieldReference, AnalyzedValueMapping,
3+
};
24
use crate::ops::sdk::{
35
AuthRegistry, BasicValueType, EnrichedValueType, FlowInstanceContext, OpArgSchema,
46
OpArgsResolver, SimpleFunctionExecutor, SimpleFunctionFactoryBase, Value, make_output_type,
@@ -7,47 +9,52 @@ use anyhow::Result;
79
use serde::de::DeserializeOwned;
810
use std::sync::Arc;
911

10-
fn new_literal_op_arg_schema(
11-
name: Option<&str>,
12-
value: Value,
13-
value_type: EnrichedValueType,
14-
) -> OpArgSchema {
15-
OpArgSchema {
16-
name: name.map_or(crate::base::spec::OpArgName(None), |n| {
17-
crate::base::spec::OpArgName(Some(n.to_string()))
18-
}),
19-
value_type,
20-
analyzed_value: AnalyzedValueMapping::Constant { value },
21-
}
22-
}
23-
24-
// This function provides a helper to create OpArgSchema for literal values.
25-
pub fn build_arg_schema(name: &str, value: Value, value_type: BasicValueType) -> OpArgSchema {
26-
new_literal_op_arg_schema(Some(name), value, make_output_type(value_type))
12+
// This function builds an argument schema for a flow function.
13+
pub fn build_arg_schema(
14+
name: &str,
15+
value_type: BasicValueType,
16+
) -> (Option<&str>, EnrichedValueType) {
17+
(Some(name), make_output_type(value_type))
2718
}
2819

2920
// This function tests a flow function by providing a spec, input argument schemas, and values.
3021
pub async fn test_flow_function<S, R, F>(
3122
factory: Arc<F>,
3223
spec: S,
33-
input_arg_schemas: Vec<OpArgSchema>,
24+
input_arg_schemas: Vec<(Option<&str>, EnrichedValueType)>,
3425
input_arg_values: Vec<Value>,
3526
) -> Result<Value>
3627
where
3728
S: DeserializeOwned + Send + Sync + 'static,
3829
R: Send + Sync + 'static,
3930
F: SimpleFunctionFactoryBase<Spec = S, ResolvedArgs = R> + ?Sized,
4031
{
32+
// 1. Construct OpArgSchema
33+
let op_arg_schemas: Vec<OpArgSchema> = input_arg_schemas
34+
.into_iter()
35+
.enumerate()
36+
.map(|(idx, (name, value_type))| OpArgSchema {
37+
name: name.map_or(crate::base::spec::OpArgName(None), |n| {
38+
crate::base::spec::OpArgName(Some(n.to_string()))
39+
}),
40+
value_type,
41+
analyzed_value: AnalyzedValueMapping::Field(AnalyzedFieldReference {
42+
local: AnalyzedLocalFieldReference {
43+
fields_idx: vec![idx as u32],
44+
},
45+
scope_up_level: 0,
46+
}),
47+
})
48+
.collect();
49+
50+
// 2. Resolve Schema & Args
51+
let mut args_resolver = OpArgsResolver::new(&op_arg_schemas)?;
4152
let context = Arc::new(FlowInstanceContext {
4253
flow_instance_name: "test_flow_function".to_string(),
4354
auth_registry: Arc::new(AuthRegistry::default()),
4455
py_exec_ctx: None,
4556
});
4657

47-
// 2. Resolve Schema & Args
48-
// The caller of test_flow_function will be responsible for creating these schemas.
49-
let mut args_resolver = OpArgsResolver::new(&input_arg_schemas)?;
50-
5158
let (resolved_args_from_schema, _output_schema): (R, EnrichedValueType) = factory
5259
.resolve_schema(&spec, &mut args_resolver, &context)
5360
.await?;

0 commit comments

Comments
 (0)