Skip to content

Commit 922073b

Browse files
authored
Merge branch 'main' into check-view-dep
2 parents e3a881d + 7ad8c21 commit 922073b

File tree

15 files changed

+192
-88
lines changed

15 files changed

+192
-88
lines changed

Cargo.lock

Lines changed: 9 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/meta/app/src/principal/user_stage.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ pub enum StageFileCompression {
106106
None,
107107
}
108108

109+
impl StageFileFormatType {
110+
pub fn has_inner_schema(&self) -> bool {
111+
matches!(self, StageFileFormatType::Parquet)
112+
}
113+
}
114+
109115
impl Default for StageFileCompression {
110116
fn default() -> Self {
111117
Self::None

src/query/service/Cargo.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ storages-common-table-meta = { path = "../storages/common/table-meta" }
9898
# Crates.io dependencies
9999
aho-corasick = { version = "0.7.20" }
100100
async-channel = "1.7.1"
101-
# Wait for https://github.com/64bit/async-openai/pull/60
102-
async-openai = { version = "0.9.5", git = "https://github.com/Xuanwo/async-openai", rev = "22fd9ca0c9c86a2c57462c5fd32aaff407d6b000" }
101+
async-openai = "0.10.0"
103102
async-stream = "0.3.3"
104103
async-trait = { version = "0.1.57", package = "async-trait-fn" }
105104
base64 = "0.21.0"
@@ -135,7 +134,7 @@ scopeguard = "1.1.0"
135134
serde = { workspace = true }
136135
serde_json = { workspace = true }
137136
serde_urlencoded = "0.7.1"
138-
socket2 = "0.5.1"
137+
socket2 = "0.4.7"
139138
strength_reduce = "0.2.4"
140139
tempfile = { version = "3.4.0", optional = true }
141140
time = "0.3.14"

src/query/service/src/interpreters/interpreter_copy.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use common_exception::ErrorCode;
2424
use common_exception::Result;
2525
use common_expression::infer_table_schema;
2626
use common_expression::DataField;
27+
use common_expression::DataSchema;
2728
use common_expression::DataSchemaRef;
2829
use common_expression::DataSchemaRefExt;
2930
use common_meta_app::principal::StageInfo;
@@ -41,6 +42,7 @@ use tracing::info;
4142
use crate::interpreters::common::append2table;
4243
use crate::interpreters::Interpreter;
4344
use crate::interpreters::SelectInterpreter;
45+
use crate::pipelines::processors::transforms::TransformRuntimeCastSchema;
4446
use crate::pipelines::processors::TransformCastSchema;
4547
use crate::pipelines::processors::TransformLimit;
4648
use crate::pipelines::PipelineBuildResult;
@@ -334,6 +336,26 @@ impl CopyInterpreter {
334336
)?;
335337
}
336338

339+
if stage_table_info
340+
.stage_info
341+
.file_format_options
342+
.format
343+
.has_inner_schema()
344+
{
345+
let dst_schema: Arc<DataSchema> = Arc::new(to_table.schema().into());
346+
let func_ctx = self.ctx.get_function_context()?;
347+
build_res.main_pipeline.add_transform(
348+
|transform_input_port, transform_output_port| {
349+
TransformRuntimeCastSchema::try_create(
350+
transform_input_port,
351+
transform_output_port,
352+
dst_schema.clone(),
353+
func_ctx,
354+
)
355+
},
356+
)?;
357+
}
358+
337359
// Build append data pipeline.
338360
to_table.append_data(
339361
ctx.clone(),

src/query/service/src/interpreters/interpreter_insert.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -378,20 +378,23 @@ impl Interpreter for InsertInterpreter {
378378
.format
379379
.exec_stream(input_context.clone(), &mut build_res.main_pipeline)?;
380380

381-
if Ok(StageFileFormatType::Parquet) == StageFileFormatType::from_str(format) {
382-
let dest_schema = plan.schema();
383-
let func_ctx = self.ctx.get_function_context()?;
384-
385-
build_res.main_pipeline.add_transform(
386-
|transform_input_port, transform_output_port| {
387-
TransformRuntimeCastSchema::try_create(
388-
transform_input_port,
389-
transform_output_port,
390-
dest_schema.clone(),
391-
func_ctx,
392-
)
393-
},
394-
)?;
381+
match StageFileFormatType::from_str(format) {
382+
Ok(f) if f.has_inner_schema() => {
383+
let dest_schema = plan.schema();
384+
let func_ctx = self.ctx.get_function_context()?;
385+
386+
build_res.main_pipeline.add_transform(
387+
|transform_input_port, transform_output_port| {
388+
TransformRuntimeCastSchema::try_create(
389+
transform_input_port,
390+
transform_output_port,
391+
dest_schema.clone(),
392+
func_ctx,
393+
)
394+
},
395+
)?;
396+
}
397+
_ => {}
395398
}
396399
}
397400
InsertInputSource::StreamingWithFileFormat(format_options, _, input_context) => {
@@ -400,7 +403,7 @@ impl Interpreter for InsertInterpreter {
400403
.format
401404
.exec_stream(input_context.clone(), &mut build_res.main_pipeline)?;
402405

403-
if StageFileFormatType::Parquet == format_options.format {
406+
if format_options.format.has_inner_schema() {
404407
let dest_schema = plan.schema();
405408
let func_ctx = self.ctx.get_function_context()?;
406409

src/query/service/src/servers/mysql/mysql_handler.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use futures::future::AbortRegistration;
2828
use futures::future::Abortable;
2929
use futures::StreamExt;
3030
use opensrv_mysql::*;
31+
use socket2::SockRef;
3132
use socket2::TcpKeepalive;
3233
use tokio_stream::wrappers::TcpListenerStream;
3334
use tracing::error;
@@ -92,7 +93,7 @@ impl MySQLHandler {
9293
sessions: Arc<SessionManager>,
9394
executor: Arc<Runtime>,
9495
socket: TcpStream,
95-
_keepalive: TcpKeepalive,
96+
keepalive: TcpKeepalive,
9697
) {
9798
executor.spawn(async move {
9899
match sessions.create_session(SessionType::MySQL).await {
@@ -103,11 +104,10 @@ impl MySQLHandler {
103104
Ok(session) => {
104105
info!("MySQL connection coming: {:?}", socket.peer_addr());
105106

106-
// FIXME: tokio TcpStream doesn't implement `AsFd` anymore, this call should be refactored.
107-
// if let Err(e) = SockRef::from(&socket).set_tcp_keepalive(&keepalive)
108-
// {
109-
// warn!("failed to set socket option keepalive {}", e);
110-
// }
107+
// TcpStream must implement AsFd for socket2 0.5, wait https://github.com/tokio-rs/tokio/pull/5514
108+
if let Err(e) = SockRef::from(&socket).set_tcp_keepalive(&keepalive) {
109+
warn!("failed to set socket option keepalive {}", e);
110+
}
111111

112112
if let Err(error) = MySQLConnection::run_on_stream(session, socket) {
113113
error!("Unexpected error occurred during query: {:?}", error);

src/query/service/src/table_functions/openai/gpt_to_sql.rs

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515
use std::any::Any;
1616
use std::sync::Arc;
17-
use std::time::Duration;
1817

19-
use async_openai::types::CreateCompletionRequestArgs;
20-
use async_openai::Client;
2118
use chrono::NaiveDateTime;
2219
use chrono::TimeZone;
2320
use chrono::Utc;
@@ -50,6 +47,9 @@ use common_storages_fuse::TableContext;
5047
use common_storages_view::view_table::VIEW_ENGINE;
5148
use tracing::info;
5249

50+
use crate::table_functions::openai::AIModel;
51+
use crate::table_functions::openai::OpenAI;
52+
5353
pub struct GPT2SQLTable {
5454
prompt: String,
5555
api_key: String,
@@ -233,37 +233,16 @@ impl AsyncSource for GPT2SQLSource {
233233
template.push("#".to_string());
234234
template.push("SELECT".to_string());
235235

236-
let model = "code-davinci-002";
237-
let timeout = Duration::from_secs(30);
238236
let prompt = template.join("");
239-
let api_key = self.api_key.clone();
240237
info!("openai request prompt: {}", prompt);
241238

242-
// Client
243-
let http_client = reqwest::ClientBuilder::new()
244-
.user_agent("databend")
245-
.timeout(timeout)
246-
.build()
247-
.map_err(|e| ErrorCode::Internal(format!("openai http error: {:?}", e)))?;
248-
let client = Client::new()
249-
.with_api_key(api_key)
250-
.with_http_client(http_client);
251-
252-
// Request
253-
let request = CreateCompletionRequestArgs::default()
254-
.model(model)
255-
.prompt(prompt)
256-
.temperature(0.0)
257-
.max_tokens(150_u16)
258-
.top_p(1.0)
259-
.frequency_penalty(0.0)
260-
.presence_penalty(0.0)
261-
.stop(["#", ";"])
262-
.build()
263-
.map_err(|e| ErrorCode::Internal(format!("openai request error: {:?}", e)))?;
264-
265239
// Response.
266-
let response = client
240+
let api_key = self.api_key.clone();
241+
let openai = OpenAI::create(api_key, AIModel::CodeDavinci002);
242+
let request = openai.completion_request(prompt)?;
243+
244+
let response = openai
245+
.client()?
267246
.completions()
268247
.create(request)
269248
.await

src/query/service/src/table_functions/openai/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,9 @@
1313
// limitations under the License.
1414

1515
mod gpt_to_sql;
16+
#[allow(clippy::module_inception)]
17+
mod openai;
1618

1719
pub use gpt_to_sql::GPT2SQLTable;
20+
pub use openai::AIModel;
21+
pub use openai::OpenAI;
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::time::Duration;
16+
17+
use async_openai::types::CreateCompletionRequest;
18+
use async_openai::types::CreateCompletionRequestArgs;
19+
use async_openai::Client;
20+
use common_exception::ErrorCode;
21+
use common_exception::Result;
22+
23+
pub enum AIModel {
24+
CodeDavinci002,
25+
}
26+
27+
// https://platform.openai.com/examples
28+
impl ToString for AIModel {
29+
fn to_string(&self) -> String {
30+
match self {
31+
AIModel::CodeDavinci002 => "code-davinci-002".to_string(),
32+
}
33+
}
34+
}
35+
36+
pub struct OpenAI {
37+
api_key: String,
38+
model: AIModel,
39+
}
40+
41+
impl OpenAI {
42+
pub fn create(api_key: String, model: AIModel) -> Self {
43+
OpenAI { api_key, model }
44+
}
45+
46+
pub fn client(&self) -> Result<Client> {
47+
let timeout = Duration::from_secs(30);
48+
// Client
49+
let http_client = reqwest::ClientBuilder::new()
50+
.user_agent("databend")
51+
.timeout(timeout)
52+
.build()
53+
.map_err(|e| ErrorCode::Internal(format!("openai http error: {:?}", e)))?;
54+
55+
Ok(Client::new()
56+
.with_api_key(self.api_key.clone())
57+
.with_http_client(http_client))
58+
}
59+
60+
pub fn completion_request(&self, prompt: String) -> Result<CreateCompletionRequest> {
61+
CreateCompletionRequestArgs::default()
62+
.model(self.model.to_string())
63+
.prompt(prompt)
64+
.temperature(0.0)
65+
.max_tokens(150_u16)
66+
.top_p(1.0)
67+
.frequency_penalty(0.0)
68+
.presence_penalty(0.0)
69+
.stop(["#", ";"])
70+
.build()
71+
.map_err(|e| ErrorCode::Internal(format!("openai completion request error: {:?}", e)))
72+
}
73+
}

tests/suites/1_stateful/00_copy/00_0002_copy_from_fs_location.result

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
199 2020.0 769
77
199 2020.0 769
88
199 2020.0 769
9+
199 2020.0000 769.00

0 commit comments

Comments
 (0)