@@ -155,3 +155,121 @@ impl SimpleFunctionFactoryBase for Factory {
155155 Ok ( Box :: new ( Executor :: new ( spec, resolved_input_schema) . await ?) )
156156 }
157157}
158+
159+ #[ cfg( test) ]
160+ mod tests {
161+ use super :: * ;
162+ use crate :: ops:: functions:: test_utils:: { build_arg_schema, test_flow_function} ;
163+ use serde_json:: json;
164+
165+ #[ tokio:: test]
166+ #[ ignore = "This test requires an OpenAI API key or a configured local LLM and may make network calls." ]
167+ async fn test_extract_by_llm_with_util ( ) {
168+ let context = Arc :: new ( FlowInstanceContext {
169+ flow_instance_name : "test_extract_by_llm_flow" . to_string ( ) ,
170+ auth_registry : Arc :: new ( AuthRegistry :: default ( ) ) ,
171+ py_exec_ctx : None ,
172+ } ) ;
173+
174+ // Define the expected output structure
175+ let target_output_schema = StructSchema {
176+ fields : Arc :: new ( vec ! [
177+ FieldSchema :: new(
178+ "extracted_field_name" ,
179+ make_output_type( BasicValueType :: Str ) ,
180+ ) ,
181+ FieldSchema :: new(
182+ "extracted_field_value" ,
183+ make_output_type( BasicValueType :: Int64 ) ,
184+ ) ,
185+ ] ) ,
186+ description : Some ( "A test structure for extraction" . into ( ) ) ,
187+ } ;
188+
189+ let output_type_spec = EnrichedValueType {
190+ typ : ValueType :: Struct ( target_output_schema. clone ( ) ) ,
191+ nullable : false ,
192+ attrs : Arc :: new ( BTreeMap :: new ( ) ) ,
193+ } ;
194+
195+ // Spec using OpenAI as an example.
196+ let spec_json = json ! ( {
197+ "llm_spec" : {
198+ "api_type" : "OpenAi" ,
199+ "model" : "gpt-4o" ,
200+ "address" : null,
201+ "api_key_auth" : null,
202+ "max_tokens" : 100 ,
203+ "temperature" : 0.0 ,
204+ "top_p" : null,
205+ "params" : { }
206+ } ,
207+ "output_type" : output_type_spec,
208+ "instruction" : "Extract the name and value from the text. The name is a string, the value is an integer."
209+ } ) ;
210+
211+ let factory = Arc :: new ( Factory ) ;
212+ let text_content = "The item is called 'CocoIndex Test' and its value is 42." ;
213+
214+ let input_args_values = vec ! [ text_content. to_string( ) . into( ) ] ;
215+
216+ let input_arg_schemas = vec ! [ build_arg_schema(
217+ "text" ,
218+ text_content. to_string( ) . into( ) ,
219+ BasicValueType :: Str ,
220+ ) ] ;
221+
222+ let result = test_flow_function (
223+ factory,
224+ spec_json,
225+ input_arg_schemas,
226+ input_args_values,
227+ context,
228+ )
229+ . await ;
230+
231+ if result. is_err ( ) {
232+ eprintln ! (
233+ "test_extract_by_llm_with_util: test_flow_function returned error (potentially expected for evaluate): {:?}" ,
234+ result. as_ref( ) . err( )
235+ ) ;
236+ }
237+
238+ assert ! (
239+ result. is_ok( ) ,
240+ "test_flow_function failed. NOTE: This test may require network access/API keys for OpenAI. Error: {:?}" ,
241+ result. err( )
242+ ) ;
243+
244+ let value = result. unwrap ( ) ;
245+
246+ match value {
247+ Value :: Struct ( field_values) => {
248+ assert_eq ! (
249+ field_values. fields. len( ) ,
250+ target_output_schema. fields. len( ) ,
251+ "Mismatched number of fields in output struct"
252+ ) ;
253+ for ( idx, field_schema) in target_output_schema. fields . iter ( ) . enumerate ( ) {
254+ match ( & field_values. fields [ idx] , & field_schema. value_type . typ ) {
255+ (
256+ Value :: Basic ( BasicValue :: Str ( _) ) ,
257+ ValueType :: Basic ( BasicValueType :: Str ) ,
258+ ) => { }
259+ (
260+ Value :: Basic ( BasicValue :: Int64 ( _) ) ,
261+ ValueType :: Basic ( BasicValueType :: Int64 ) ,
262+ ) => { }
263+ ( val, expected_type) => panic ! (
264+ "Field '{}' type mismatch. Got {:?}, expected type compatible with {:?}" ,
265+ field_schema. name,
266+ val. kind( ) ,
267+ expected_type
268+ ) ,
269+ }
270+ }
271+ }
272+ _ => panic ! ( "Expected Value::Struct, got {:?}" , value) ,
273+ }
274+ }
275+ }
0 commit comments