diff --git a/Cargo.lock b/Cargo.lock index 210502d..8e22539 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -345,13 +345,16 @@ name = "arrow-pg" version = "0.6.1" dependencies = [ "arrow", + "async-trait", "bytes", "chrono", "datafusion", + "duckdb", "futures", "pgwire", "postgres-types", "rust_decimal", + "tokio", ] [[package]] @@ -373,6 +376,7 @@ version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3aa9e59c611ebc291c28582077ef25c97f1975383f1479b12f3b9ffee2ffabe" dependencies = [ + "bitflags 2.9.4", "serde", "serde_json", ] @@ -694,6 +698,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.2.38" @@ -1664,6 +1674,23 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "duckdb" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2207e2cc81719eca29ef65cf904df49505cb0346c460ba6272c99cc5d7eb655" +dependencies = [ + "arrow", + "cast", + "fallible-iterator 0.3.0", + "fallible-streaming-iterator", + "hashlink", + "libduckdb-sys", + "num-integer", + "rust_decimal", + "strum 0.27.2", +] + [[package]] name = "either" version = "1.15.0" @@ -1715,12 +1742,36 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "filetime" +version = "0.2.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc0505cd1b6fa6580283f6bdf70a73fcf4aba1184038c90902b92b3dd0df63ed" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys 0.60.2", +] + [[package]] name = "find-msvc-tools" version = "0.1.2" @@ -1976,6 +2027,15 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown 0.15.5", +] + [[package]] name = "heck" version = "0.3.3" @@ -2350,12 +2410,38 @@ version = "0.2.176" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" +[[package]] +name = "libduckdb-sys" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2d70696fa43ca9ed04ab7d8bfd4f3f1df73cabd3c872d67ddc9f55bb8dd4d1a" +dependencies = [ + "cc", + "flate2", + "pkg-config", + "serde", + "serde_json", + "tar", + "vcpkg", +] + [[package]] name = "libm" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +[[package]] +name = "libredox" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" +dependencies = [ + "bitflags 2.9.4", + "libc", + "redox_syscall", +] + [[package]] name = "libz-rs-sys" version = "0.5.2" @@ -2766,7 +2852,7 @@ dependencies = [ "base64", "byteorder", "bytes", - "fallible-iterator", + "fallible-iterator 0.2.0", "hmac", "md-5", "memchr", @@ -2784,7 +2870,7 @@ dependencies = [ "array-init", "bytes", "chrono", - "fallible-iterator", + "fallible-iterator 0.2.0", "postgres-protocol", ] @@ -3476,6 +3562,9 @@ name = "strum" version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros 0.27.2", +] [[package]] name = "strum_macros" @@ -3547,6 +3636,17 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tar" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "tempfile" version = "3.23.0" @@ -3868,6 +3968,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -4288,6 +4394,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix", +] + [[package]] name = "xz2" version = "0.1.7" diff --git a/arrow-pg/Cargo.toml b/arrow-pg/Cargo.toml index 4499e18..6718106 100644 --- a/arrow-pg/Cargo.toml +++ b/arrow-pg/Cargo.toml @@ -16,6 +16,10 @@ rust-version.workspace = true default = ["arrow"] arrow = ["dep:arrow"] datafusion = ["dep:datafusion"] +# for testing +_duckdb = [] +_bundled = ["duckdb/bundled"] + [dependencies] arrow = { workspace = true, optional = true } @@ -23,6 +27,15 @@ bytes.workspace = true chrono.workspace = true datafusion = { workspace = true, optional = true } futures.workspace = true -pgwire = { version = ">=0.32, <0.33", default-features = false, features = ["server-api"] } +pgwire = { workspace = true, default-features = false, features = ["server-api"] } postgres-types.workspace = true rust_decimal.workspace = true + +[dev-dependencies] +async-trait = "0.1" +duckdb = { version = "~1.4" } +tokio = { version = "1.19", features = ["full"]} + +[[example]] +name = "duckdb" +required-features = ["_duckdb"] diff --git a/arrow-pg/examples/duckdb.rs b/arrow-pg/examples/duckdb.rs new file mode 100644 index 0000000..5fe4892 --- /dev/null +++ b/arrow-pg/examples/duckdb.rs @@ -0,0 +1,281 @@ +use std::sync::{Arc, Mutex}; + +use arrow_pg::datatypes::arrow_schema_to_pg_fields; +use arrow_pg::datatypes::encode_recordbatch; +use arrow_pg::datatypes::into_pg_type; +use async_trait::async_trait; +use duckdb::{params, Connection, Statement, ToSql}; +use futures::stream; +use pgwire::api::auth::md5pass::{hash_md5_password, Md5PasswordAuthStartupHandler}; +use pgwire::api::auth::{ + AuthSource, DefaultServerParameterProvider, LoginInfo, Password, StartupHandler, +}; +use pgwire::api::portal::{Format, Portal}; +use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; +use pgwire::api::results::{ + DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag, +}; +use pgwire::api::stmt::{NoopQueryParser, StoredStatement}; +use pgwire::api::{ClientInfo, PgWireServerHandlers, Type}; +use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::tokio::process_socket; +use tokio::net::TcpListener; + +pub struct DuckDBBackend { + conn: Arc>, + query_parser: Arc, +} + +struct DummyAuthSource; + +#[async_trait] +impl AuthSource for DummyAuthSource { + async fn get_password(&self, login_info: &LoginInfo) -> PgWireResult { + println!("login info: {:?}", login_info); + + let salt = vec![0, 0, 0, 0]; + let password = "pencil"; + + let hash_password = + hash_md5_password(login_info.user().as_ref().unwrap(), password, salt.as_ref()); + Ok(Password::new(Some(salt), hash_password.as_bytes().to_vec())) + } +} + +#[async_trait] +impl SimpleQueryHandler for DuckDBBackend { + async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult>> + where + C: ClientInfo + Unpin + Send + Sync, + { + let conn = self.conn.lock().unwrap(); + if query.to_uppercase().starts_with("SELECT") { + let mut stmt = conn + .prepare(query) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + let ret = stmt + .query_arrow(params![]) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let schema = ret.get_schema(); + let header = Arc::new(arrow_schema_to_pg_fields( + schema.as_ref(), + &Format::UnifiedText, + )?); + + let header_ref = header.clone(); + let data = ret + .flat_map(move |rb| encode_recordbatch(header_ref.clone(), rb)) + .collect::>(); + Ok(vec![Response::Query(QueryResponse::new( + header, + stream::iter(data.into_iter()), + ))]) + } else { + conn.execute(query, params![]) + .map(|affected_rows| { + vec![Response::Execution(Tag::new("OK").with_rows(affected_rows))] + }) + .map_err(|e| PgWireError::ApiError(Box::new(e))) + } + } +} + +fn get_params(portal: &Portal) -> Vec> { + let mut results = Vec::with_capacity(portal.parameter_len()); + for i in 0..portal.parameter_len() { + let param_type = portal.statement.parameter_types.get(i).unwrap(); + // we only support a small amount of types for demo + match param_type { + &Type::BOOL => { + let param = portal.parameter::(i, param_type).unwrap(); + results.push(Box::new(param) as Box); + } + &Type::INT2 => { + let param = portal.parameter::(i, param_type).unwrap(); + results.push(Box::new(param) as Box); + } + &Type::INT4 => { + let param = portal.parameter::(i, param_type).unwrap(); + results.push(Box::new(param) as Box); + } + &Type::INT8 => { + let param = portal.parameter::(i, param_type).unwrap(); + results.push(Box::new(param) as Box); + } + &Type::TEXT | &Type::VARCHAR => { + let param = portal.parameter::(i, param_type).unwrap(); + results.push(Box::new(param) as Box); + } + &Type::FLOAT4 => { + let param = portal.parameter::(i, param_type).unwrap(); + results.push(Box::new(param) as Box); + } + &Type::FLOAT8 => { + let param = portal.parameter::(i, param_type).unwrap(); + results.push(Box::new(param) as Box); + } + _ => { + unimplemented!("parameter type not supported") + } + } + } + + results +} + +fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult> { + let columns = stmt.column_count(); + + (0..columns) + .map(|idx| { + let datatype = stmt.column_type(idx); + let name = stmt.column_name(idx).unwrap(); + + Ok(FieldInfo::new( + name.clone(), + None, + None, + into_pg_type(&datatype).unwrap(), + format.format_for(idx), + )) + }) + .collect() +} + +#[async_trait] +impl ExtendedQueryHandler for DuckDBBackend { + type Statement = String; + type QueryParser = NoopQueryParser; + + fn query_parser(&self) -> Arc { + self.query_parser.clone() + } + + async fn do_query<'a, C>( + &self, + _client: &mut C, + portal: &Portal, + _max_rows: usize, + ) -> PgWireResult> + where + C: ClientInfo + Unpin + Send + Sync, + { + let conn = self.conn.lock().unwrap(); + let query = &portal.statement.statement; + let mut stmt = conn + .prepare_cached(query) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let params = get_params(portal); + let params_ref = params + .iter() + .map(|f| f.as_ref()) + .collect::>(); + + if query.to_uppercase().starts_with("SELECT") { + let ret = stmt + .query_arrow(params![]) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let schema = ret.get_schema(); + let header = Arc::new(arrow_schema_to_pg_fields( + schema.as_ref(), + &Format::UnifiedText, + )?); + + let header_ref = header.clone(); + let data = ret + .flat_map(move |rb| encode_recordbatch(header_ref.clone(), rb)) + .collect::>(); + + Ok(Response::Query(QueryResponse::new( + header, + stream::iter(data.into_iter()), + ))) + } else { + stmt.execute::<&[&dyn duckdb::ToSql]>(params_ref.as_ref()) + .map(|affected_rows| Response::Execution(Tag::new("OK").with_rows(affected_rows))) + .map_err(|e| PgWireError::ApiError(Box::new(e))) + } + } + + async fn do_describe_statement( + &self, + _client: &mut C, + stmt: &StoredStatement, + ) -> PgWireResult + where + C: ClientInfo + Unpin + Send + Sync, + { + let conn = self.conn.lock().unwrap(); + let param_types = stmt.parameter_types.clone(); + let stmt = conn + .prepare_cached(&stmt.statement) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + row_desc_from_stmt(&stmt, &Format::UnifiedBinary) + .map(|fields| DescribeStatementResponse::new(param_types, fields)) + } + + async fn do_describe_portal( + &self, + _client: &mut C, + portal: &Portal, + ) -> PgWireResult + where + C: ClientInfo + Unpin + Send + Sync, + { + let conn = self.conn.lock().unwrap(); + let stmt = conn + .prepare_cached(&portal.statement.statement) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + row_desc_from_stmt(&stmt, &portal.result_column_format).map(DescribePortalResponse::new) + } +} + +impl DuckDBBackend { + fn new() -> DuckDBBackend { + DuckDBBackend { + conn: Arc::new(Mutex::new(Connection::open_in_memory().unwrap())), + query_parser: Arc::new(NoopQueryParser::new()), + } + } +} + +struct DuckDBBackendFactory { + handler: Arc, +} + +impl PgWireServerHandlers for DuckDBBackendFactory { + fn simple_query_handler(&self) -> Arc { + self.handler.clone() + } + + fn extended_query_handler(&self) -> Arc { + self.handler.clone() + } + + fn startup_handler(&self) -> Arc { + Arc::new(Md5PasswordAuthStartupHandler::new( + Arc::new(DummyAuthSource), + Arc::new(DefaultServerParameterProvider::default()), + )) + } +} + +#[tokio::main] +pub async fn main() { + let factory = Arc::new(DuckDBBackendFactory { + handler: Arc::new(DuckDBBackend::new()), + }); + let server_addr = "127.0.0.1:5432"; + let listener = TcpListener::bind(server_addr).await.unwrap(); + println!( + "Listening to {}, use password `pencil` to connect", + server_addr + ); + loop { + let incoming_socket = listener.accept().await.unwrap(); + let factory_ref = factory.clone(); + + tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); + } +} diff --git a/flake.nix b/flake.nix index 0136d8b..3b6f26e 100644 --- a/flake.nix +++ b/flake.nix @@ -22,6 +22,8 @@ buildInputs = with pkgs; [ llvmPackages.libclang libpq + duckdb.dev + duckdb.lib ]; in { @@ -50,6 +52,8 @@ shellHook = '' export CC=clang export CXX=clang++ + export DUCKDB_LIB_DIR="${pkgs.duckdb.lib}/lib" + export DUCKDB_INCLUDE_DIR="${pkgs.duckdb.dev}/include" ''; }; });