Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/ops/functions/embed_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,74 @@ impl SimpleFunctionFactoryBase for Factory {
pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
Factory.register(registry)
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
use serde_json::json;

#[tokio::test]
#[ignore = "This test requires OpenAI API key or a configured local LLM and may make network calls."]
async fn test_embed_text_with_util() {
let context = Arc::new(FlowInstanceContext {
flow_instance_name: "test_embed_text_flow".to_string(),
auth_registry: Arc::new(AuthRegistry::default()),
py_exec_ctx: None,
});

// Using OpenAI as an example.
let spec_json = json!({
"api_type": "OpenAi",
"model": "text-embedding-ada-002",
});

let factory = Arc::new(Factory);
let text_content = "CocoIndex is a performant data transformation framework for AI.";

let input_args_values = vec![text_content.to_string().into()];

let input_arg_schemas = vec![build_arg_schema(
"text",
text_content.to_string().into(),
BasicValueType::Str,
)];

let result = test_flow_function(
factory,
spec_json,
input_arg_schemas,
input_args_values,
context,
)
.await;

if result.is_err() {
eprintln!(
"test_embed_text: test_flow_function returned error (potentially expected for evaluate): {:?}",
result.as_ref().err()
);
}

assert!(
result.is_ok(),
"test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}",
result.err()
);

let value = result.unwrap();

match value {
Value::Basic(BasicValue::Vector(arc_vec)) => {
assert_eq!(arc_vec.len(), 1536, "Embedding vector dimension mismatch");
for item in arc_vec.iter() {
match item {
BasicValue::Float32(_) => {}
_ => panic!("Embedding vector element is not Float32: {:?}", item),
}
}
}
_ => panic!("Expected Value::Basic(BasicValue::Vector), got {:?}", value),
}
}
}
118 changes: 118 additions & 0 deletions src/ops/functions/extract_by_llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,121 @@ impl SimpleFunctionFactoryBase for Factory {
Ok(Box::new(Executor::new(spec, resolved_input_schema).await?))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
use serde_json::json;

#[tokio::test]
#[ignore = "This test requires an OpenAI API key or a configured local LLM and may make network calls."]
async fn test_extract_by_llm_with_util() {
let context = Arc::new(FlowInstanceContext {
flow_instance_name: "test_extract_by_llm_flow".to_string(),
auth_registry: Arc::new(AuthRegistry::default()),
py_exec_ctx: None,
});

// Define the expected output structure
let target_output_schema = StructSchema {
fields: Arc::new(vec![
FieldSchema::new(
"extracted_field_name",
make_output_type(BasicValueType::Str),
),
FieldSchema::new(
"extracted_field_value",
make_output_type(BasicValueType::Int64),
),
]),
description: Some("A test structure for extraction".into()),
};

let output_type_spec = EnrichedValueType {
typ: ValueType::Struct(target_output_schema.clone()),
nullable: false,
attrs: Arc::new(BTreeMap::new()),
};

// Spec using OpenAI as an example.
let spec_json = json!({
"llm_spec": {
"api_type": "OpenAi",
"model": "gpt-4o",
"address": null,
"api_key_auth": null,
"max_tokens": 100,
"temperature": 0.0,
"top_p": null,
"params": {}
},
"output_type": output_type_spec,
"instruction": "Extract the name and value from the text. The name is a string, the value is an integer."
});

let factory = Arc::new(Factory);
let text_content = "The item is called 'CocoIndex Test' and its value is 42.";

let input_args_values = vec![text_content.to_string().into()];

let input_arg_schemas = vec![build_arg_schema(
"text",
text_content.to_string().into(),
BasicValueType::Str,
)];

let result = test_flow_function(
factory,
spec_json,
input_arg_schemas,
input_args_values,
context,
)
.await;

if result.is_err() {
eprintln!(
"test_extract_by_llm_with_util: test_flow_function returned error (potentially expected for evaluate): {:?}",
result.as_ref().err()
);
}

assert!(
result.is_ok(),
"test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}",
result.err()
);

let value = result.unwrap();

match value {
Value::Struct(field_values) => {
assert_eq!(
field_values.fields.len(),
target_output_schema.fields.len(),
"Mismatched number of fields in output struct"
);
for (idx, field_schema) in target_output_schema.fields.iter().enumerate() {
match (&field_values.fields[idx], &field_schema.value_type.typ) {
(
Value::Basic(BasicValue::Str(_)),
ValueType::Basic(BasicValueType::Str),
) => {}
(
Value::Basic(BasicValue::Int64(_)),
ValueType::Basic(BasicValueType::Int64),
) => {}
(val, expected_type) => panic!(
"Field '{}' type mismatch. Got {:?}, expected type compatible with {:?}",
field_schema.name,
val.kind(),
expected_type
),
}
}
}
_ => panic!("Expected Value::Struct, got {:?}", value),
}
}
}
3 changes: 3 additions & 0 deletions src/ops/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ pub mod embed_text;
pub mod extract_by_llm;
pub mod parse_json;
pub mod split_recursively;

#[cfg(test)]
mod test_utils;
60 changes: 60 additions & 0 deletions src/ops/functions/parse_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,63 @@ impl SimpleFunctionFactoryBase for Factory {
Ok(Box::new(Executor { args }))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
use serde_json::json;

#[tokio::test]
async fn test_parse_json_util() {
let context = Arc::new(FlowInstanceContext {
flow_instance_name: "test_parse_json_flow".to_string(),
auth_registry: Arc::new(AuthRegistry::default()),
py_exec_ctx: None,
});

let spec_json = json!({});

let factory = Arc::new(Factory);
let json_string_content = r#"{"city": "Magdeburg"}"#;
let lang_value: Value = "json".to_string().into();

let input_args_values = vec![json_string_content.to_string().into(), lang_value.clone()];

let input_arg_schemas = vec![
build_arg_schema(
"text",
json_string_content.to_string().into(),
BasicValueType::Str,
),
build_arg_schema("language", lang_value, BasicValueType::Str),
];

let result = test_flow_function(
factory,
spec_json,
input_arg_schemas,
input_args_values,
context,
)
.await;

assert!(
result.is_ok(),
"test_flow_function failed: {:?}",
result.err()
);
let value = result.unwrap();

match value {
Value::Basic(BasicValue::Json(arc_json_value)) => {
let expected_json = json!({"city": "Magdeburg"});
assert_eq!(
*arc_json_value, expected_json,
"Parsed JSON value mismatch with specified language"
);
}
_ => panic!("Expected Value::Basic(BasicValue::Json), got {:?}", value),
}
}
}
71 changes: 71 additions & 0 deletions src/ops/functions/split_recursively.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,8 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
use serde_json::json;

// Helper function to assert chunk text and its consistency with the range within the original text.
fn assert_chunk_text_consistency(
Expand Down Expand Up @@ -1072,6 +1074,75 @@ mod tests {
}
}

#[tokio::test]
async fn test_split_recursively_with_util() {
let context = Arc::new(FlowInstanceContext {
flow_instance_name: "test_parse_recursively_flow".to_string(),
auth_registry: Arc::new(AuthRegistry::default()),
py_exec_ctx: None,
});

let spec_json = json!({});
let factory = Arc::new(Factory);
let text_content = "Linea 1.\nLinea 2.\n\nLinea 3.";

let input_args_values = vec![
text_content.to_string().into(),
(15i64).into(),
(5i64).into(),
(0i64).into(),
Value::Null,
];

let input_arg_schemas = vec![
build_arg_schema("text", text_content.to_string().into(), BasicValueType::Str),
build_arg_schema("chunk_size", (15i64).into(), BasicValueType::Int64),
build_arg_schema("min_chunk_size", (5i64).into(), BasicValueType::Int64),
build_arg_schema("chunk_overlap", (0i64).into(), BasicValueType::Int64),
build_arg_schema("language", Value::Null, BasicValueType::Str),
];

let result = test_flow_function(
factory,
spec_json,
input_arg_schemas,
input_args_values,
context,
)
.await;

assert!(
result.is_ok(),
"test_flow_function failed: {:?}",
result.err()
);
let value = result.unwrap();

match value {
Value::KTable(table) => {
let expected_chunks = vec![
(RangeValue::new(0, 8), "Linea 1."),
(RangeValue::new(9, 17), "Linea 2."),
(RangeValue::new(19, 27), "Linea 3."),
];

for (range, expected_text) in expected_chunks {
let key: KeyValue = range.into();
match table.get(&key) {
Some(scope_value_ref) => {
let chunk_text = scope_value_ref.0.fields[0]
.as_str()
.expect(&format!("Chunk text not a string for key {:?}", key));
assert_eq!(**chunk_text, *expected_text);
}
None => panic!("Expected row value for key {:?}, not found", key),
}
}
}
other => panic!("Expected Value::KTable, got {:?}", other),
}
}

#[test]
fn test_translate_bytes_to_chars_simple() {
let text = "abc😄def";
Expand Down
Loading