Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/ops/functions/embed_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,58 @@ impl SimpleFunctionFactoryBase for Factory {
pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
Factory.register(registry)
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};

#[tokio::test]
#[ignore = "This test requires OpenAI API key or a configured local LLM and may make network calls."]
async fn test_embed_text() {
let spec = Spec {
api_type: LlmApiType::OpenAi,
model: "text-embedding-ada-002".to_string(),
address: None,
output_dimension: None,
task_type: None,
};

let factory = Arc::new(Factory);
let text_content = "CocoIndex is a performant data transformation framework for AI.";

let input_args_values = vec![text_content.to_string().into()];

let input_arg_schemas = vec![build_arg_schema("text", BasicValueType::Str)];

let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;

if result.is_err() {
eprintln!(
"test_embed_text: test_flow_function returned error (potentially expected for evaluate): {:?}",
result.as_ref().err()
);
}

assert!(
result.is_ok(),
"test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}",
result.err()
);

let value = result.unwrap();

match value {
Value::Basic(BasicValue::Vector(arc_vec)) => {
assert_eq!(arc_vec.len(), 1536, "Embedding vector dimension mismatch");
for item in arc_vec.iter() {
match item {
BasicValue::Float32(_) => {}
_ => panic!("Embedding vector element is not Float32: {:?}", item),
}
}
}
_ => panic!("Expected Value::Basic(BasicValue::Vector), got {:?}", value),
}
}
}
94 changes: 94 additions & 0 deletions src/ops/functions/extract_by_llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,97 @@ impl SimpleFunctionFactoryBase for Factory {
Ok(Box::new(Executor::new(spec, resolved_input_schema).await?))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};

#[tokio::test]
#[ignore = "This test requires an OpenAI API key or a configured local LLM and may make network calls."]
async fn test_extract_by_llm() {
// Define the expected output structure
let target_output_schema = StructSchema {
fields: Arc::new(vec![
FieldSchema::new(
"extracted_field_name",
make_output_type(BasicValueType::Str),
),
FieldSchema::new(
"extracted_field_value",
make_output_type(BasicValueType::Int64),
),
]),
description: Some("A test structure for extraction".into()),
};

let output_type_spec = EnrichedValueType {
typ: ValueType::Struct(target_output_schema.clone()),
nullable: false,
attrs: Arc::new(BTreeMap::new()),
};

let spec = Spec {
llm_spec: LlmSpec {
api_type: crate::llm::LlmApiType::OpenAi,
model: "gpt-4o".to_string(),
address: None,
},
output_type: output_type_spec,
instruction: Some("Extract the name and value from the text. The name is a string, the value is an integer.".to_string()),
};

let factory = Arc::new(Factory);
let text_content = "The item is called 'CocoIndex Test' and its value is 42.";

let input_args_values = vec![text_content.to_string().into()];

let input_arg_schemas = vec![build_arg_schema("text", BasicValueType::Str)];

let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;

if result.is_err() {
eprintln!(
"test_extract_by_llm: test_flow_function returned error (potentially expected for evaluate): {:?}",
result.as_ref().err()
);
}

assert!(
result.is_ok(),
"test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}",
result.err()
);

let value = result.unwrap();

match value {
Value::Struct(field_values) => {
assert_eq!(
field_values.fields.len(),
target_output_schema.fields.len(),
"Mismatched number of fields in output struct"
);
for (idx, field_schema) in target_output_schema.fields.iter().enumerate() {
match (&field_values.fields[idx], &field_schema.value_type.typ) {
(
Value::Basic(BasicValue::Str(_)),
ValueType::Basic(BasicValueType::Str),
) => {}
(
Value::Basic(BasicValue::Int64(_)),
ValueType::Basic(BasicValueType::Int64),
) => {}
(val, expected_type) => panic!(
"Field '{}' type mismatch. Got {:?}, expected type compatible with {:?}",
field_schema.name,
val.kind(),
expected_type
),
}
}
}
_ => panic!("Expected Value::Struct, got {:?}", value),
}
}
}
3 changes: 3 additions & 0 deletions src/ops/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ pub mod embed_text;
pub mod extract_by_llm;
pub mod parse_json;
pub mod split_recursively;

#[cfg(test)]
mod test_utils;
43 changes: 43 additions & 0 deletions src/ops/functions/parse_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,46 @@ impl SimpleFunctionFactoryBase for Factory {
Ok(Box::new(Executor { args }))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
use serde_json::json;

#[tokio::test]
async fn test_parse_json() {
let spec = EmptySpec {};

let factory = Arc::new(Factory);
let json_string_content = r#"{"city": "Magdeburg"}"#;
let lang_value: Value = "json".to_string().into();

let input_args_values = vec![json_string_content.to_string().into(), lang_value.clone()];

let input_arg_schemas = vec![
build_arg_schema("text", BasicValueType::Str),
build_arg_schema("language", BasicValueType::Str),
];

let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;

assert!(
result.is_ok(),
"test_flow_function failed: {:?}",
result.err()
);
let value = result.unwrap();

match value {
Value::Basic(BasicValue::Json(arc_json_value)) => {
let expected_json = json!({"city": "Magdeburg"});
assert_eq!(
*arc_json_value, expected_json,
"Parsed JSON value mismatch with specified language"
);
}
_ => panic!("Expected Value::Basic(BasicValue::Json), got {:?}", value),
}
}
}
61 changes: 61 additions & 0 deletions src/ops/functions/split_recursively.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,7 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};

// Helper function to assert chunk text and its consistency with the range within the original text.
fn assert_chunk_text_consistency(
Expand Down Expand Up @@ -1072,6 +1073,64 @@ mod tests {
}
}

#[tokio::test]
async fn test_split_recursively() {
let spec = Spec {
custom_languages: vec![],
};
let factory = Arc::new(Factory);
let text_content = "Linea 1.\nLinea 2.\n\nLinea 3.";

let input_args_values = vec![
text_content.to_string().into(),
(15i64).into(),
(5i64).into(),
(0i64).into(),
Value::Null,
];

let input_arg_schemas = vec![
build_arg_schema("text", BasicValueType::Str),
build_arg_schema("chunk_size", BasicValueType::Int64),
build_arg_schema("min_chunk_size", BasicValueType::Int64),
build_arg_schema("chunk_overlap", BasicValueType::Int64),
build_arg_schema("language", BasicValueType::Str),
];

let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;

assert!(
result.is_ok(),
"test_flow_function failed: {:?}",
result.err()
);
let value = result.unwrap();

match value {
Value::KTable(table) => {
let expected_chunks = vec![
(RangeValue::new(0, 8), "Linea 1."),
(RangeValue::new(9, 17), "Linea 2."),
(RangeValue::new(19, 27), "Linea 3."),
];

for (range, expected_text) in expected_chunks {
let key: KeyValue = range.into();
match table.get(&key) {
Some(scope_value_ref) => {
let chunk_text = scope_value_ref.0.fields[0]
.as_str()
.expect(&format!("Chunk text not a string for key {:?}", key));
assert_eq!(**chunk_text, *expected_text);
}
None => panic!("Expected row value for key {:?}, not found", key),
}
}
}
other => panic!("Expected Value::KTable, got {:?}", other),
}
}

#[test]
fn test_translate_bytes_to_chars_simple() {
let text = "abc😄def";
Expand Down Expand Up @@ -1179,6 +1238,7 @@ mod tests {
assert_chunk_text_consistency(text2, &chunks2[0], "A very very long", "Test 2, Chunk 0");
assert!(chunks2[0].text.len() <= 20);
}

#[test]
fn test_basic_split_with_overlap() {
let text = "This is a test text that is a bit longer to see how the overlap works.";
Expand All @@ -1198,6 +1258,7 @@ mod tests {
assert!(chunks[0].text.len() <= 25);
}
}

#[test]
fn test_split_trims_whitespace() {
let text = " \n First chunk. \n\n Second chunk with spaces at the end. \n";
Expand Down
73 changes: 73 additions & 0 deletions src/ops/functions/test_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use crate::builder::plan::{
AnalyzedFieldReference, AnalyzedLocalFieldReference, AnalyzedValueMapping,
};
use crate::ops::sdk::{
AuthRegistry, BasicValueType, EnrichedValueType, FlowInstanceContext, OpArgSchema,
OpArgsResolver, SimpleFunctionExecutor, SimpleFunctionFactoryBase, Value, make_output_type,
};
use anyhow::Result;
use serde::de::DeserializeOwned;
use std::sync::Arc;

// This function builds an argument schema for a flow function.
pub fn build_arg_schema(
name: &str,
value_type: BasicValueType,
) -> (Option<&str>, EnrichedValueType) {
(Some(name), make_output_type(value_type))
}

// This function tests a flow function by providing a spec, input argument schemas, and values.
pub async fn test_flow_function<S, R, F>(
factory: Arc<F>,
spec: S,
input_arg_schemas: Vec<(Option<&str>, EnrichedValueType)>,
input_arg_values: Vec<Value>,
) -> Result<Value>
where
S: DeserializeOwned + Send + Sync + 'static,
R: Send + Sync + 'static,
F: SimpleFunctionFactoryBase<Spec = S, ResolvedArgs = R> + ?Sized,
{
// 1. Construct OpArgSchema
let op_arg_schemas: Vec<OpArgSchema> = input_arg_schemas
.into_iter()
.enumerate()
.map(|(idx, (name, value_type))| OpArgSchema {
name: name.map_or(crate::base::spec::OpArgName(None), |n| {
crate::base::spec::OpArgName(Some(n.to_string()))
}),
value_type,
analyzed_value: AnalyzedValueMapping::Field(AnalyzedFieldReference {
local: AnalyzedLocalFieldReference {
fields_idx: vec![idx as u32],
},
scope_up_level: 0,
}),
})
.collect();

// 2. Resolve Schema & Args
let mut args_resolver = OpArgsResolver::new(&op_arg_schemas)?;
let context = Arc::new(FlowInstanceContext {
flow_instance_name: "test_flow_function".to_string(),
auth_registry: Arc::new(AuthRegistry::default()),
py_exec_ctx: None,
});

let (resolved_args_from_schema, _output_schema): (R, EnrichedValueType) = factory
.resolve_schema(&spec, &mut args_resolver, &context)
.await?;

args_resolver.done()?;

// 3. Build Executor
let executor: Box<dyn SimpleFunctionExecutor> = factory
.build_executor(spec, resolved_args_from_schema, Arc::clone(&context))
.await?;

// 4. Evaluate
let result = executor.evaluate(input_arg_values).await?;

Ok(result)
}