Skip to content

Commit 0130fd1

Browse files
authored
feat: add openai api_base_url, completion_model and embedding_model to query config (#10993)
* feat: add openai api_base_url, completion_model and embedding_model to query config * fix the default value * fix unit test * fix unit test
1 parent a0b4586 commit 0130fd1

File tree

13 files changed

+204
-143
lines changed

13 files changed

+204
-143
lines changed
Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,28 @@ use openai_api_rust::Role;
2323

2424
use crate::metrics::metrics_completion_count;
2525
use crate::metrics::metrics_completion_token;
26-
use crate::AIModel;
2726
use crate::OpenAI;
2827

28+
#[derive(Debug)]
29+
enum CompletionMode {
30+
Sql,
31+
Text,
32+
}
33+
2934
impl OpenAI {
3035
pub fn completion_text_request(&self, prompt: String) -> Result<(String, Option<u32>)> {
36+
self.completion_request(CompletionMode::Text, prompt)
37+
}
38+
39+
pub fn completion_sql_request(&self, prompt: String) -> Result<(String, Option<u32>)> {
40+
self.completion_request(CompletionMode::Sql, prompt)
41+
}
42+
43+
fn completion_request(
44+
&self,
45+
mode: CompletionMode,
46+
prompt: String,
47+
) -> Result<(String, Option<u32>)> {
3148
let openai = openai_api_rust::OpenAI::new(
3249
Auth {
3350
api_key: self.api_key.clone(),
@@ -36,10 +53,13 @@ impl OpenAI {
3653
&self.api_base,
3754
);
3855

39-
let (max_tokens, stop) = (Some(1024), None);
56+
let (max_tokens, stop) = match mode {
57+
CompletionMode::Sql => (Some(150), Some(vec!["#".to_string(), ";".to_string()])),
58+
CompletionMode::Text => (Some(1024), None),
59+
};
4060

4161
let body = ChatBody {
42-
model: AIModel::GPT35Turbo.to_string(),
62+
model: self.completion_model.to_string(),
4363
temperature: Some(0_f32),
4464
top_p: Some(1_f32),
4565
n: None,
@@ -56,12 +76,15 @@ impl OpenAI {
5676
}],
5777
};
5878

59-
trace!("openai text completion request: {:?}", body);
79+
trace!("openai {:?} completion request: {:?}", mode, body);
6080

6181
let resp = openai.chat_completion_create(&body).map_err(|e| {
62-
ErrorCode::Internal(format!("openai completion text request error: {:?}", e))
82+
ErrorCode::Internal(format!(
83+
"openai {:?} completion request error: {:?}",
84+
mode, e
85+
))
6386
})?;
64-
trace!("openai text completion response: {:?}", resp);
87+
trace!("openai {:?} completion response: {:?}", mode, resp);
6588

6689
let usage = resp.usage.total_tokens;
6790
let result = if resp.choices.is_empty() {

src/common/openai/src/completion_sql.rs

Lines changed: 0 additions & 80 deletions
This file was deleted.

src/common/openai/src/embedding.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ use openai_api_rust::Auth;
2020

2121
use crate::metrics::metrics_embedding_count;
2222
use crate::metrics::metrics_embedding_token;
23-
use crate::AIModel;
2423
use crate::OpenAI;
2524

2625
impl OpenAI {
@@ -34,7 +33,7 @@ impl OpenAI {
3433
&self.api_base,
3534
);
3635
let body = EmbeddingsBody {
37-
model: AIModel::TextEmbeddingAda003.to_string(),
36+
model: self.embedding_model.to_string(),
3837
input: input.to_vec(),
3938
user: None,
4039
};

src/common/openai/src/lib.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
mod completion_sql;
16-
mod completion_text;
15+
mod completion;
1716
mod embedding;
1817

1918
#[allow(clippy::module_inception)]
2019
mod openai;
2120

2221
pub(crate) mod metrics;
2322

24-
pub use openai::AIModel;
2523
pub use openai::OpenAI;

src/common/openai/src/openai.rs

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,44 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
pub enum AIModel {
16-
// For SQL completion.
17-
TextDavinci003,
18-
// For embedding.
19-
TextEmbeddingAda003,
20-
// For Text completion.
21-
GPT35Turbo,
22-
}
23-
24-
// https://platform.openai.com/examples
25-
impl ToString for AIModel {
26-
fn to_string(&self) -> String {
27-
match self {
28-
AIModel::TextDavinci003 => "text-davinci-003".to_string(),
29-
AIModel::TextEmbeddingAda003 => "text-embedding-ada-002".to_string(),
30-
AIModel::GPT35Turbo => "gpt-3.5-turbo".to_string(),
31-
}
32-
}
33-
}
34-
3515
pub struct OpenAI {
3616
pub(crate) api_key: String,
3717
pub(crate) api_base: String,
18+
pub(crate) embedding_model: String,
19+
pub(crate) completion_model: String,
3820
}
3921

4022
impl OpenAI {
41-
pub fn create(api_key: String) -> Self {
23+
pub fn create(
24+
api_base: String,
25+
api_key: String,
26+
embedding_model: String,
27+
completion_model: String,
28+
) -> Self {
29+
// Check and default.
30+
let api_base = if api_base.is_empty() {
31+
"https://api.openai.com/v1/".to_string()
32+
} else {
33+
api_base
34+
};
35+
36+
let embedding_model = if embedding_model.is_empty() {
37+
"text-embedding-ada-002".to_string()
38+
} else {
39+
embedding_model
40+
};
41+
42+
let completion_model = if completion_model.is_empty() {
43+
"gpt-3.5-turbo".to_string()
44+
} else {
45+
completion_model
46+
};
47+
4248
OpenAI {
49+
api_base,
4350
api_key,
44-
api_base: "https://api.openai.com/v1/".to_string(),
51+
embedding_model,
52+
completion_model,
4553
}
4654
}
4755
}

src/common/openai/tests/it/openai.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ use common_openai::OpenAI;
1717
fn create_openai() -> Option<OpenAI> {
1818
let key = std::env::var("OPENAI_API_KEY").unwrap_or("".to_string());
1919
if !key.is_empty() {
20-
Some(OpenAI::create(key))
20+
Some(OpenAI::create(
21+
"".to_string(),
22+
key,
23+
"".to_string(),
24+
"".to_string(),
25+
))
2126
} else {
2227
None
2328
}

src/query/config/src/config.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,9 +1308,20 @@ pub struct QueryConfig {
13081308
#[clap(long)]
13091309
pub disable_system_table_load: bool,
13101310

1311+
#[clap(long, default_value = "https://api.openai.com/v1/")]
1312+
pub openai_api_base_url: String,
1313+
13111314
// This will not show in system.configs, put it to mask.rs.
13121315
#[clap(long, default_value = "")]
13131316
pub openai_api_key: String,
1317+
1318+
/// https://platform.openai.com/docs/models/embeddings
1319+
#[clap(long, default_value = "text-embedding-ada-002")]
1320+
pub openai_api_embedding_model: String,
1321+
1322+
/// https://platform.openai.com/docs/guides/chat
1323+
#[clap(long, default_value = "gpt-3.5-turbo")]
1324+
pub openai_api_completion_model: String,
13141325
}
13151326

13161327
impl Default for QueryConfig {
@@ -1372,7 +1383,10 @@ impl TryInto<InnerQueryConfig> for QueryConfig {
13721383
internal_enable_sandbox_tenant: self.internal_enable_sandbox_tenant,
13731384
internal_merge_on_read_mutation: self.internal_merge_on_read_mutation,
13741385
disable_system_table_load: self.disable_system_table_load,
1386+
openai_api_base_url: self.openai_api_base_url,
13751387
openai_api_key: self.openai_api_key,
1388+
openai_api_completion_model: self.openai_api_completion_model,
1389+
openai_api_embedding_model: self.openai_api_embedding_model,
13761390
})
13771391
}
13781392
}
@@ -1446,7 +1460,10 @@ impl From<InnerQueryConfig> for QueryConfig {
14461460
table_cache_bloom_index_filter_count: None,
14471461
table_cache_bloom_index_data_bytes: None,
14481462
disable_system_table_load: inner.disable_system_table_load,
1463+
openai_api_base_url: inner.openai_api_base_url,
14491464
openai_api_key: inner.openai_api_key,
1465+
openai_api_completion_model: inner.openai_api_completion_model,
1466+
openai_api_embedding_model: inner.openai_api_embedding_model,
14501467
}
14511468
}
14521469
}

src/query/config/src/inner.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,12 @@ pub struct QueryConfig {
174174
pub internal_merge_on_read_mutation: bool,
175175
/// Disable some system load(For example system.configs) for cloud security.
176176
pub disable_system_table_load: bool,
177+
178+
/// openai
177179
pub openai_api_key: String,
180+
pub openai_api_base_url: String,
181+
pub openai_api_embedding_model: String,
182+
pub openai_api_completion_model: String,
178183
}
179184

180185
impl Default for QueryConfig {
@@ -225,8 +230,11 @@ impl Default for QueryConfig {
225230
internal_enable_sandbox_tenant: false,
226231
internal_merge_on_read_mutation: false,
227232
disable_system_table_load: false,
228-
openai_api_key: "".to_string(),
229233
flight_sql_tls_server_key: "".to_string(),
234+
openai_api_base_url: "https://api.openai.com/v1/".to_string(),
235+
openai_api_key: "".to_string(),
236+
openai_api_completion_model: "gpt-3.5-turbo".to_string(),
237+
openai_api_embedding_model: "text-embedding-ada-002".to_string(),
230238
}
231239
}
232240
}

src/query/functions/src/scalars/vector.rs

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use common_expression::types::Float32Type;
1818
use common_expression::types::StringType;
1919
use common_expression::types::F32;
2020
use common_expression::vectorize_with_builder_2_arg;
21+
use common_expression::vectorize_with_builder_5_arg;
2122
use common_expression::FunctionDomain;
2223
use common_expression::FunctionRegistry;
2324
use common_openai::OpenAI;
@@ -28,7 +29,7 @@ pub fn register(registry: &mut FunctionRegistry) {
2829
// This function takes two Float32 arrays as input and computes the cosine distance between them.
2930
registry.register_passthrough_nullable_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type, _, _>(
3031
"cosine_distance",
31-
|_, _| FunctionDomain::MayThrow,
32+
|_, _| FunctionDomain::MayThrow,
3233
vectorize_with_builder_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type>(
3334
|lhs, rhs, output, ctx| {
3435
let l_f32=
@@ -52,14 +53,18 @@ pub fn register(registry: &mut FunctionRegistry) {
5253
// embedding_vector
5354
// This function takes two strings as input, sends an API request to OpenAI, and returns the Float32 array of embeddings.
5455
// The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key.
55-
registry.register_passthrough_nullable_2_arg::<StringType, StringType, ArrayType<Float32Type>, _, _>(
56+
registry.register_passthrough_nullable_5_arg::<StringType, StringType, StringType, StringType, StringType, ArrayType<Float32Type>, _, _>(
5657
"embedding_vector",
57-
|_, _| FunctionDomain::MayThrow,
58-
vectorize_with_builder_2_arg::<StringType, StringType, ArrayType<Float32Type>>(
59-
|data, api_key, output, ctx| {
58+
|_, _, _, _, _| FunctionDomain::MayThrow,
59+
vectorize_with_builder_5_arg::<StringType, StringType, StringType, StringType, StringType, ArrayType<Float32Type>>(
60+
|data, api_base,api_key, embedding_model, completion_model, output, ctx| {
6061
let data = std::str::from_utf8(data).unwrap();
61-
let api_key = std::str::from_utf8(api_key).unwrap();
62-
let openai = OpenAI::create(api_key.to_string());
62+
63+
let api_base = std::str::from_utf8(api_base).unwrap().to_string();
64+
let api_key = std::str::from_utf8(api_key).unwrap().to_string();
65+
let embedding_model = std::str::from_utf8(embedding_model).unwrap().to_string();
66+
let completion_model= std::str::from_utf8(completion_model).unwrap().to_string();
67+
let openai = OpenAI::create(api_base, api_key, embedding_model, completion_model);
6368
let result = openai.embedding_request(&[data.to_string()]);
6469
match result {
6570
Ok((embeddings, _)) => {
@@ -78,14 +83,18 @@ pub fn register(registry: &mut FunctionRegistry) {
7883
// text_completion
7984
// This function takes two strings as input, sends an API request to OpenAI, and returns the AI-generated completion as a string.
8085
// The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key.
81-
registry.register_passthrough_nullable_2_arg::<StringType, StringType, StringType, _, _>(
86+
registry.register_passthrough_nullable_5_arg::<StringType,StringType,StringType, StringType, StringType, StringType, _, _>(
8287
"text_completion",
83-
|_, _| FunctionDomain::MayThrow,
84-
vectorize_with_builder_2_arg::<StringType, StringType, StringType>(
85-
|data, api_key, output, ctx| {
88+
|_, _, _, _, _| FunctionDomain::MayThrow,
89+
vectorize_with_builder_5_arg::<StringType, StringType, StringType, StringType, StringType, StringType>(
90+
|data, api_base,api_key, embedding_model, completion_model, output, ctx| {
8691
let data = std::str::from_utf8(data).unwrap();
87-
let api_key = std::str::from_utf8(api_key).unwrap();
88-
let openai = OpenAI::create(api_key.to_string());
92+
93+
let api_base = std::str::from_utf8(api_base).unwrap().to_string();
94+
let api_key = std::str::from_utf8(api_key).unwrap().to_string();
95+
let embedding_model = std::str::from_utf8(embedding_model).unwrap().to_string();
96+
let completion_model= std::str::from_utf8(completion_model).unwrap().to_string();
97+
let openai = OpenAI::create(api_base, api_key, embedding_model, completion_model);
8998
let result = openai.completion_text_request(data.to_string());
9099
match result {
91100
Ok((resp, _)) => {

0 commit comments

Comments
 (0)