Skip to content

Commit 880c5d3

Browse files
committed
Merge branch 'master' into feat/pgwire-033
2 parents b704db0 + 3ddbd66 commit 880c5d3

File tree

102 files changed

+922
-348
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+922
-348
lines changed

Cargo.lock

Lines changed: 380 additions & 189 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[workspace]
22
resolver = "2"
3-
members = ["datafusion-postgres", "datafusion-postgres-cli", "arrow-pg"]
3+
members = ["datafusion-postgres", "datafusion-postgres-cli", "arrow-pg", "datafusion-pg-catalog"]
44

55
[workspace.package]
6-
version = "0.10.2"
6+
version = "0.11.0"
77
edition = "2021"
88
license = "Apache-2.0"
99
rust-version = "1.86.0"

arrow-pg/Cargo.toml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "arrow-pg"
33
description = "Arrow data mapping and encoding/decoding for Postgres"
4-
version = "0.6.0"
4+
version = "0.7.0"
55
edition.workspace = true
66
license.workspace = true
77
authors.workspace = true
@@ -16,13 +16,26 @@ rust-version.workspace = true
1616
default = ["arrow"]
1717
arrow = ["dep:arrow"]
1818
datafusion = ["dep:datafusion"]
19+
# for testing
20+
_duckdb = []
21+
_bundled = ["duckdb/bundled"]
22+
1923

2024
[dependencies]
2125
arrow = { workspace = true, optional = true }
2226
bytes.workspace = true
2327
chrono.workspace = true
2428
datafusion = { workspace = true, optional = true }
2529
futures.workspace = true
26-
pgwire = { version = ">=0.33", default-features = false, features = ["server-api"] }
30+
pgwire = { workspace = true, default-features = false, features = ["server-api"] }
2731
postgres-types.workspace = true
2832
rust_decimal.workspace = true
33+
34+
[dev-dependencies]
35+
async-trait = "0.1"
36+
duckdb = { version = "~1.4" }
37+
tokio = { version = "1.19", features = ["full"]}
38+
39+
[[example]]
40+
name = "duckdb"
41+
required-features = ["_duckdb"]

arrow-pg/examples/duckdb.rs

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
use std::sync::{Arc, Mutex};
2+
3+
use arrow_pg::datatypes::arrow_schema_to_pg_fields;
4+
use arrow_pg::datatypes::encode_recordbatch;
5+
use arrow_pg::datatypes::into_pg_type;
6+
use async_trait::async_trait;
7+
use duckdb::{params, Connection, Statement, ToSql};
8+
use futures::stream;
9+
use pgwire::api::auth::md5pass::{hash_md5_password, Md5PasswordAuthStartupHandler};
10+
use pgwire::api::auth::{
11+
AuthSource, DefaultServerParameterProvider, LoginInfo, Password, StartupHandler,
12+
};
13+
use pgwire::api::portal::{Format, Portal};
14+
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
15+
use pgwire::api::results::{
16+
DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
17+
};
18+
use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
19+
use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
20+
use pgwire::error::{PgWireError, PgWireResult};
21+
use pgwire::tokio::process_socket;
22+
use tokio::net::TcpListener;
23+
24+
pub struct DuckDBBackend {
25+
conn: Arc<Mutex<Connection>>,
26+
query_parser: Arc<NoopQueryParser>,
27+
}
28+
29+
struct DummyAuthSource;
30+
31+
#[async_trait]
32+
impl AuthSource for DummyAuthSource {
33+
async fn get_password(&self, login_info: &LoginInfo) -> PgWireResult<Password> {
34+
println!("login info: {:?}", login_info);
35+
36+
let salt = vec![0, 0, 0, 0];
37+
let password = "pencil";
38+
39+
let hash_password =
40+
hash_md5_password(login_info.user().as_ref().unwrap(), password, salt.as_ref());
41+
Ok(Password::new(Some(salt), hash_password.as_bytes().to_vec()))
42+
}
43+
}
44+
45+
#[async_trait]
46+
impl SimpleQueryHandler for DuckDBBackend {
47+
async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
48+
where
49+
C: ClientInfo + Unpin + Send + Sync,
50+
{
51+
let conn = self.conn.lock().unwrap();
52+
if query.to_uppercase().starts_with("SELECT") {
53+
let mut stmt = conn
54+
.prepare(query)
55+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
56+
57+
let ret = stmt
58+
.query_arrow(params![])
59+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
60+
let schema = ret.get_schema();
61+
let header = Arc::new(arrow_schema_to_pg_fields(
62+
schema.as_ref(),
63+
&Format::UnifiedText,
64+
)?);
65+
66+
let header_ref = header.clone();
67+
let data = ret
68+
.flat_map(move |rb| encode_recordbatch(header_ref.clone(), rb))
69+
.collect::<Vec<_>>();
70+
Ok(vec![Response::Query(QueryResponse::new(
71+
header,
72+
stream::iter(data.into_iter()),
73+
))])
74+
} else {
75+
conn.execute(query, params![])
76+
.map(|affected_rows| {
77+
vec![Response::Execution(Tag::new("OK").with_rows(affected_rows))]
78+
})
79+
.map_err(|e| PgWireError::ApiError(Box::new(e)))
80+
}
81+
}
82+
}
83+
84+
fn get_params(portal: &Portal<String>) -> Vec<Box<dyn ToSql>> {
85+
let mut results = Vec::with_capacity(portal.parameter_len());
86+
for i in 0..portal.parameter_len() {
87+
let param_type = portal.statement.parameter_types.get(i).unwrap();
88+
// we only support a small amount of types for demo
89+
match param_type {
90+
&Type::BOOL => {
91+
let param = portal.parameter::<bool>(i, param_type).unwrap();
92+
results.push(Box::new(param) as Box<dyn ToSql>);
93+
}
94+
&Type::INT2 => {
95+
let param = portal.parameter::<i16>(i, param_type).unwrap();
96+
results.push(Box::new(param) as Box<dyn ToSql>);
97+
}
98+
&Type::INT4 => {
99+
let param = portal.parameter::<i32>(i, param_type).unwrap();
100+
results.push(Box::new(param) as Box<dyn ToSql>);
101+
}
102+
&Type::INT8 => {
103+
let param = portal.parameter::<i64>(i, param_type).unwrap();
104+
results.push(Box::new(param) as Box<dyn ToSql>);
105+
}
106+
&Type::TEXT | &Type::VARCHAR => {
107+
let param = portal.parameter::<String>(i, param_type).unwrap();
108+
results.push(Box::new(param) as Box<dyn ToSql>);
109+
}
110+
&Type::FLOAT4 => {
111+
let param = portal.parameter::<f32>(i, param_type).unwrap();
112+
results.push(Box::new(param) as Box<dyn ToSql>);
113+
}
114+
&Type::FLOAT8 => {
115+
let param = portal.parameter::<f64>(i, param_type).unwrap();
116+
results.push(Box::new(param) as Box<dyn ToSql>);
117+
}
118+
_ => {
119+
unimplemented!("parameter type not supported")
120+
}
121+
}
122+
}
123+
124+
results
125+
}
126+
127+
fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
128+
let columns = stmt.column_count();
129+
130+
(0..columns)
131+
.map(|idx| {
132+
let datatype = stmt.column_type(idx);
133+
let name = stmt.column_name(idx).unwrap();
134+
135+
Ok(FieldInfo::new(
136+
name.clone(),
137+
None,
138+
None,
139+
into_pg_type(&datatype).unwrap(),
140+
format.format_for(idx),
141+
))
142+
})
143+
.collect()
144+
}
145+
146+
#[async_trait]
147+
impl ExtendedQueryHandler for DuckDBBackend {
148+
type Statement = String;
149+
type QueryParser = NoopQueryParser;
150+
151+
fn query_parser(&self) -> Arc<Self::QueryParser> {
152+
self.query_parser.clone()
153+
}
154+
155+
async fn do_query<'a, C>(
156+
&self,
157+
_client: &mut C,
158+
portal: &Portal<Self::Statement>,
159+
_max_rows: usize,
160+
) -> PgWireResult<Response<'a>>
161+
where
162+
C: ClientInfo + Unpin + Send + Sync,
163+
{
164+
let conn = self.conn.lock().unwrap();
165+
let query = &portal.statement.statement;
166+
let mut stmt = conn
167+
.prepare_cached(query)
168+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
169+
let params = get_params(portal);
170+
let params_ref = params
171+
.iter()
172+
.map(|f| f.as_ref())
173+
.collect::<Vec<&dyn duckdb::ToSql>>();
174+
175+
if query.to_uppercase().starts_with("SELECT") {
176+
let ret = stmt
177+
.query_arrow(params![])
178+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
179+
let schema = ret.get_schema();
180+
let header = Arc::new(arrow_schema_to_pg_fields(
181+
schema.as_ref(),
182+
&Format::UnifiedText,
183+
)?);
184+
185+
let header_ref = header.clone();
186+
let data = ret
187+
.flat_map(move |rb| encode_recordbatch(header_ref.clone(), rb))
188+
.collect::<Vec<_>>();
189+
190+
Ok(Response::Query(QueryResponse::new(
191+
header,
192+
stream::iter(data.into_iter()),
193+
)))
194+
} else {
195+
stmt.execute::<&[&dyn duckdb::ToSql]>(params_ref.as_ref())
196+
.map(|affected_rows| Response::Execution(Tag::new("OK").with_rows(affected_rows)))
197+
.map_err(|e| PgWireError::ApiError(Box::new(e)))
198+
}
199+
}
200+
201+
async fn do_describe_statement<C>(
202+
&self,
203+
_client: &mut C,
204+
stmt: &StoredStatement<Self::Statement>,
205+
) -> PgWireResult<DescribeStatementResponse>
206+
where
207+
C: ClientInfo + Unpin + Send + Sync,
208+
{
209+
let conn = self.conn.lock().unwrap();
210+
let param_types = stmt.parameter_types.clone();
211+
let stmt = conn
212+
.prepare_cached(&stmt.statement)
213+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
214+
row_desc_from_stmt(&stmt, &Format::UnifiedBinary)
215+
.map(|fields| DescribeStatementResponse::new(param_types, fields))
216+
}
217+
218+
async fn do_describe_portal<C>(
219+
&self,
220+
_client: &mut C,
221+
portal: &Portal<Self::Statement>,
222+
) -> PgWireResult<DescribePortalResponse>
223+
where
224+
C: ClientInfo + Unpin + Send + Sync,
225+
{
226+
let conn = self.conn.lock().unwrap();
227+
let stmt = conn
228+
.prepare_cached(&portal.statement.statement)
229+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
230+
row_desc_from_stmt(&stmt, &portal.result_column_format).map(DescribePortalResponse::new)
231+
}
232+
}
233+
234+
impl DuckDBBackend {
235+
fn new() -> DuckDBBackend {
236+
DuckDBBackend {
237+
conn: Arc::new(Mutex::new(Connection::open_in_memory().unwrap())),
238+
query_parser: Arc::new(NoopQueryParser::new()),
239+
}
240+
}
241+
}
242+
243+
struct DuckDBBackendFactory {
244+
handler: Arc<DuckDBBackend>,
245+
}
246+
247+
impl PgWireServerHandlers for DuckDBBackendFactory {
248+
fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
249+
self.handler.clone()
250+
}
251+
252+
fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
253+
self.handler.clone()
254+
}
255+
256+
fn startup_handler(&self) -> Arc<impl StartupHandler> {
257+
Arc::new(Md5PasswordAuthStartupHandler::new(
258+
Arc::new(DummyAuthSource),
259+
Arc::new(DefaultServerParameterProvider::default()),
260+
))
261+
}
262+
}
263+
264+
#[tokio::main]
265+
pub async fn main() {
266+
let factory = Arc::new(DuckDBBackendFactory {
267+
handler: Arc::new(DuckDBBackend::new()),
268+
});
269+
let server_addr = "127.0.0.1:5432";
270+
let listener = TcpListener::bind(server_addr).await.unwrap();
271+
println!(
272+
"Listening to {}, use password `pencil` to connect",
273+
server_addr
274+
);
275+
loop {
276+
let incoming_socket = listener.accept().await.unwrap();
277+
let factory_ref = factory.clone();
278+
279+
tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await });
280+
}
281+
}

datafusion-pg-catalog/Cargo.toml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
[package]
2+
name = "datafusion-pg-catalog"
3+
description = "pg_catalog compatibility for datafusion"
4+
version.workspace = true
5+
edition.workspace = true
6+
license.workspace = true
7+
authors.workspace = true
8+
keywords.workspace = true
9+
homepage.workspace = true
10+
repository.workspace = true
11+
documentation.workspace = true
12+
readme = "../README.md"
13+
rust-version.workspace = true
14+
include = [
15+
"src/**/*",
16+
"pg_catalog_arrow_exports/**/*",
17+
"Cargo.toml"
18+
]
19+
20+
[dependencies]
21+
async-trait = "0.1"
22+
datafusion.workspace = true
23+
futures.workspace = true
24+
log = "0.4"
25+
postgres-types.workspace = true
26+
tokio = { version = "1.47", features = ["sync"] }
27+
28+
[dev-dependencies]
29+
env_logger = "0.11"

0 commit comments

Comments
 (0)