@@ -25,6 +25,8 @@ use common_openai::OpenAI;
2525use common_vector:: cosine_distance;
2626
2727pub fn register ( registry : & mut FunctionRegistry ) {
28+ // cosine_distance
29+ // This function takes two Float32 arrays as input and computes the cosine distance between them.
2830 registry. register_passthrough_nullable_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type , _ , _ > (
2931 "cosine_distance" ,
3032 |_, _| FunctionDomain :: MayThrow ,
@@ -48,6 +50,9 @@ pub fn register(registry: &mut FunctionRegistry) {
4850 ) ,
4951 ) ;
5052
53+ // embedding_vector
54+ // This function takes two strings as input, sends an API request to OpenAI, and returns the Float32 array of embeddings.
55+ // The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key.
5156 registry. register_passthrough_nullable_2_arg :: < StringType , StringType , ArrayType < Float32Type > , _ , _ > (
5257 "embedding_vector" ,
5358 |_, _| FunctionDomain :: MayThrow ,
@@ -63,11 +68,39 @@ pub fn register(registry: &mut FunctionRegistry) {
6368 output. push ( result. into ( ) ) ;
6469 }
6570 Err ( e) => {
66- ctx. set_error ( output. len ( ) , format ! ( "openai request error:{:?}" , e) ) ;
71+ ctx. set_error ( output. len ( ) , format ! ( "openai embedding request error:{:?}" , e) ) ;
6772 output. push ( vec ! [ F32 :: from( 0.0 ) ] . into ( ) ) ;
6873 }
6974 }
7075 } ,
7176 ) ,
7277 ) ;
78+
79+ // text_completion
80+ // This function takes two strings as input, sends an API request to OpenAI, and returns the AI-generated completion as a string.
81+ // The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key.
82+ registry. register_passthrough_nullable_2_arg :: < StringType , StringType , StringType , _ , _ > (
83+ "text_completion" ,
84+ |_, _| FunctionDomain :: MayThrow ,
85+ vectorize_with_builder_2_arg :: < StringType , StringType , StringType > (
86+ |data, api_key, output, ctx| {
87+ let data = std:: str:: from_utf8 ( data) . unwrap ( ) ;
88+ let api_key = std:: str:: from_utf8 ( api_key) . unwrap ( ) ;
89+ let openai = OpenAI :: create ( api_key. to_string ( ) , AIModel :: TextDavinci003 ) ;
90+ let result = openai. completion_request ( data. to_string ( ) ) ;
91+ match result {
92+ Ok ( ( resp, _) ) => {
93+ output. put_str ( & resp) ;
94+ }
95+ Err ( e) => {
96+ ctx. set_error (
97+ output. len ( ) ,
98+ format ! ( "openai completion request error:{:?}" , e) ,
99+ ) ;
100+ output. put_str ( "" ) ;
101+ }
102+ }
103+ } ,
104+ ) ,
105+ ) ;
73106}
0 commit comments