Skip to content

Commit 3b302e3

Browse files
committed
test: add tests for all flow functions
1 parent 169c50b commit 3b302e3

File tree

5 files changed

+328
-4
lines changed

5 files changed

+328
-4
lines changed

src/ops/functions/embed_text.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,74 @@ impl SimpleFunctionFactoryBase for Factory {
9595
pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
9696
Factory.register(registry)
9797
}
98+
99+
#[cfg(test)]
100+
mod tests {
101+
use super::*;
102+
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
103+
use serde_json::json;
104+
105+
#[tokio::test]
106+
#[ignore = "This test requires OpenAI API key or a configured local LLM and may make network calls."]
107+
async fn test_embed_text_with_util() {
108+
let context = Arc::new(FlowInstanceContext {
109+
flow_instance_name: "test_embed_text_flow".to_string(),
110+
auth_registry: Arc::new(AuthRegistry::default()),
111+
py_exec_ctx: None,
112+
});
113+
114+
// Using OpenAI as an example.
115+
let spec_json = json!({
116+
"api_type": "OpenAi",
117+
"model": "text-embedding-ada-002",
118+
});
119+
120+
let factory = Arc::new(Factory);
121+
let text_content = "CocoIndex is a performant data transformation framework for AI.";
122+
123+
let input_args_values = vec![text_content.to_string().into()];
124+
125+
let input_arg_schemas = vec![build_arg_schema(
126+
"text",
127+
text_content.to_string().into(),
128+
BasicValueType::Str,
129+
)];
130+
131+
let result = test_flow_function(
132+
factory,
133+
spec_json,
134+
input_arg_schemas,
135+
input_args_values,
136+
context,
137+
)
138+
.await;
139+
140+
if result.is_err() {
141+
eprintln!(
142+
"test_embed_text: test_flow_function returned error (potentially expected for evaluate): {:?}",
143+
result.as_ref().err()
144+
);
145+
}
146+
147+
assert!(
148+
result.is_ok(),
149+
"test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}",
150+
result.err()
151+
);
152+
153+
let value = result.unwrap();
154+
155+
match value {
156+
Value::Basic(BasicValue::Vector(arc_vec)) => {
157+
assert_eq!(arc_vec.len(), 1536, "Embedding vector dimension mismatch");
158+
for item in arc_vec.iter() {
159+
match item {
160+
BasicValue::Float32(_) => {}
161+
_ => panic!("Embedding vector element is not Float32: {:?}", item),
162+
}
163+
}
164+
}
165+
_ => panic!("Expected Value::Basic(BasicValue::Vector), got {:?}", value),
166+
}
167+
}
168+
}

src/ops/functions/extract_by_llm.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,121 @@ impl SimpleFunctionFactoryBase for Factory {
155155
Ok(Box::new(Executor::new(spec, resolved_input_schema).await?))
156156
}
157157
}
158+
159+
#[cfg(test)]
160+
mod tests {
161+
use super::*;
162+
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
163+
use serde_json::json;
164+
165+
#[tokio::test]
166+
#[ignore = "This test requires an OpenAI API key or a configured local LLM and may make network calls."]
167+
async fn test_extract_by_llm_with_util() {
168+
let context = Arc::new(FlowInstanceContext {
169+
flow_instance_name: "test_extract_by_llm_flow".to_string(),
170+
auth_registry: Arc::new(AuthRegistry::default()),
171+
py_exec_ctx: None,
172+
});
173+
174+
// Define the expected output structure
175+
let target_output_schema = StructSchema {
176+
fields: Arc::new(vec![
177+
FieldSchema::new(
178+
"extracted_field_name",
179+
make_output_type(BasicValueType::Str),
180+
),
181+
FieldSchema::new(
182+
"extracted_field_value",
183+
make_output_type(BasicValueType::Int64),
184+
),
185+
]),
186+
description: Some("A test structure for extraction".into()),
187+
};
188+
189+
let output_type_spec = EnrichedValueType {
190+
typ: ValueType::Struct(target_output_schema.clone()),
191+
nullable: false,
192+
attrs: Arc::new(BTreeMap::new()),
193+
};
194+
195+
// Spec using OpenAI as an example.
196+
let spec_json = json!({
197+
"llm_spec": {
198+
"api_type": "OpenAi",
199+
"model": "gpt-4o",
200+
"address": null,
201+
"api_key_auth": null,
202+
"max_tokens": 100,
203+
"temperature": 0.0,
204+
"top_p": null,
205+
"params": {}
206+
},
207+
"output_type": output_type_spec,
208+
"instruction": "Extract the name and value from the text. The name is a string, the value is an integer."
209+
});
210+
211+
let factory = Arc::new(Factory);
212+
let text_content = "The item is called 'CocoIndex Test' and its value is 42.";
213+
214+
let input_args_values = vec![text_content.to_string().into()];
215+
216+
let input_arg_schemas = vec![build_arg_schema(
217+
"text",
218+
text_content.to_string().into(),
219+
BasicValueType::Str,
220+
)];
221+
222+
let result = test_flow_function(
223+
factory,
224+
spec_json,
225+
input_arg_schemas,
226+
input_args_values,
227+
context,
228+
)
229+
.await;
230+
231+
if result.is_err() {
232+
eprintln!(
233+
"test_extract_by_llm_with_util: test_flow_function returned error (potentially expected for evaluate): {:?}",
234+
result.as_ref().err()
235+
);
236+
}
237+
238+
assert!(
239+
result.is_ok(),
240+
"test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}",
241+
result.err()
242+
);
243+
244+
let value = result.unwrap();
245+
246+
match value {
247+
Value::Struct(field_values) => {
248+
assert_eq!(
249+
field_values.fields.len(),
250+
target_output_schema.fields.len(),
251+
"Mismatched number of fields in output struct"
252+
);
253+
for (idx, field_schema) in target_output_schema.fields.iter().enumerate() {
254+
match (&field_values.fields[idx], &field_schema.value_type.typ) {
255+
(
256+
Value::Basic(BasicValue::Str(_)),
257+
ValueType::Basic(BasicValueType::Str),
258+
) => {}
259+
(
260+
Value::Basic(BasicValue::Int64(_)),
261+
ValueType::Basic(BasicValueType::Int64),
262+
) => {}
263+
(val, expected_type) => panic!(
264+
"Field '{}' type mismatch. Got {:?}, expected type compatible with {:?}",
265+
field_schema.name,
266+
val.kind(),
267+
expected_type
268+
),
269+
}
270+
}
271+
}
272+
_ => panic!("Expected Value::Struct, got {:?}", value),
273+
}
274+
}
275+
}

src/ops/functions/parse_json.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,63 @@ impl SimpleFunctionFactoryBase for Factory {
102102
Ok(Box::new(Executor { args }))
103103
}
104104
}
105+
106+
#[cfg(test)]
107+
mod tests {
108+
use super::*;
109+
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
110+
use serde_json::json;
111+
112+
#[tokio::test]
113+
async fn test_parse_json_util() {
114+
let context = Arc::new(FlowInstanceContext {
115+
flow_instance_name: "test_parse_json_flow".to_string(),
116+
auth_registry: Arc::new(AuthRegistry::default()),
117+
py_exec_ctx: None,
118+
});
119+
120+
let spec_json = json!({});
121+
122+
let factory = Arc::new(Factory);
123+
let json_string_content = r#"{"city": "Magdeburg"}"#;
124+
let lang_value: Value = "json".to_string().into();
125+
126+
let input_args_values = vec![json_string_content.to_string().into(), lang_value.clone()];
127+
128+
let input_arg_schemas = vec![
129+
build_arg_schema(
130+
"text",
131+
json_string_content.to_string().into(),
132+
BasicValueType::Str,
133+
),
134+
build_arg_schema("language", lang_value, BasicValueType::Str),
135+
];
136+
137+
let result = test_flow_function(
138+
factory,
139+
spec_json,
140+
input_arg_schemas,
141+
input_args_values,
142+
context,
143+
)
144+
.await;
145+
146+
assert!(
147+
result.is_ok(),
148+
"test_flow_function failed: {:?}",
149+
result.err()
150+
);
151+
let value = result.unwrap();
152+
153+
match value {
154+
Value::Basic(BasicValue::Json(arc_json_value)) => {
155+
let expected_json = json!({"city": "Magdeburg"});
156+
assert_eq!(
157+
*arc_json_value, expected_json,
158+
"Parsed JSON value mismatch with specified language"
159+
);
160+
}
161+
_ => panic!("Expected Value::Basic(BasicValue::Json), got {:?}", value),
162+
}
163+
}
164+
}

src/ops/functions/split_recursively.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,8 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
10311031
#[cfg(test)]
10321032
mod tests {
10331033
use super::*;
1034+
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
1035+
use serde_json::json;
10341036

10351037
// Helper function to assert chunk text and its consistency with the range within the original text.
10361038
fn assert_chunk_text_consistency(
@@ -1072,6 +1074,75 @@ mod tests {
10721074
}
10731075
}
10741076

1077+
#[tokio::test]
1078+
async fn test_split_recursively_with_util() {
1079+
let context = Arc::new(FlowInstanceContext {
1080+
flow_instance_name: "test_parse_recursively_flow".to_string(),
1081+
auth_registry: Arc::new(AuthRegistry::default()),
1082+
py_exec_ctx: None,
1083+
});
1084+
1085+
let spec_json = json!({});
1086+
let factory = Arc::new(Factory);
1087+
let text_content = "Linea 1.\nLinea 2.\n\nLinea 3.";
1088+
1089+
let input_args_values = vec![
1090+
text_content.to_string().into(),
1091+
(15i64).into(),
1092+
(5i64).into(),
1093+
(0i64).into(),
1094+
Value::Null,
1095+
];
1096+
1097+
let input_arg_schemas = vec![
1098+
build_arg_schema("text", text_content.to_string().into(), BasicValueType::Str),
1099+
build_arg_schema("chunk_size", (15i64).into(), BasicValueType::Int64),
1100+
build_arg_schema("min_chunk_size", (5i64).into(), BasicValueType::Int64),
1101+
build_arg_schema("chunk_overlap", (0i64).into(), BasicValueType::Int64),
1102+
build_arg_schema("language", Value::Null, BasicValueType::Str),
1103+
];
1104+
1105+
let result = test_flow_function(
1106+
factory,
1107+
spec_json,
1108+
input_arg_schemas,
1109+
input_args_values,
1110+
context,
1111+
)
1112+
.await;
1113+
1114+
assert!(
1115+
result.is_ok(),
1116+
"test_flow_function failed: {:?}",
1117+
result.err()
1118+
);
1119+
let value = result.unwrap();
1120+
1121+
match value {
1122+
Value::KTable(table) => {
1123+
let expected_chunks = vec![
1124+
(RangeValue::new(0, 8), "Linea 1."),
1125+
(RangeValue::new(9, 17), "Linea 2."),
1126+
(RangeValue::new(19, 27), "Linea 3."),
1127+
];
1128+
1129+
for (range, expected_text) in expected_chunks {
1130+
let key: KeyValue = range.into();
1131+
match table.get(&key) {
1132+
Some(scope_value_ref) => {
1133+
let chunk_text = scope_value_ref.0.fields[0]
1134+
.as_str()
1135+
.expect(&format!("Chunk text not a string for key {:?}", key));
1136+
assert_eq!(**chunk_text, *expected_text);
1137+
}
1138+
None => panic!("Expected row value for key {:?}, not found", key),
1139+
}
1140+
}
1141+
}
1142+
other => panic!("Expected Value::KTable, got {:?}", other),
1143+
}
1144+
}
1145+
10751146
#[test]
10761147
fn test_translate_bytes_to_chars_simple() {
10771148
let text = "abc😄def";

src/ops/functions/test_utils.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
use crate::builder::plan::AnalyzedValueMapping;
22
use crate::ops::sdk::{
3-
EnrichedValueType, FlowInstanceContext, OpArgSchema, OpArgsResolver, SimpleFunctionExecutor,
4-
SimpleFunctionFactoryBase, Value,
3+
BasicValueType, EnrichedValueType, FlowInstanceContext, OpArgSchema, OpArgsResolver,
4+
SimpleFunctionExecutor, SimpleFunctionFactoryBase, Value, make_output_type,
55
};
66
use anyhow::Result;
77
use serde::de::DeserializeOwned;
88
use serde_json::Value as JsonValue;
99
use std::sync::Arc;
1010

11-
// This function provides a helper to create OpArgSchema for literal values.
12-
pub fn new_literal_op_arg_schema(
11+
fn new_literal_op_arg_schema(
1312
name: Option<&str>,
1413
value: Value,
1514
value_type: EnrichedValueType,
@@ -23,6 +22,11 @@ pub fn new_literal_op_arg_schema(
2322
}
2423
}
2524

25+
// This function provides a helper to create OpArgSchema for literal values.
26+
pub fn build_arg_schema(name: &str, value: Value, value_type: BasicValueType) -> OpArgSchema {
27+
new_literal_op_arg_schema(Some(name), value, make_output_type(value_type))
28+
}
29+
2630
// This function tests a flow function by providing a spec, input argument schemas, and values.
2731
pub async fn test_flow_function<S, R, F>(
2832
factory: Arc<F>,

0 commit comments

Comments
 (0)