diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index 9e8a49a89..c60b0f2af 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -95,3 +95,58 @@ impl SimpleFunctionFactoryBase for Factory { pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> { Factory.register(registry) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function}; + + #[tokio::test] + #[ignore = "This test requires OpenAI API key or a configured local LLM and may make network calls."] + async fn test_embed_text() { + let spec = Spec { + api_type: LlmApiType::OpenAi, + model: "text-embedding-ada-002".to_string(), + address: None, + output_dimension: None, + task_type: None, + }; + + let factory = Arc::new(Factory); + let text_content = "CocoIndex is a performant data transformation framework for AI."; + + let input_args_values = vec![text_content.to_string().into()]; + + let input_arg_schemas = vec![build_arg_schema("text", BasicValueType::Str)]; + + let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await; + + if result.is_err() { + eprintln!( + "test_embed_text: test_flow_function returned error (potentially expected for evaluate): {:?}", + result.as_ref().err() + ); + } + + assert!( + result.is_ok(), + "test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}", + result.err() + ); + + let value = result.unwrap(); + + match value { + Value::Basic(BasicValue::Vector(arc_vec)) => { + assert_eq!(arc_vec.len(), 1536, "Embedding vector dimension mismatch"); + for item in arc_vec.iter() { + match item { + BasicValue::Float32(_) => {} + _ => panic!("Embedding vector element is not Float32: {:?}", item), + } + } + } + _ => panic!("Expected Value::Basic(BasicValue::Vector), got {:?}", value), + } + } +} diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 5749ba438..5a399f946 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -155,3 +155,97 @@ impl SimpleFunctionFactoryBase for Factory { Ok(Box::new(Executor::new(spec, resolved_input_schema).await?)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function}; + + #[tokio::test] + #[ignore = "This test requires an OpenAI API key or a configured local LLM and may make network calls."] + async fn test_extract_by_llm() { + // Define the expected output structure + let target_output_schema = StructSchema { + fields: Arc::new(vec![ + FieldSchema::new( + "extracted_field_name", + make_output_type(BasicValueType::Str), + ), + FieldSchema::new( + "extracted_field_value", + make_output_type(BasicValueType::Int64), + ), + ]), + description: Some("A test structure for extraction".into()), + }; + + let output_type_spec = EnrichedValueType { + typ: ValueType::Struct(target_output_schema.clone()), + nullable: false, + attrs: Arc::new(BTreeMap::new()), + }; + + let spec = Spec { + llm_spec: LlmSpec { + api_type: crate::llm::LlmApiType::OpenAi, + model: "gpt-4o".to_string(), + address: None, + }, + output_type: output_type_spec, + instruction: Some("Extract the name and value from the text. The name is a string, the value is an integer.".to_string()), + }; + + let factory = Arc::new(Factory); + let text_content = "The item is called 'CocoIndex Test' and its value is 42."; + + let input_args_values = vec![text_content.to_string().into()]; + + let input_arg_schemas = vec![build_arg_schema("text", BasicValueType::Str)]; + + let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await; + + if result.is_err() { + eprintln!( + "test_extract_by_llm: test_flow_function returned error (potentially expected for evaluate): {:?}", + result.as_ref().err() + ); + } + + assert!( + result.is_ok(), + "test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}", + result.err() + ); + + let value = result.unwrap(); + + match value { + Value::Struct(field_values) => { + assert_eq!( + field_values.fields.len(), + target_output_schema.fields.len(), + "Mismatched number of fields in output struct" + ); + for (idx, field_schema) in target_output_schema.fields.iter().enumerate() { + match (&field_values.fields[idx], &field_schema.value_type.typ) { + ( + Value::Basic(BasicValue::Str(_)), + ValueType::Basic(BasicValueType::Str), + ) => {} + ( + Value::Basic(BasicValue::Int64(_)), + ValueType::Basic(BasicValueType::Int64), + ) => {} + (val, expected_type) => panic!( + "Field '{}' type mismatch. Got {:?}, expected type compatible with {:?}", + field_schema.name, + val.kind(), + expected_type + ), + } + } + } + _ => panic!("Expected Value::Struct, got {:?}", value), + } + } +} diff --git a/src/ops/functions/mod.rs b/src/ops/functions/mod.rs index 7f08d308d..0e135e9f6 100644 --- a/src/ops/functions/mod.rs +++ b/src/ops/functions/mod.rs @@ -2,3 +2,6 @@ pub mod embed_text; pub mod extract_by_llm; pub mod parse_json; pub mod split_recursively; + +#[cfg(test)] +mod test_utils; diff --git a/src/ops/functions/parse_json.rs b/src/ops/functions/parse_json.rs index 2410a1078..946f64764 100644 --- a/src/ops/functions/parse_json.rs +++ b/src/ops/functions/parse_json.rs @@ -102,3 +102,46 @@ impl SimpleFunctionFactoryBase for Factory { Ok(Box::new(Executor { args })) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function}; + use serde_json::json; + + #[tokio::test] + async fn test_parse_json() { + let spec = EmptySpec {}; + + let factory = Arc::new(Factory); + let json_string_content = r#"{"city": "Magdeburg"}"#; + let lang_value: Value = "json".to_string().into(); + + let input_args_values = vec![json_string_content.to_string().into(), lang_value.clone()]; + + let input_arg_schemas = vec![ + build_arg_schema("text", BasicValueType::Str), + build_arg_schema("language", BasicValueType::Str), + ]; + + let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await; + + assert!( + result.is_ok(), + "test_flow_function failed: {:?}", + result.err() + ); + let value = result.unwrap(); + + match value { + Value::Basic(BasicValue::Json(arc_json_value)) => { + let expected_json = json!({"city": "Magdeburg"}); + assert_eq!( + *arc_json_value, expected_json, + "Parsed JSON value mismatch with specified language" + ); + } + _ => panic!("Expected Value::Basic(BasicValue::Json), got {:?}", value), + } + } +} diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index 21060563b..af62b7196 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -1031,6 +1031,7 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> { #[cfg(test)] mod tests { use super::*; + use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function}; // Helper function to assert chunk text and its consistency with the range within the original text. fn assert_chunk_text_consistency( @@ -1072,6 +1073,64 @@ mod tests { } } + #[tokio::test] + async fn test_split_recursively() { + let spec = Spec { + custom_languages: vec![], + }; + let factory = Arc::new(Factory); + let text_content = "Linea 1.\nLinea 2.\n\nLinea 3."; + + let input_args_values = vec![ + text_content.to_string().into(), + (15i64).into(), + (5i64).into(), + (0i64).into(), + Value::Null, + ]; + + let input_arg_schemas = vec![ + build_arg_schema("text", BasicValueType::Str), + build_arg_schema("chunk_size", BasicValueType::Int64), + build_arg_schema("min_chunk_size", BasicValueType::Int64), + build_arg_schema("chunk_overlap", BasicValueType::Int64), + build_arg_schema("language", BasicValueType::Str), + ]; + + let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await; + + assert!( + result.is_ok(), + "test_flow_function failed: {:?}", + result.err() + ); + let value = result.unwrap(); + + match value { + Value::KTable(table) => { + let expected_chunks = vec![ + (RangeValue::new(0, 8), "Linea 1."), + (RangeValue::new(9, 17), "Linea 2."), + (RangeValue::new(19, 27), "Linea 3."), + ]; + + for (range, expected_text) in expected_chunks { + let key: KeyValue = range.into(); + match table.get(&key) { + Some(scope_value_ref) => { + let chunk_text = scope_value_ref.0.fields[0] + .as_str() + .expect(&format!("Chunk text not a string for key {:?}", key)); + assert_eq!(**chunk_text, *expected_text); + } + None => panic!("Expected row value for key {:?}, not found", key), + } + } + } + other => panic!("Expected Value::KTable, got {:?}", other), + } + } + #[test] fn test_translate_bytes_to_chars_simple() { let text = "abc😄def"; @@ -1179,6 +1238,7 @@ mod tests { assert_chunk_text_consistency(text2, &chunks2[0], "A very very long", "Test 2, Chunk 0"); assert!(chunks2[0].text.len() <= 20); } + #[test] fn test_basic_split_with_overlap() { let text = "This is a test text that is a bit longer to see how the overlap works."; @@ -1198,6 +1258,7 @@ mod tests { assert!(chunks[0].text.len() <= 25); } } + #[test] fn test_split_trims_whitespace() { let text = " \n First chunk. \n\n Second chunk with spaces at the end. \n"; diff --git a/src/ops/functions/test_utils.rs b/src/ops/functions/test_utils.rs new file mode 100644 index 000000000..41801f1a2 --- /dev/null +++ b/src/ops/functions/test_utils.rs @@ -0,0 +1,73 @@ +use crate::builder::plan::{ + AnalyzedFieldReference, AnalyzedLocalFieldReference, AnalyzedValueMapping, +}; +use crate::ops::sdk::{ + AuthRegistry, BasicValueType, EnrichedValueType, FlowInstanceContext, OpArgSchema, + OpArgsResolver, SimpleFunctionExecutor, SimpleFunctionFactoryBase, Value, make_output_type, +}; +use anyhow::Result; +use serde::de::DeserializeOwned; +use std::sync::Arc; + +// This function builds an argument schema for a flow function. +pub fn build_arg_schema( + name: &str, + value_type: BasicValueType, +) -> (Option<&str>, EnrichedValueType) { + (Some(name), make_output_type(value_type)) +} + +// This function tests a flow function by providing a spec, input argument schemas, and values. +pub async fn test_flow_function( + factory: Arc, + spec: S, + input_arg_schemas: Vec<(Option<&str>, EnrichedValueType)>, + input_arg_values: Vec, +) -> Result +where + S: DeserializeOwned + Send + Sync + 'static, + R: Send + Sync + 'static, + F: SimpleFunctionFactoryBase + ?Sized, +{ + // 1. Construct OpArgSchema + let op_arg_schemas: Vec = input_arg_schemas + .into_iter() + .enumerate() + .map(|(idx, (name, value_type))| OpArgSchema { + name: name.map_or(crate::base::spec::OpArgName(None), |n| { + crate::base::spec::OpArgName(Some(n.to_string())) + }), + value_type, + analyzed_value: AnalyzedValueMapping::Field(AnalyzedFieldReference { + local: AnalyzedLocalFieldReference { + fields_idx: vec![idx as u32], + }, + scope_up_level: 0, + }), + }) + .collect(); + + // 2. Resolve Schema & Args + let mut args_resolver = OpArgsResolver::new(&op_arg_schemas)?; + let context = Arc::new(FlowInstanceContext { + flow_instance_name: "test_flow_function".to_string(), + auth_registry: Arc::new(AuthRegistry::default()), + py_exec_ctx: None, + }); + + let (resolved_args_from_schema, _output_schema): (R, EnrichedValueType) = factory + .resolve_schema(&spec, &mut args_resolver, &context) + .await?; + + args_resolver.done()?; + + // 3. Build Executor + let executor: Box = factory + .build_executor(spec, resolved_args_from_schema, Arc::clone(&context)) + .await?; + + // 4. Evaluate + let result = executor.evaluate(input_arg_values).await?; + + Ok(result) +}