Skip to content

Commit d308b57

Browse files
authored
feat: add ai_text_completion function (#10880)
1 parent 0770a2f commit d308b57

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

src/query/functions/src/scalars/vector.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ use common_openai::OpenAI;
2525
use common_vector::cosine_distance;
2626

2727
pub fn register(registry: &mut FunctionRegistry) {
28+
// cosine_distance
29+
// This function takes two Float32 arrays as input and computes the cosine distance between them.
2830
registry.register_passthrough_nullable_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type, _, _>(
2931
"cosine_distance",
3032
|_, _| FunctionDomain::MayThrow,
@@ -48,6 +50,9 @@ pub fn register(registry: &mut FunctionRegistry) {
4850
),
4951
);
5052

53+
// embedding_vector
54+
// This function takes two strings as input, sends an API request to OpenAI, and returns the Float32 array of embeddings.
55+
// The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key.
5156
registry.register_passthrough_nullable_2_arg::<StringType, StringType, ArrayType<Float32Type>, _, _>(
5257
"embedding_vector",
5358
|_, _| FunctionDomain::MayThrow,
@@ -63,11 +68,39 @@ pub fn register(registry: &mut FunctionRegistry) {
6368
output.push(result.into());
6469
}
6570
Err(e) => {
66-
ctx.set_error(output.len(), format!("openai request error:{:?}", e));
71+
ctx.set_error(output.len(), format!("openai embedding request error:{:?}", e));
6772
output.push(vec![F32::from(0.0)].into());
6873
}
6974
}
7075
},
7176
),
7277
);
78+
79+
// text_completion
80+
// This function takes two strings as input, sends an API request to OpenAI, and returns the AI-generated completion as a string.
81+
// The OpenAI API key is pre-configured during the binder phase, so we rewrite this function and set the API key.
82+
registry.register_passthrough_nullable_2_arg::<StringType, StringType, StringType, _, _>(
83+
"text_completion",
84+
|_, _| FunctionDomain::MayThrow,
85+
vectorize_with_builder_2_arg::<StringType, StringType, StringType>(
86+
|data, api_key, output, ctx| {
87+
let data = std::str::from_utf8(data).unwrap();
88+
let api_key = std::str::from_utf8(api_key).unwrap();
89+
let openai = OpenAI::create(api_key.to_string(), AIModel::TextDavinci003);
90+
let result = openai.completion_request(data.to_string());
91+
match result {
92+
Ok((resp, _)) => {
93+
output.put_str(&resp);
94+
}
95+
Err(e) => {
96+
ctx.set_error(
97+
output.len(),
98+
format!("openai completion request error:{:?}", e),
99+
);
100+
output.put_str("");
101+
}
102+
}
103+
},
104+
),
105+
);
73106
}

src/query/functions/tests/it/scalars/testdata/function_list.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2996,6 +2996,8 @@ Functions overloads:
29962996
3 subtract_years(Timestamp NULL, Int64 NULL) :: Timestamp NULL
29972997
0 tan(Float64) :: Float64
29982998
1 tan(Float64 NULL) :: Float64 NULL
2999+
0 text_completion(String, String) :: String
3000+
1 text_completion(String NULL, String NULL) :: String NULL
29993001
0 time_slot(Timestamp) :: Timestamp
30003002
1 time_slot(Timestamp NULL) :: Timestamp NULL
30013003
0 to_base64(String) :: String

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,6 +1561,7 @@ impl<'a> TypeChecker<'a> {
15611561
"coalesce",
15621562
"last_query_id",
15631563
"ai_embedding_vector",
1564+
"ai_text_completion",
15641565
]
15651566
}
15661567

@@ -1781,6 +1782,28 @@ impl<'a> TypeChecker<'a> {
17811782
.await,
17821783
)
17831784
}
1785+
("ai_text_completion", args) => {
1786+
// ai_text_completion(prompt) -> text_completion(prompt, api_key)
1787+
if args.len() != 1 {
1788+
return Some(Err(ErrorCode::BadArguments(
1789+
"ai_text_completion(STRING) only accepts one STRING argument",
1790+
)
1791+
.set_span(span)));
1792+
}
1793+
1794+
// Prompt.
1795+
let arg1 = args[0];
1796+
// API key.
1797+
let arg2 = &Expr::Literal {
1798+
span,
1799+
lit: Literal::String(GlobalConfig::instance().query.openai_api_key.clone()),
1800+
};
1801+
1802+
Some(
1803+
self.resolve_function(span, "text_completion", vec![], &[arg1, arg2])
1804+
.await,
1805+
)
1806+
}
17841807
_ => None,
17851808
}
17861809
}

0 commit comments

Comments
 (0)