Skip to content

Commit 2089997

Browse files
authored
feat(rs): add utilities to streamline tests for flow functions (#723)
* feat(rs): add utility functions for testing flow functions * test: add tests for all flow functions * refactor(tests): simplify test util using structured Spec instead of JSON * feat(test-util): construct field-based arg schemas
1 parent 410c2c5 commit 2089997

File tree

6 files changed

+329
-0
lines changed

6 files changed

+329
-0
lines changed

src/ops/functions/embed_text.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,58 @@ impl SimpleFunctionFactoryBase for Factory {
9595
pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
9696
Factory.register(registry)
9797
}
98+
99+
#[cfg(test)]
100+
mod tests {
101+
use super::*;
102+
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
103+
104+
#[tokio::test]
105+
#[ignore = "This test requires OpenAI API key or a configured local LLM and may make network calls."]
106+
async fn test_embed_text() {
107+
let spec = Spec {
108+
api_type: LlmApiType::OpenAi,
109+
model: "text-embedding-ada-002".to_string(),
110+
address: None,
111+
output_dimension: None,
112+
task_type: None,
113+
};
114+
115+
let factory = Arc::new(Factory);
116+
let text_content = "CocoIndex is a performant data transformation framework for AI.";
117+
118+
let input_args_values = vec![text_content.to_string().into()];
119+
120+
let input_arg_schemas = vec![build_arg_schema("text", BasicValueType::Str)];
121+
122+
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
123+
124+
if result.is_err() {
125+
eprintln!(
126+
"test_embed_text: test_flow_function returned error (potentially expected for evaluate): {:?}",
127+
result.as_ref().err()
128+
);
129+
}
130+
131+
assert!(
132+
result.is_ok(),
133+
"test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}",
134+
result.err()
135+
);
136+
137+
let value = result.unwrap();
138+
139+
match value {
140+
Value::Basic(BasicValue::Vector(arc_vec)) => {
141+
assert_eq!(arc_vec.len(), 1536, "Embedding vector dimension mismatch");
142+
for item in arc_vec.iter() {
143+
match item {
144+
BasicValue::Float32(_) => {}
145+
_ => panic!("Embedding vector element is not Float32: {:?}", item),
146+
}
147+
}
148+
}
149+
_ => panic!("Expected Value::Basic(BasicValue::Vector), got {:?}", value),
150+
}
151+
}
152+
}

src/ops/functions/extract_by_llm.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,97 @@ impl SimpleFunctionFactoryBase for Factory {
155155
Ok(Box::new(Executor::new(spec, resolved_input_schema).await?))
156156
}
157157
}
158+
159+
#[cfg(test)]
160+
mod tests {
161+
use super::*;
162+
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
163+
164+
#[tokio::test]
165+
#[ignore = "This test requires an OpenAI API key or a configured local LLM and may make network calls."]
166+
async fn test_extract_by_llm() {
167+
// Define the expected output structure
168+
let target_output_schema = StructSchema {
169+
fields: Arc::new(vec![
170+
FieldSchema::new(
171+
"extracted_field_name",
172+
make_output_type(BasicValueType::Str),
173+
),
174+
FieldSchema::new(
175+
"extracted_field_value",
176+
make_output_type(BasicValueType::Int64),
177+
),
178+
]),
179+
description: Some("A test structure for extraction".into()),
180+
};
181+
182+
let output_type_spec = EnrichedValueType {
183+
typ: ValueType::Struct(target_output_schema.clone()),
184+
nullable: false,
185+
attrs: Arc::new(BTreeMap::new()),
186+
};
187+
188+
let spec = Spec {
189+
llm_spec: LlmSpec {
190+
api_type: crate::llm::LlmApiType::OpenAi,
191+
model: "gpt-4o".to_string(),
192+
address: None,
193+
},
194+
output_type: output_type_spec,
195+
instruction: Some("Extract the name and value from the text. The name is a string, the value is an integer.".to_string()),
196+
};
197+
198+
let factory = Arc::new(Factory);
199+
let text_content = "The item is called 'CocoIndex Test' and its value is 42.";
200+
201+
let input_args_values = vec![text_content.to_string().into()];
202+
203+
let input_arg_schemas = vec![build_arg_schema("text", BasicValueType::Str)];
204+
205+
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
206+
207+
if result.is_err() {
208+
eprintln!(
209+
"test_extract_by_llm: test_flow_function returned error (potentially expected for evaluate): {:?}",
210+
result.as_ref().err()
211+
);
212+
}
213+
214+
assert!(
215+
result.is_ok(),
216+
"test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}",
217+
result.err()
218+
);
219+
220+
let value = result.unwrap();
221+
222+
match value {
223+
Value::Struct(field_values) => {
224+
assert_eq!(
225+
field_values.fields.len(),
226+
target_output_schema.fields.len(),
227+
"Mismatched number of fields in output struct"
228+
);
229+
for (idx, field_schema) in target_output_schema.fields.iter().enumerate() {
230+
match (&field_values.fields[idx], &field_schema.value_type.typ) {
231+
(
232+
Value::Basic(BasicValue::Str(_)),
233+
ValueType::Basic(BasicValueType::Str),
234+
) => {}
235+
(
236+
Value::Basic(BasicValue::Int64(_)),
237+
ValueType::Basic(BasicValueType::Int64),
238+
) => {}
239+
(val, expected_type) => panic!(
240+
"Field '{}' type mismatch. Got {:?}, expected type compatible with {:?}",
241+
field_schema.name,
242+
val.kind(),
243+
expected_type
244+
),
245+
}
246+
}
247+
}
248+
_ => panic!("Expected Value::Struct, got {:?}", value),
249+
}
250+
}
251+
}

src/ops/functions/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@ pub mod embed_text;
22
pub mod extract_by_llm;
33
pub mod parse_json;
44
pub mod split_recursively;
5+
6+
#[cfg(test)]
7+
mod test_utils;

src/ops/functions/parse_json.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,46 @@ impl SimpleFunctionFactoryBase for Factory {
102102
Ok(Box::new(Executor { args }))
103103
}
104104
}
105+
106+
#[cfg(test)]
107+
mod tests {
108+
use super::*;
109+
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
110+
use serde_json::json;
111+
112+
#[tokio::test]
113+
async fn test_parse_json() {
114+
let spec = EmptySpec {};
115+
116+
let factory = Arc::new(Factory);
117+
let json_string_content = r#"{"city": "Magdeburg"}"#;
118+
let lang_value: Value = "json".to_string().into();
119+
120+
let input_args_values = vec![json_string_content.to_string().into(), lang_value.clone()];
121+
122+
let input_arg_schemas = vec![
123+
build_arg_schema("text", BasicValueType::Str),
124+
build_arg_schema("language", BasicValueType::Str),
125+
];
126+
127+
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
128+
129+
assert!(
130+
result.is_ok(),
131+
"test_flow_function failed: {:?}",
132+
result.err()
133+
);
134+
let value = result.unwrap();
135+
136+
match value {
137+
Value::Basic(BasicValue::Json(arc_json_value)) => {
138+
let expected_json = json!({"city": "Magdeburg"});
139+
assert_eq!(
140+
*arc_json_value, expected_json,
141+
"Parsed JSON value mismatch with specified language"
142+
);
143+
}
144+
_ => panic!("Expected Value::Basic(BasicValue::Json), got {:?}", value),
145+
}
146+
}
147+
}

src/ops/functions/split_recursively.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,7 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
10311031
#[cfg(test)]
10321032
mod tests {
10331033
use super::*;
1034+
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
10341035

10351036
// Helper function to assert chunk text and its consistency with the range within the original text.
10361037
fn assert_chunk_text_consistency(
@@ -1072,6 +1073,64 @@ mod tests {
10721073
}
10731074
}
10741075

1076+
#[tokio::test]
1077+
async fn test_split_recursively() {
1078+
let spec = Spec {
1079+
custom_languages: vec![],
1080+
};
1081+
let factory = Arc::new(Factory);
1082+
let text_content = "Linea 1.\nLinea 2.\n\nLinea 3.";
1083+
1084+
let input_args_values = vec![
1085+
text_content.to_string().into(),
1086+
(15i64).into(),
1087+
(5i64).into(),
1088+
(0i64).into(),
1089+
Value::Null,
1090+
];
1091+
1092+
let input_arg_schemas = vec![
1093+
build_arg_schema("text", BasicValueType::Str),
1094+
build_arg_schema("chunk_size", BasicValueType::Int64),
1095+
build_arg_schema("min_chunk_size", BasicValueType::Int64),
1096+
build_arg_schema("chunk_overlap", BasicValueType::Int64),
1097+
build_arg_schema("language", BasicValueType::Str),
1098+
];
1099+
1100+
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
1101+
1102+
assert!(
1103+
result.is_ok(),
1104+
"test_flow_function failed: {:?}",
1105+
result.err()
1106+
);
1107+
let value = result.unwrap();
1108+
1109+
match value {
1110+
Value::KTable(table) => {
1111+
let expected_chunks = vec![
1112+
(RangeValue::new(0, 8), "Linea 1."),
1113+
(RangeValue::new(9, 17), "Linea 2."),
1114+
(RangeValue::new(19, 27), "Linea 3."),
1115+
];
1116+
1117+
for (range, expected_text) in expected_chunks {
1118+
let key: KeyValue = range.into();
1119+
match table.get(&key) {
1120+
Some(scope_value_ref) => {
1121+
let chunk_text = scope_value_ref.0.fields[0]
1122+
.as_str()
1123+
.expect(&format!("Chunk text not a string for key {:?}", key));
1124+
assert_eq!(**chunk_text, *expected_text);
1125+
}
1126+
None => panic!("Expected row value for key {:?}, not found", key),
1127+
}
1128+
}
1129+
}
1130+
other => panic!("Expected Value::KTable, got {:?}", other),
1131+
}
1132+
}
1133+
10751134
#[test]
10761135
fn test_translate_bytes_to_chars_simple() {
10771136
let text = "abc😄def";
@@ -1179,6 +1238,7 @@ mod tests {
11791238
assert_chunk_text_consistency(text2, &chunks2[0], "A very very long", "Test 2, Chunk 0");
11801239
assert!(chunks2[0].text.len() <= 20);
11811240
}
1241+
11821242
#[test]
11831243
fn test_basic_split_with_overlap() {
11841244
let text = "This is a test text that is a bit longer to see how the overlap works.";
@@ -1198,6 +1258,7 @@ mod tests {
11981258
assert!(chunks[0].text.len() <= 25);
11991259
}
12001260
}
1261+
12011262
#[test]
12021263
fn test_split_trims_whitespace() {
12031264
let text = " \n First chunk. \n\n Second chunk with spaces at the end. \n";

src/ops/functions/test_utils.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use crate::builder::plan::{
2+
AnalyzedFieldReference, AnalyzedLocalFieldReference, AnalyzedValueMapping,
3+
};
4+
use crate::ops::sdk::{
5+
AuthRegistry, BasicValueType, EnrichedValueType, FlowInstanceContext, OpArgSchema,
6+
OpArgsResolver, SimpleFunctionExecutor, SimpleFunctionFactoryBase, Value, make_output_type,
7+
};
8+
use anyhow::Result;
9+
use serde::de::DeserializeOwned;
10+
use std::sync::Arc;
11+
12+
// This function builds an argument schema for a flow function.
13+
pub fn build_arg_schema(
14+
name: &str,
15+
value_type: BasicValueType,
16+
) -> (Option<&str>, EnrichedValueType) {
17+
(Some(name), make_output_type(value_type))
18+
}
19+
20+
// This function tests a flow function by providing a spec, input argument schemas, and values.
21+
pub async fn test_flow_function<S, R, F>(
22+
factory: Arc<F>,
23+
spec: S,
24+
input_arg_schemas: Vec<(Option<&str>, EnrichedValueType)>,
25+
input_arg_values: Vec<Value>,
26+
) -> Result<Value>
27+
where
28+
S: DeserializeOwned + Send + Sync + 'static,
29+
R: Send + Sync + 'static,
30+
F: SimpleFunctionFactoryBase<Spec = S, ResolvedArgs = R> + ?Sized,
31+
{
32+
// 1. Construct OpArgSchema
33+
let op_arg_schemas: Vec<OpArgSchema> = input_arg_schemas
34+
.into_iter()
35+
.enumerate()
36+
.map(|(idx, (name, value_type))| OpArgSchema {
37+
name: name.map_or(crate::base::spec::OpArgName(None), |n| {
38+
crate::base::spec::OpArgName(Some(n.to_string()))
39+
}),
40+
value_type,
41+
analyzed_value: AnalyzedValueMapping::Field(AnalyzedFieldReference {
42+
local: AnalyzedLocalFieldReference {
43+
fields_idx: vec![idx as u32],
44+
},
45+
scope_up_level: 0,
46+
}),
47+
})
48+
.collect();
49+
50+
// 2. Resolve Schema & Args
51+
let mut args_resolver = OpArgsResolver::new(&op_arg_schemas)?;
52+
let context = Arc::new(FlowInstanceContext {
53+
flow_instance_name: "test_flow_function".to_string(),
54+
auth_registry: Arc::new(AuthRegistry::default()),
55+
py_exec_ctx: None,
56+
});
57+
58+
let (resolved_args_from_schema, _output_schema): (R, EnrichedValueType) = factory
59+
.resolve_schema(&spec, &mut args_resolver, &context)
60+
.await?;
61+
62+
args_resolver.done()?;
63+
64+
// 3. Build Executor
65+
let executor: Box<dyn SimpleFunctionExecutor> = factory
66+
.build_executor(spec, resolved_args_from_schema, Arc::clone(&context))
67+
.await?;
68+
69+
// 4. Evaluate
70+
let result = executor.evaluate(input_arg_values).await?;
71+
72+
Ok(result)
73+
}

0 commit comments

Comments
 (0)