Skip to content

Commit df7b517

Browse files
authored
chore: change completion model to GPT 3.5 Turbo (#10945)
* chore: change completion model to GPT 3.5 Turbo * add unit test for openai completion * add trace to openai api * fix chat text completion response content * change the max token from 512 to 1024
1 parent f103481 commit df7b517

File tree

11 files changed

+188
-36
lines changed

11 files changed

+188
-36
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/common/openai/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ common-exception = { path = "../exception" }
1919
# GitHub dependencies
2020

2121
# Crates.io dependencies
22+
log = "0.4"
2223
metrics = "0.20.1"
2324
openai_api_rust = { git = "https://github.com/datafuse-extras/openai-api", rev = "5f977a4" }
2425

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,18 @@
1414

1515
use common_exception::ErrorCode;
1616
use common_exception::Result;
17+
use log::trace;
1718
use openai_api_rust::completions::CompletionsApi;
1819
use openai_api_rust::completions::CompletionsBody;
1920
use openai_api_rust::Auth;
2021

2122
use crate::metrics::metrics_completion_count;
2223
use crate::metrics::metrics_completion_token;
24+
use crate::AIModel;
2325
use crate::OpenAI;
2426

25-
pub enum CompletionMode {
26-
// SQL translate:
27-
// max_tokens: 150, stop: ['#', ';']
28-
SQL,
29-
// Text completion:
30-
// max_tokens: 512, stop: none
31-
Text,
32-
}
33-
3427
impl OpenAI {
35-
pub fn completion_request(
36-
&self,
37-
prompt: String,
38-
mode: CompletionMode,
39-
) -> Result<(String, Option<u32>)> {
28+
pub fn completion_sql_request(&self, prompt: String) -> Result<(String, Option<u32>)> {
4029
let openai = openai_api_rust::OpenAI::new(
4130
Auth {
4231
api_key: self.api_key.clone(),
@@ -45,19 +34,16 @@ impl OpenAI {
4534
&self.api_base,
4635
);
4736

48-
let (max_tokens, stop) = match mode {
49-
CompletionMode::SQL => (Some(150), Some(vec!["#".to_string(), ";".to_string()])),
50-
CompletionMode::Text => (Some(512), None),
51-
};
37+
let (max_tokens, stop) = (Some(150), Some(vec!["#".to_string(), ";".to_string()]));
5238

5339
let body = CompletionsBody {
54-
model: self.model.to_string(),
40+
model: AIModel::TextDavinci003.to_string(),
5541
prompt: Some(vec![prompt]),
5642
suffix: None,
5743
max_tokens,
5844
temperature: Some(0_f32),
5945
top_p: Some(1_f32),
60-
n: Some(2),
46+
n: None,
6147
stream: Some(false),
6248
logprobs: None,
6349
echo: None,
@@ -68,10 +54,14 @@ impl OpenAI {
6854
logit_bias: None,
6955
user: None,
7056
};
57+
trace!("openai sql completion request: {:?}", body);
58+
7159
let resp = openai.completion_create(&body).map_err(|e| {
72-
ErrorCode::Internal(format!("openai completion request error: {:?}", e))
60+
ErrorCode::Internal(format!("openai completion request sql error: {:?}", e))
7361
})?;
7462

63+
trace!("openai sql completion response: {:?}", resp);
64+
7565
let usage = resp.usage.total_tokens;
7666
let sql = if resp.choices.is_empty() {
7767
"".to_string()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use common_exception::ErrorCode;
16+
use common_exception::Result;
17+
use log::trace;
18+
use openai_api_rust::chat::ChatApi;
19+
use openai_api_rust::chat::ChatBody;
20+
use openai_api_rust::Auth;
21+
use openai_api_rust::Message;
22+
use openai_api_rust::Role;
23+
24+
use crate::metrics::metrics_completion_count;
25+
use crate::metrics::metrics_completion_token;
26+
use crate::AIModel;
27+
use crate::OpenAI;
28+
29+
impl OpenAI {
30+
pub fn completion_text_request(&self, prompt: String) -> Result<(String, Option<u32>)> {
31+
let openai = openai_api_rust::OpenAI::new(
32+
Auth {
33+
api_key: self.api_key.clone(),
34+
organization: None,
35+
},
36+
&self.api_base,
37+
);
38+
39+
let (max_tokens, stop) = (Some(1024), None);
40+
41+
let body = ChatBody {
42+
model: AIModel::GPT35Turbo.to_string(),
43+
temperature: Some(0_f32),
44+
top_p: Some(1_f32),
45+
n: None,
46+
stream: None,
47+
stop,
48+
max_tokens,
49+
presence_penalty: None,
50+
frequency_penalty: None,
51+
logit_bias: None,
52+
user: None,
53+
messages: vec![Message {
54+
role: Role::User,
55+
content: prompt,
56+
}],
57+
};
58+
59+
trace!("openai text completion request: {:?}", body);
60+
61+
let resp = openai.chat_completion_create(&body).map_err(|e| {
62+
ErrorCode::Internal(format!("openai completion text request error: {:?}", e))
63+
})?;
64+
trace!("openai text completion response: {:?}", resp);
65+
66+
let usage = resp.usage.total_tokens;
67+
let result = if resp.choices.is_empty() {
68+
"".to_string()
69+
} else {
70+
let message = resp
71+
.choices
72+
.get(0)
73+
.and_then(|choice| choice.message.as_ref());
74+
75+
match message {
76+
Some(msg) => msg.content.clone(),
77+
_ => "".to_string(),
78+
}
79+
};
80+
81+
// perf.
82+
{
83+
metrics_completion_count(1);
84+
metrics_completion_token(usage.unwrap_or(0));
85+
}
86+
87+
Ok((result, usage))
88+
}
89+
}

src/common/openai/src/embedding.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use openai_api_rust::Auth;
2020

2121
use crate::metrics::metrics_embedding_count;
2222
use crate::metrics::metrics_embedding_token;
23+
use crate::AIModel;
2324
use crate::OpenAI;
2425

2526
impl OpenAI {
@@ -33,7 +34,7 @@ impl OpenAI {
3334
&self.api_base,
3435
);
3536
let body = EmbeddingsBody {
36-
model: self.model.to_string(),
37+
model: AIModel::TextEmbeddingAda003.to_string(),
3738
input: input.to_vec(),
3839
user: None,
3940
};

src/common/openai/src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
mod completion;
15+
mod completion_sql;
16+
mod completion_text;
1617
mod embedding;
18+
1719
#[allow(clippy::module_inception)]
1820
mod openai;
1921

2022
pub(crate) mod metrics;
2123

22-
pub use completion::CompletionMode;
2324
pub use openai::AIModel;
2425
pub use openai::OpenAI;

src/common/openai/src/openai.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
// limitations under the License.
1414

1515
pub enum AIModel {
16+
// For SQL completion.
1617
TextDavinci003,
18+
// For embedding.
1719
TextEmbeddingAda003,
20+
// For Text completion.
21+
GPT35Turbo,
1822
}
1923

2024
// https://platform.openai.com/examples
@@ -23,22 +27,21 @@ impl ToString for AIModel {
2327
match self {
2428
AIModel::TextDavinci003 => "text-davinci-003".to_string(),
2529
AIModel::TextEmbeddingAda003 => "text-embedding-ada-002".to_string(),
30+
AIModel::GPT35Turbo => "gpt-3.5-turbo".to_string(),
2631
}
2732
}
2833
}
2934

3035
pub struct OpenAI {
3136
pub(crate) api_key: String,
3237
pub(crate) api_base: String,
33-
pub(crate) model: AIModel,
3438
}
3539

3640
impl OpenAI {
37-
pub fn create(api_key: String, model: AIModel) -> Self {
41+
pub fn create(api_key: String) -> Self {
3842
OpenAI {
3943
api_key,
4044
api_base: "https://api.openai.com/v1/".to_string(),
41-
model,
4245
}
4346
}
4447
}

src/common/openai/tests/it/main.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
mod openai;
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use common_openai::OpenAI;
16+
17+
fn create_openai() -> Option<OpenAI> {
18+
let key = std::env::var("OPENAI_API_KEY").unwrap_or("".to_string());
19+
if !key.is_empty() {
20+
Some(OpenAI::create(key))
21+
} else {
22+
None
23+
}
24+
}
25+
26+
#[test]
27+
fn test_openai_text_completion() {
28+
let openai = create_openai();
29+
if let Some(openai) = openai {
30+
let resp = openai
31+
.completion_text_request("say hello".to_string())
32+
.unwrap();
33+
34+
assert!(resp.0.contains("hello"));
35+
}
36+
}
37+
38+
#[test]
39+
fn test_openai_sql_completion() {
40+
let openai = create_openai();
41+
if let Some(openai) = openai {
42+
let resp = openai
43+
.completion_sql_request("### Postgres SQL tables, with their properties:
44+
#
45+
# Employee(id, name, department_id)
46+
# Department(id, name, address)
47+
# Salary_Payments(id, employee_id, amount, date)
48+
#
49+
### A query to list the names of the departments which employed more than 10 employees in the last 3 months
50+
SELECT".to_string())
51+
.unwrap();
52+
53+
assert!(resp.0.contains("FROM"));
54+
}
55+
}

src/query/functions/src/scalars/vector.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ use common_expression::types::F32;
2020
use common_expression::vectorize_with_builder_2_arg;
2121
use common_expression::FunctionDomain;
2222
use common_expression::FunctionRegistry;
23-
use common_openai::AIModel;
24-
use common_openai::CompletionMode;
2523
use common_openai::OpenAI;
2624
use common_vector::cosine_distance;
2725

@@ -61,7 +59,7 @@ pub fn register(registry: &mut FunctionRegistry) {
6159
|data, api_key, output, ctx| {
6260
let data = std::str::from_utf8(data).unwrap();
6361
let api_key = std::str::from_utf8(api_key).unwrap();
64-
let openai = OpenAI::create(api_key.to_string(), AIModel::TextEmbeddingAda003);
62+
let openai = OpenAI::create(api_key.to_string());
6563
let result = openai.embedding_request(&[data.to_string()]);
6664
match result {
6765
Ok((embeddings, _)) => {
@@ -87,8 +85,8 @@ pub fn register(registry: &mut FunctionRegistry) {
8785
|data, api_key, output, ctx| {
8886
let data = std::str::from_utf8(data).unwrap();
8987
let api_key = std::str::from_utf8(api_key).unwrap();
90-
let openai = OpenAI::create(api_key.to_string(), AIModel::TextDavinci003);
91-
let result = openai.completion_request(data.to_string(), CompletionMode::Text);
88+
let openai = OpenAI::create(api_key.to_string());
89+
let result = openai.completion_text_request(data.to_string());
9290
match result {
9391
Ok((resp, _)) => {
9492
output.put_str(&resp);

0 commit comments

Comments
 (0)