Skip to content

Commit d347648

Browse files
committed
refactor(tests): simplify test util using structured Spec instead of JSON
1 parent 3b302e3 commit d347648

File tree

5 files changed

+37
-89
lines changed

5 files changed

+37
-89
lines changed

src/ops/functions/embed_text.rs

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,18 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
100100
mod tests {
101101
use super::*;
102102
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
103-
use serde_json::json;
104103

105104
#[tokio::test]
106105
#[ignore = "This test requires OpenAI API key or a configured local LLM and may make network calls."]
107-
async fn test_embed_text_with_util() {
108-
let context = Arc::new(FlowInstanceContext {
109-
flow_instance_name: "test_embed_text_flow".to_string(),
110-
auth_registry: Arc::new(AuthRegistry::default()),
111-
py_exec_ctx: None,
112-
});
113-
106+
async fn test_embed_text() {
114107
// Using OpenAI as an example.
115-
let spec_json = json!({
116-
"api_type": "OpenAi",
117-
"model": "text-embedding-ada-002",
118-
});
108+
let spec = Spec {
109+
api_type: LlmApiType::OpenAi,
110+
model: "text-embedding-ada-002".to_string(),
111+
address: None,
112+
output_dimension: None,
113+
task_type: None,
114+
};
119115

120116
let factory = Arc::new(Factory);
121117
let text_content = "CocoIndex is a performant data transformation framework for AI.";
@@ -128,14 +124,7 @@ mod tests {
128124
BasicValueType::Str,
129125
)];
130126

131-
let result = test_flow_function(
132-
factory,
133-
spec_json,
134-
input_arg_schemas,
135-
input_args_values,
136-
context,
137-
)
138-
.await;
127+
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
139128

140129
if result.is_err() {
141130
eprintln!(

src/ops/functions/extract_by_llm.rs

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,10 @@ impl SimpleFunctionFactoryBase for Factory {
160160
mod tests {
161161
use super::*;
162162
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
163-
use serde_json::json;
164163

165164
#[tokio::test]
166165
#[ignore = "This test requires an OpenAI API key or a configured local LLM and may make network calls."]
167-
async fn test_extract_by_llm_with_util() {
168-
let context = Arc::new(FlowInstanceContext {
169-
flow_instance_name: "test_extract_by_llm_flow".to_string(),
170-
auth_registry: Arc::new(AuthRegistry::default()),
171-
py_exec_ctx: None,
172-
});
173-
166+
async fn test_extract_by_llm() {
174167
// Define the expected output structure
175168
let target_output_schema = StructSchema {
176169
fields: Arc::new(vec![
@@ -193,20 +186,15 @@ mod tests {
193186
};
194187

195188
// Spec using OpenAI as an example.
196-
let spec_json = json!({
197-
"llm_spec": {
198-
"api_type": "OpenAi",
199-
"model": "gpt-4o",
200-
"address": null,
201-
"api_key_auth": null,
202-
"max_tokens": 100,
203-
"temperature": 0.0,
204-
"top_p": null,
205-
"params": {}
189+
let spec = Spec {
190+
llm_spec: LlmSpec {
191+
api_type: crate::llm::LlmApiType::OpenAi,
192+
model: "gpt-4o".to_string(),
193+
address: None,
206194
},
207-
"output_type": output_type_spec,
208-
"instruction": "Extract the name and value from the text. The name is a string, the value is an integer."
209-
});
195+
output_type: output_type_spec,
196+
instruction: Some("Extract the name and value from the text. The name is a string, the value is an integer.".to_string()),
197+
};
210198

211199
let factory = Arc::new(Factory);
212200
let text_content = "The item is called 'CocoIndex Test' and its value is 42.";
@@ -219,14 +207,7 @@ mod tests {
219207
BasicValueType::Str,
220208
)];
221209

222-
let result = test_flow_function(
223-
factory,
224-
spec_json,
225-
input_arg_schemas,
226-
input_args_values,
227-
context,
228-
)
229-
.await;
210+
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
230211

231212
if result.is_err() {
232213
eprintln!(

src/ops/functions/parse_json.rs

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,8 @@ mod tests {
110110
use serde_json::json;
111111

112112
#[tokio::test]
113-
async fn test_parse_json_util() {
114-
let context = Arc::new(FlowInstanceContext {
115-
flow_instance_name: "test_parse_json_flow".to_string(),
116-
auth_registry: Arc::new(AuthRegistry::default()),
117-
py_exec_ctx: None,
118-
});
119-
120-
let spec_json = json!({});
113+
async fn test_parse_json() {
114+
let spec = EmptySpec {};
121115

122116
let factory = Arc::new(Factory);
123117
let json_string_content = r#"{"city": "Magdeburg"}"#;
@@ -134,14 +128,7 @@ mod tests {
134128
build_arg_schema("language", lang_value, BasicValueType::Str),
135129
];
136130

137-
let result = test_flow_function(
138-
factory,
139-
spec_json,
140-
input_arg_schemas,
141-
input_args_values,
142-
context,
143-
)
144-
.await;
131+
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
145132

146133
assert!(
147134
result.is_ok(),

src/ops/functions/split_recursively.rs

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,6 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
10321032
mod tests {
10331033
use super::*;
10341034
use crate::ops::functions::test_utils::{build_arg_schema, test_flow_function};
1035-
use serde_json::json;
10361035

10371036
// Helper function to assert chunk text and its consistency with the range within the original text.
10381037
fn assert_chunk_text_consistency(
@@ -1075,14 +1074,10 @@ mod tests {
10751074
}
10761075

10771076
#[tokio::test]
1078-
async fn test_split_recursively_with_util() {
1079-
let context = Arc::new(FlowInstanceContext {
1080-
flow_instance_name: "test_parse_recursively_flow".to_string(),
1081-
auth_registry: Arc::new(AuthRegistry::default()),
1082-
py_exec_ctx: None,
1083-
});
1084-
1085-
let spec_json = json!({});
1077+
async fn test_split_recursively() {
1078+
let spec = Spec {
1079+
custom_languages: vec![],
1080+
};
10861081
let factory = Arc::new(Factory);
10871082
let text_content = "Linea 1.\nLinea 2.\n\nLinea 3.";
10881083

@@ -1102,14 +1097,7 @@ mod tests {
11021097
build_arg_schema("language", Value::Null, BasicValueType::Str),
11031098
];
11041099

1105-
let result = test_flow_function(
1106-
factory,
1107-
spec_json,
1108-
input_arg_schemas,
1109-
input_args_values,
1110-
context,
1111-
)
1112-
.await;
1100+
let result = test_flow_function(factory, spec, input_arg_schemas, input_args_values).await;
11131101

11141102
assert!(
11151103
result.is_ok(),
@@ -1250,6 +1238,7 @@ mod tests {
12501238
assert_chunk_text_consistency(text2, &chunks2[0], "A very very long", "Test 2, Chunk 0");
12511239
assert!(chunks2[0].text.len() <= 20);
12521240
}
1241+
12531242
#[test]
12541243
fn test_basic_split_with_overlap() {
12551244
let text = "This is a test text that is a bit longer to see how the overlap works.";
@@ -1269,6 +1258,7 @@ mod tests {
12691258
assert!(chunks[0].text.len() <= 25);
12701259
}
12711260
}
1261+
12721262
#[test]
12731263
fn test_split_trims_whitespace() {
12741264
let text = " \n First chunk. \n\n Second chunk with spaces at the end. \n";

src/ops/functions/test_utils.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
use crate::builder::plan::AnalyzedValueMapping;
22
use crate::ops::sdk::{
3-
BasicValueType, EnrichedValueType, FlowInstanceContext, OpArgSchema, OpArgsResolver,
4-
SimpleFunctionExecutor, SimpleFunctionFactoryBase, Value, make_output_type,
3+
AuthRegistry, BasicValueType, EnrichedValueType, FlowInstanceContext, OpArgSchema,
4+
OpArgsResolver, SimpleFunctionExecutor, SimpleFunctionFactoryBase, Value, make_output_type,
55
};
66
use anyhow::Result;
77
use serde::de::DeserializeOwned;
8-
use serde_json::Value as JsonValue;
98
use std::sync::Arc;
109

1110
fn new_literal_op_arg_schema(
@@ -30,18 +29,20 @@ pub fn build_arg_schema(name: &str, value: Value, value_type: BasicValueType) ->
3029
// This function tests a flow function by providing a spec, input argument schemas, and values.
3130
pub async fn test_flow_function<S, R, F>(
3231
factory: Arc<F>,
33-
spec_json: JsonValue,
32+
spec: S,
3433
input_arg_schemas: Vec<OpArgSchema>,
3534
input_arg_values: Vec<Value>,
36-
context: Arc<FlowInstanceContext>,
3735
) -> Result<Value>
3836
where
3937
S: DeserializeOwned + Send + Sync + 'static,
4038
R: Send + Sync + 'static,
4139
F: SimpleFunctionFactoryBase<Spec = S, ResolvedArgs = R> + ?Sized,
4240
{
43-
// 1. Deserialize Spec
44-
let spec: S = serde_json::from_value(spec_json)?;
41+
let context = Arc::new(FlowInstanceContext {
42+
flow_instance_name: "test_flow_function".to_string(),
43+
auth_registry: Arc::new(AuthRegistry::default()),
44+
py_exec_ctx: None,
45+
});
4546

4647
// 2. Resolve Schema & Args
4748
// The caller of test_flow_function will be responsible for creating these schemas.

0 commit comments

Comments
 (0)