@@ -18,6 +18,7 @@ use common_expression::types::Float32Type;
1818use common_expression:: types:: StringType ;
1919use common_expression:: types:: F32 ;
2020use common_expression:: vectorize_with_builder_2_arg;
21+ use common_expression:: vectorize_with_builder_5_arg;
2122use common_expression:: FunctionDomain ;
2223use common_expression:: FunctionRegistry ;
2324use common_openai:: OpenAI ;
@@ -28,7 +29,7 @@ pub fn register(registry: &mut FunctionRegistry) {
2829 // This function takes two Float32 arrays as input and computes the cosine distance between them.
2930 registry. register_passthrough_nullable_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type , _ , _ > (
3031 "cosine_distance" ,
31- |_, _| FunctionDomain :: MayThrow ,
32+ |_, _| FunctionDomain :: MayThrow ,
3233 vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
3334 |lhs, rhs, output, ctx| {
3435 let l_f32=
@@ -52,14 +53,18 @@ pub fn register(registry: &mut FunctionRegistry) {
5253 // embedding_vector
5354 // This function takes two strings as input, sends an API request to OpenAI, and returns the Float32 array of embeddings.
5455 // The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key.
55- registry. register_passthrough_nullable_2_arg :: < StringType , StringType , ArrayType < Float32Type > , _ , _ > (
56+ registry. register_passthrough_nullable_5_arg :: < StringType , StringType , StringType , StringType , StringType , ArrayType < Float32Type > , _ , _ > (
5657 "embedding_vector" ,
57- |_, _| FunctionDomain :: MayThrow ,
58- vectorize_with_builder_2_arg :: < StringType , StringType , ArrayType < Float32Type > > (
59- |data, api_key, output, ctx| {
58+ |_, _, _ , _ , _ | FunctionDomain :: MayThrow ,
59+ vectorize_with_builder_5_arg :: < StringType , StringType , StringType , StringType , StringType , ArrayType < Float32Type > > (
60+ |data, api_base , api_key, embedding_model , completion_model , output, ctx| {
6061 let data = std:: str:: from_utf8 ( data) . unwrap ( ) ;
61- let api_key = std:: str:: from_utf8 ( api_key) . unwrap ( ) ;
62- let openai = OpenAI :: create ( api_key. to_string ( ) ) ;
62+
63+ let api_base = std:: str:: from_utf8 ( api_base) . unwrap ( ) . to_string ( ) ;
64+ let api_key = std:: str:: from_utf8 ( api_key) . unwrap ( ) . to_string ( ) ;
65+ let embedding_model = std:: str:: from_utf8 ( embedding_model) . unwrap ( ) . to_string ( ) ;
66+ let completion_model= std:: str:: from_utf8 ( completion_model) . unwrap ( ) . to_string ( ) ;
67+ let openai = OpenAI :: create ( api_base, api_key, embedding_model, completion_model) ;
6368 let result = openai. embedding_request ( & [ data. to_string ( ) ] ) ;
6469 match result {
6570 Ok ( ( embeddings, _) ) => {
@@ -78,14 +83,18 @@ pub fn register(registry: &mut FunctionRegistry) {
7883 // text_completion
7984 // This function takes two strings as input, sends an API request to OpenAI, and returns the AI-generated completion as a string.
8085 // The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key.
81- registry. register_passthrough_nullable_2_arg :: < StringType , StringType , StringType , _ , _ > (
86+ registry. register_passthrough_nullable_5_arg :: < StringType , StringType , StringType , StringType , StringType , StringType , _ , _ > (
8287 "text_completion" ,
83- |_, _| FunctionDomain :: MayThrow ,
84- vectorize_with_builder_2_arg :: < StringType , StringType , StringType > (
85- |data, api_key, output, ctx| {
88+ |_, _, _ , _ , _ | FunctionDomain :: MayThrow ,
89+ vectorize_with_builder_5_arg :: < StringType , StringType , StringType , StringType , StringType , StringType > (
90+ |data, api_base , api_key, embedding_model , completion_model , output, ctx| {
8691 let data = std:: str:: from_utf8 ( data) . unwrap ( ) ;
87- let api_key = std:: str:: from_utf8 ( api_key) . unwrap ( ) ;
88- let openai = OpenAI :: create ( api_key. to_string ( ) ) ;
92+
93+ let api_base = std:: str:: from_utf8 ( api_base) . unwrap ( ) . to_string ( ) ;
94+ let api_key = std:: str:: from_utf8 ( api_key) . unwrap ( ) . to_string ( ) ;
95+ let embedding_model = std:: str:: from_utf8 ( embedding_model) . unwrap ( ) . to_string ( ) ;
96+ let completion_model= std:: str:: from_utf8 ( completion_model) . unwrap ( ) . to_string ( ) ;
97+ let openai = OpenAI :: create ( api_base, api_key, embedding_model, completion_model) ;
8998 let result = openai. completion_text_request ( data. to_string ( ) ) ;
9099 match result {
91100 Ok ( ( resp, _) ) => {
0 commit comments