Skip to content

Commit b407f40

Browse files
committed
Merge branch 'master' into feat/geo
2 parents aabd489 + 125bba1 commit b407f40

File tree

145 files changed

+2581
-1279
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

145 files changed

+2581
-1279
lines changed

Cargo.lock

Lines changed: 511 additions & 406 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[workspace]
22
resolver = "2"
3-
members = ["datafusion-postgres", "datafusion-postgres-cli", "arrow-pg"]
3+
members = ["datafusion-postgres", "datafusion-postgres-cli", "arrow-pg", "datafusion-pg-catalog"]
44

55
[workspace.package]
6-
version = "0.10.2"
6+
version = "0.12.0"
77
edition = "2021"
88
license = "Apache-2.0"
99
rust-version = "1.86.0"
@@ -19,9 +19,9 @@ bytes = "1.10.1"
1919
chrono = { version = "0.4", features = ["std"] }
2020
datafusion = { version = "50", default-features = false }
2121
futures = "0.3"
22-
pgwire = { version = "0.32", default-features = false }
22+
pgwire = { version = "0.34", default-features = false }
2323
postgres-types = "0.2"
24-
rust_decimal = { version = "1.38", features = ["db-postgres"] }
24+
rust_decimal = { version = "1.39", features = ["db-postgres"] }
2525
tokio = { version = "1", default-features = false }
2626

2727
[profile.release]

arrow-pg/Cargo.toml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "arrow-pg"
33
description = "Arrow data mapping and encoding/decoding for Postgres"
4-
version = "0.6.0"
4+
version = "0.8.0"
55
edition.workspace = true
66
license.workspace = true
77
authors.workspace = true
@@ -16,13 +16,26 @@ rust-version.workspace = true
1616
default = ["arrow"]
1717
arrow = ["dep:arrow"]
1818
datafusion = ["dep:datafusion"]
19+
# for testing
20+
_duckdb = []
21+
_bundled = ["duckdb/bundled"]
22+
1923

2024
[dependencies]
2125
arrow = { workspace = true, optional = true }
2226
bytes.workspace = true
2327
chrono.workspace = true
2428
datafusion = { workspace = true, optional = true }
2529
futures.workspace = true
26-
pgwire = { version = ">=0.32", default-features = false, features = ["server-api"] }
30+
pgwire = { workspace = true, default-features = false, features = ["server-api", "pg-ext-types"] }
2731
postgres-types.workspace = true
2832
rust_decimal.workspace = true
33+
34+
[dev-dependencies]
35+
async-trait = "0.1"
36+
duckdb = { version = "~1.4" }
37+
tokio = { version = "1.48", features = ["full"]}
38+
39+
[[example]]
40+
name = "duckdb"
41+
required-features = ["_duckdb"]

arrow-pg/examples/duckdb.rs

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
use std::sync::{Arc, Mutex};
2+
3+
use arrow_pg::datatypes::arrow_schema_to_pg_fields;
4+
use arrow_pg::datatypes::encode_recordbatch;
5+
use arrow_pg::datatypes::into_pg_type;
6+
use async_trait::async_trait;
7+
use duckdb::{params, Connection, Statement, ToSql};
8+
use futures::stream;
9+
use pgwire::api::auth::md5pass::{hash_md5_password, Md5PasswordAuthStartupHandler};
10+
use pgwire::api::auth::{
11+
AuthSource, DefaultServerParameterProvider, LoginInfo, Password, StartupHandler,
12+
};
13+
use pgwire::api::portal::{Format, Portal};
14+
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
15+
use pgwire::api::results::{
16+
DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
17+
};
18+
use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
19+
use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
20+
use pgwire::error::{PgWireError, PgWireResult};
21+
use pgwire::tokio::process_socket;
22+
use tokio::net::TcpListener;
23+
24+
pub struct DuckDBBackend {
25+
conn: Arc<Mutex<Connection>>,
26+
query_parser: Arc<NoopQueryParser>,
27+
}
28+
29+
#[derive(Debug)]
30+
struct DummyAuthSource;
31+
32+
#[async_trait]
33+
impl AuthSource for DummyAuthSource {
34+
async fn get_password(&self, login_info: &LoginInfo) -> PgWireResult<Password> {
35+
println!("login info: {:?}", login_info);
36+
37+
let salt = vec![0, 0, 0, 0];
38+
let password = "pencil";
39+
40+
let hash_password =
41+
hash_md5_password(login_info.user().as_ref().unwrap(), password, salt.as_ref());
42+
Ok(Password::new(Some(salt), hash_password.as_bytes().to_vec()))
43+
}
44+
}
45+
46+
#[async_trait]
47+
impl SimpleQueryHandler for DuckDBBackend {
48+
async fn do_query<C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
49+
where
50+
C: ClientInfo + Unpin + Send + Sync,
51+
{
52+
let conn = self.conn.lock().unwrap();
53+
if query.to_uppercase().starts_with("SELECT") {
54+
let mut stmt = conn
55+
.prepare(query)
56+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
57+
58+
let ret = stmt
59+
.query_arrow(params![])
60+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
61+
let schema = ret.get_schema();
62+
let header = Arc::new(arrow_schema_to_pg_fields(
63+
schema.as_ref(),
64+
&Format::UnifiedText,
65+
)?);
66+
67+
let header_ref = header.clone();
68+
let data = ret
69+
.flat_map(move |rb| encode_recordbatch(header_ref.clone(), rb))
70+
.collect::<Vec<_>>();
71+
Ok(vec![Response::Query(QueryResponse::new(
72+
header,
73+
stream::iter(data.into_iter()),
74+
))])
75+
} else {
76+
conn.execute(query, params![])
77+
.map(|affected_rows| {
78+
vec![Response::Execution(Tag::new("OK").with_rows(affected_rows))]
79+
})
80+
.map_err(|e| PgWireError::ApiError(Box::new(e)))
81+
}
82+
}
83+
}
84+
85+
fn get_params(portal: &Portal<String>) -> Vec<Box<dyn ToSql>> {
86+
let mut results = Vec::with_capacity(portal.parameter_len());
87+
for i in 0..portal.parameter_len() {
88+
let param_type = portal.statement.parameter_types.get(i).unwrap();
89+
// we only support a small amount of types for demo
90+
match param_type {
91+
&Type::BOOL => {
92+
let param = portal.parameter::<bool>(i, param_type).unwrap();
93+
results.push(Box::new(param) as Box<dyn ToSql>);
94+
}
95+
&Type::INT2 => {
96+
let param = portal.parameter::<i16>(i, param_type).unwrap();
97+
results.push(Box::new(param) as Box<dyn ToSql>);
98+
}
99+
&Type::INT4 => {
100+
let param = portal.parameter::<i32>(i, param_type).unwrap();
101+
results.push(Box::new(param) as Box<dyn ToSql>);
102+
}
103+
&Type::INT8 => {
104+
let param = portal.parameter::<i64>(i, param_type).unwrap();
105+
results.push(Box::new(param) as Box<dyn ToSql>);
106+
}
107+
&Type::TEXT | &Type::VARCHAR => {
108+
let param = portal.parameter::<String>(i, param_type).unwrap();
109+
results.push(Box::new(param) as Box<dyn ToSql>);
110+
}
111+
&Type::FLOAT4 => {
112+
let param = portal.parameter::<f32>(i, param_type).unwrap();
113+
results.push(Box::new(param) as Box<dyn ToSql>);
114+
}
115+
&Type::FLOAT8 => {
116+
let param = portal.parameter::<f64>(i, param_type).unwrap();
117+
results.push(Box::new(param) as Box<dyn ToSql>);
118+
}
119+
_ => {
120+
unimplemented!("parameter type not supported")
121+
}
122+
}
123+
}
124+
125+
results
126+
}
127+
128+
fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
129+
let columns = stmt.column_count();
130+
131+
(0..columns)
132+
.map(|idx| {
133+
let datatype = stmt.column_type(idx);
134+
let name = stmt.column_name(idx).unwrap();
135+
136+
Ok(FieldInfo::new(
137+
name.clone(),
138+
None,
139+
None,
140+
into_pg_type(&datatype).unwrap(),
141+
format.format_for(idx),
142+
))
143+
})
144+
.collect()
145+
}
146+
147+
#[async_trait]
148+
impl ExtendedQueryHandler for DuckDBBackend {
149+
type Statement = String;
150+
type QueryParser = NoopQueryParser;
151+
152+
fn query_parser(&self) -> Arc<Self::QueryParser> {
153+
self.query_parser.clone()
154+
}
155+
156+
async fn do_query<C>(
157+
&self,
158+
_client: &mut C,
159+
portal: &Portal<Self::Statement>,
160+
_max_rows: usize,
161+
) -> PgWireResult<Response>
162+
where
163+
C: ClientInfo + Unpin + Send + Sync,
164+
{
165+
let conn = self.conn.lock().unwrap();
166+
let query = &portal.statement.statement;
167+
let mut stmt = conn
168+
.prepare_cached(query)
169+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
170+
let params = get_params(portal);
171+
let params_ref = params
172+
.iter()
173+
.map(|f| f.as_ref())
174+
.collect::<Vec<&dyn duckdb::ToSql>>();
175+
176+
if query.to_uppercase().starts_with("SELECT") {
177+
let ret = stmt
178+
.query_arrow(params![])
179+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
180+
let schema = ret.get_schema();
181+
let header = Arc::new(arrow_schema_to_pg_fields(
182+
schema.as_ref(),
183+
&Format::UnifiedText,
184+
)?);
185+
186+
let header_ref = header.clone();
187+
let data = ret
188+
.flat_map(move |rb| encode_recordbatch(header_ref.clone(), rb))
189+
.collect::<Vec<_>>();
190+
191+
Ok(Response::Query(QueryResponse::new(
192+
header,
193+
stream::iter(data.into_iter()),
194+
)))
195+
} else {
196+
stmt.execute::<&[&dyn duckdb::ToSql]>(params_ref.as_ref())
197+
.map(|affected_rows| Response::Execution(Tag::new("OK").with_rows(affected_rows)))
198+
.map_err(|e| PgWireError::ApiError(Box::new(e)))
199+
}
200+
}
201+
202+
async fn do_describe_statement<C>(
203+
&self,
204+
_client: &mut C,
205+
stmt: &StoredStatement<Self::Statement>,
206+
) -> PgWireResult<DescribeStatementResponse>
207+
where
208+
C: ClientInfo + Unpin + Send + Sync,
209+
{
210+
let conn = self.conn.lock().unwrap();
211+
let param_types = stmt.parameter_types.clone();
212+
let stmt = conn
213+
.prepare_cached(&stmt.statement)
214+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
215+
row_desc_from_stmt(&stmt, &Format::UnifiedBinary)
216+
.map(|fields| DescribeStatementResponse::new(param_types, fields))
217+
}
218+
219+
async fn do_describe_portal<C>(
220+
&self,
221+
_client: &mut C,
222+
portal: &Portal<Self::Statement>,
223+
) -> PgWireResult<DescribePortalResponse>
224+
where
225+
C: ClientInfo + Unpin + Send + Sync,
226+
{
227+
let conn = self.conn.lock().unwrap();
228+
let stmt = conn
229+
.prepare_cached(&portal.statement.statement)
230+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
231+
row_desc_from_stmt(&stmt, &portal.result_column_format).map(DescribePortalResponse::new)
232+
}
233+
}
234+
235+
impl DuckDBBackend {
236+
fn new() -> DuckDBBackend {
237+
DuckDBBackend {
238+
conn: Arc::new(Mutex::new(Connection::open_in_memory().unwrap())),
239+
query_parser: Arc::new(NoopQueryParser::new()),
240+
}
241+
}
242+
}
243+
244+
struct DuckDBBackendFactory {
245+
handler: Arc<DuckDBBackend>,
246+
}
247+
248+
impl PgWireServerHandlers for DuckDBBackendFactory {
249+
fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
250+
self.handler.clone()
251+
}
252+
253+
fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
254+
self.handler.clone()
255+
}
256+
257+
fn startup_handler(&self) -> Arc<impl StartupHandler> {
258+
Arc::new(Md5PasswordAuthStartupHandler::new(
259+
Arc::new(DummyAuthSource),
260+
Arc::new(DefaultServerParameterProvider::default()),
261+
))
262+
}
263+
}
264+
265+
#[tokio::main]
266+
pub async fn main() {
267+
let factory = Arc::new(DuckDBBackendFactory {
268+
handler: Arc::new(DuckDBBackend::new()),
269+
});
270+
let server_addr = "127.0.0.1:5432";
271+
let listener = TcpListener::bind(server_addr).await.unwrap();
272+
println!(
273+
"Listening to {}, use password `pencil` to connect",
274+
server_addr
275+
);
276+
loop {
277+
let incoming_socket = listener.accept().await.unwrap();
278+
let factory_ref = factory.clone();
279+
280+
tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await });
281+
}
282+
}

arrow-pg/src/datatypes/df.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ use rust_decimal::Decimal;
1818

1919
use super::{arrow_schema_to_pg_fields, encode_recordbatch, into_pg_type};
2020

21-
pub async fn encode_dataframe<'a>(
22-
df: DataFrame,
23-
format: &Format,
24-
) -> PgWireResult<QueryResponse<'a>> {
21+
pub async fn encode_dataframe(df: DataFrame, format: &Format) -> PgWireResult<QueryResponse> {
2522
let fields = Arc::new(arrow_schema_to_pg_fields(df.schema().as_arrow(), format)?);
2623

2724
let recordbatch_stream = df

datafusion-pg-catalog/Cargo.toml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
[package]
2+
name = "datafusion-pg-catalog"
3+
description = "pg_catalog compatibility for datafusion"
4+
version.workspace = true
5+
edition.workspace = true
6+
license.workspace = true
7+
authors.workspace = true
8+
keywords.workspace = true
9+
homepage.workspace = true
10+
repository.workspace = true
11+
documentation.workspace = true
12+
readme = "../README.md"
13+
rust-version.workspace = true
14+
include = [
15+
"src/**/*",
16+
"pg_catalog_arrow_exports/**/*",
17+
"Cargo.toml"
18+
]
19+
20+
[dependencies]
21+
async-trait = "0.1"
22+
datafusion.workspace = true
23+
futures.workspace = true
24+
log = "0.4"
25+
postgres-types.workspace = true
26+
tokio = { version = "1.48", features = ["sync"] }
27+
28+
[dev-dependencies]
29+
env_logger = "0.11"

0 commit comments

Comments
 (0)