diff --git a/Cargo.lock b/Cargo.lock index bf7d015..f1cc3bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -145,12 +145,21 @@ dependencies = [ "snap", "strum 0.27.2", "strum_macros 0.27.2", - "thiserror", + "thiserror 2.0.17", "uuid", "xz2", "zstd", ] +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "array-init" version = "2.1.0" @@ -336,12 +345,14 @@ name = "arrow-pg" version = "0.8.1" dependencies = [ "arrow", + "arrow-schema", "async-trait", "bytes", "chrono", "datafusion", "duckdb", "futures", + "geoarrow-schema", "pgwire", "postgres-types", "rust_decimal", @@ -1921,6 +1932,39 @@ dependencies = [ "version_check", ] +[[package]] +name = "geo-traits" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e7c353d12a704ccfab1ba8bfb1a7fe6cb18b665bf89d37f4f7890edcd260206" +dependencies = [ + "geo-types", +] + +[[package]] +name = "geo-types" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a4dcd69d35b2c87a7c83bce9af69fd65c9d68d3833a0ded568983928f3fc99" +dependencies = [ + "approx", + "num-traits", + "serde", +] + +[[package]] +name = "geoarrow-schema" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02f1b18b1c9a44ecd72be02e53d6e63bbccfdc8d1765206226af227327e2be6e" +dependencies = [ + "arrow-schema", + "geo-traits", + "serde", + "serde_json", + "thiserror 1.0.69", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -2602,7 +2646,7 @@ dependencies = [ "itertools", "parking_lot", "percent-encoding", - "thiserror", + "thiserror 2.0.17", "tokio", "tracing", "url", @@ -2749,7 +2793,7 @@ dependencies = [ "serde", "serde_json", "stringprep", - "thiserror", + "thiserror 2.0.17", "tokio", "tokio-rustls", "tokio-util", @@ -2835,6 +2879,7 @@ dependencies = [ "bytes", "chrono", "fallible-iterator 0.2.0", + "geo-types", "postgres-protocol", "serde_core", "serde_json", @@ -3629,13 +3674,33 @@ dependencies = [ "unicode-width 0.1.14", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + [[package]] name = "thiserror" version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.17", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -4321,7 +4386,7 @@ dependencies = [ "ring", "signature", "spki", - "thiserror", + "thiserror 2.0.17", "zeroize", ] diff --git a/Cargo.toml b/Cargo.toml index 10f37d2..8a6bc00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ documentation = "https://docs.rs/crate/datafusion-postgres/" [workspace.dependencies] arrow = "56" +arrow-schema = "56" bytes = "1.10.1" chrono = { version = "0.4", features = ["std"] } datafusion = { version = "50", default-features = false } diff --git a/arrow-pg/Cargo.toml b/arrow-pg/Cargo.toml index 92c60ca..2be4e76 100644 --- a/arrow-pg/Cargo.toml +++ b/arrow-pg/Cargo.toml @@ -13,9 +13,10 @@ readme = "../README.md" rust-version.workspace = true [features] -default = ["arrow"] +default = ["arrow", "geo"] arrow = ["dep:arrow"] datafusion = ["dep:datafusion"] +geo = ["postgres-types/with-geo-types-0_7", "dep:geoarrow-schema"] # for testing _duckdb = [] _bundled = ["duckdb/bundled"] @@ -23,6 +24,8 @@ _bundled = ["duckdb/bundled"] [dependencies] arrow = { workspace = true, optional = true } +arrow-schema = { workspace = true } +geoarrow-schema = { version = "0.6", optional = true } bytes.workspace = true chrono.workspace = true datafusion = { workspace = true, optional = true } diff --git a/arrow-pg/src/datatypes.rs b/arrow-pg/src/datatypes.rs index c3c6276..9114d33 100644 --- a/arrow-pg/src/datatypes.rs +++ b/arrow-pg/src/datatypes.rs @@ -2,6 +2,7 @@ use std::sync::Arc; #[cfg(not(feature = "datafusion"))] use arrow::{datatypes::*, record_batch::RecordBatch}; +use arrow_schema::extension::ExtensionType; #[cfg(feature = "datafusion")] use datafusion::arrow::{datatypes::*, record_batch::RecordBatch}; @@ -17,34 +18,42 @@ use crate::row_encoder::RowEncoder; #[cfg(feature = "datafusion")] pub mod df; -pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { - Ok(match arrow_type { - DataType::Null => Type::UNKNOWN, - DataType::Boolean => Type::BOOL, - DataType::Int8 | DataType::UInt8 => Type::CHAR, - DataType::Int16 | DataType::UInt16 => Type::INT2, - DataType::Int32 | DataType::UInt32 => Type::INT4, - DataType::Int64 | DataType::UInt64 => Type::INT8, - DataType::Timestamp(_, tz) => { - if tz.is_some() { - Type::TIMESTAMPTZ - } else { - Type::TIMESTAMP +pub fn into_pg_type(field: &Arc) -> PgWireResult { + let arrow_type = field.data_type(); + + match field.extension_type_name() { + #[cfg(feature = "geo")] + Some(geoarrow_schema::PointType::NAME) => Ok(Type::POINT), + _ => Ok(match arrow_type { + DataType::Null => Type::UNKNOWN, + DataType::Boolean => Type::BOOL, + DataType::Int8 | DataType::UInt8 => Type::CHAR, + DataType::Int16 | DataType::UInt16 => Type::INT2, + DataType::Int32 | DataType::UInt32 => Type::INT4, + DataType::Int64 | DataType::UInt64 => Type::INT8, + DataType::Timestamp(_, tz) => { + if tz.is_some() { + Type::TIMESTAMPTZ + } else { + Type::TIMESTAMP + } } - } - DataType::Time32(_) | DataType::Time64(_) => Type::TIME, - DataType::Date32 | DataType::Date64 => Type::DATE, - DataType::Interval(_) => Type::INTERVAL, - DataType::Binary - | DataType::FixedSizeBinary(_) - | DataType::LargeBinary - | DataType::BinaryView => Type::BYTEA, - DataType::Float16 | DataType::Float32 => Type::FLOAT4, - DataType::Float64 => Type::FLOAT8, - DataType::Decimal128(_, _) => Type::NUMERIC, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, - DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { - match field.data_type() { + DataType::Time32(_) | DataType::Time64(_) => Type::TIME, + DataType::Date32 | DataType::Date64 => Type::DATE, + DataType::Interval(_) => Type::INTERVAL, + DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::BinaryView => Type::BYTEA, + DataType::Float16 | DataType::Float32 => Type::FLOAT4, + DataType::Float64 => Type::FLOAT8, + DataType::Decimal128(_, _) => Type::NUMERIC, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) => match field.data_type() { DataType::Boolean => Type::BOOL_ARRAY, DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, @@ -67,10 +76,10 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, DataType::Float64 => Type::FLOAT8_ARRAY, DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY, - struct_type @ DataType::Struct(_) => Type::new( + DataType::Struct(_) => Type::new( Type::RECORD_ARRAY.name().into(), Type::RECORD_ARRAY.oid(), - Kind::Array(into_pg_type(struct_type)?), + Kind::Array(into_pg_type(field)?), Type::RECORD_ARRAY.schema().into(), ), list_type => { @@ -80,35 +89,42 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { format!("Unsupported List Datatype {list_type}"), )))); } + }, + DataType::Dictionary(_, value_type) => { + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + *value_type.clone(), + true, + )); + into_pg_type(&field)? } - } - DataType::Dictionary(_, value_type) => into_pg_type(value_type)?, - DataType::Struct(fields) => { - let name: String = fields - .iter() - .map(|x| x.name().clone()) - .reduce(|a, b| a + ", " + &b) - .map(|x| format!("({x})")) - .unwrap_or("()".to_string()); - let kind = Kind::Composite( - fields + DataType::Struct(fields) => { + let name: String = fields .iter() - .map(|x| { - into_pg_type(x.data_type()) - .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) - }) - .collect::, PgWireError>>()?, - ); - Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) - } - _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Unsupported Datatype {arrow_type}"), - )))); - } - }) + .map(|x| x.name().clone()) + .reduce(|a, b| a + ", " + &b) + .map(|x| format!("({x})")) + .unwrap_or("()".to_string()); + let kind = Kind::Composite( + fields + .iter() + .map(|x| { + into_pg_type(x) + .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) + }) + .collect::, PgWireError>>()?, + ); + Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) + } + _ => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!("Unsupported Datatype {arrow_type}"), + )))); + } + }), + } } pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResult> { @@ -117,7 +133,7 @@ pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResu .iter() .enumerate() .map(|(idx, f)| { - let pg_type = into_pg_type(f.data_type())?; + let pg_type = into_pg_type(f)?; Ok(FieldInfo::new( f.name().into(), None, diff --git a/arrow-pg/src/datatypes/df.rs b/arrow-pg/src/datatypes/df.rs index c81d53a..de343b2 100644 --- a/arrow-pg/src/datatypes/df.rs +++ b/arrow-pg/src/datatypes/df.rs @@ -2,7 +2,7 @@ use std::iter; use std::sync::Arc; use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; -use datafusion::arrow::datatypes::{DataType, Date32Type}; +use datafusion::arrow::datatypes::{DataType, Date32Type, Field}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::ParamValues; use datafusion::prelude::*; @@ -61,7 +61,7 @@ where if let Some(ty) = pg_type_hint { Ok(ty.clone()) } else if let Some(infer_type) = inferenced_type { - into_pg_type(infer_type) + into_pg_type(&Arc::new(Field::new("item", infer_type.clone(), true))) } else { Ok(Type::UNKNOWN) } diff --git a/arrow-pg/src/encoder.rs b/arrow-pg/src/encoder.rs index 074939c..65d7f51 100644 --- a/arrow-pg/src/encoder.rs +++ b/arrow-pg/src/encoder.rs @@ -12,6 +12,7 @@ use chrono::{NaiveDate, NaiveDateTime}; use datafusion::arrow::{array::*, datatypes::*}; use pgwire::api::results::DataRowEncoder; use pgwire::api::results::FieldFormat; +use pgwire::api::results::FieldInfo; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::types::ToSqlText; use postgres_types::{ToSql, Type}; @@ -288,9 +289,12 @@ pub fn encode_value( encoder: &mut T, arr: &Arc, idx: usize, - type_: &Type, - format: FieldFormat, + _arrow_filed: &Field, + pg_field: &FieldInfo, ) -> PgWireResult<()> { + let type_ = pg_field.datatype(); + let format = pg_field.format(); + match arr.data_type() { DataType::Null => encoder.encode_field_with_type_and_format(&None::, type_, format)?, DataType::Boolean => { @@ -494,7 +498,7 @@ pub fn encode_value( let value = encode_list(array, type_, format)?; encoder.encode_field_with_type_and_format(&value, type_, format)? } - DataType::Struct(_) => { + DataType::Struct(arrow_fields) => { let fields = match type_.kind() { postgres_types::Kind::Composite(fields) => fields, _ => { @@ -503,7 +507,7 @@ pub fn encode_value( )))); } }; - let value = encode_struct(arr, idx, fields, format)?; + let value = encode_struct(arr, idx, arrow_fields, fields, format)?; encoder.encode_field_with_type_and_format(&value, type_, format)? } DataType::Dictionary(_, value_type) => { @@ -534,7 +538,16 @@ pub fn encode_value( )) })?; - encode_value(encoder, values, idx, type_, format)? + let inner_pg_field = FieldInfo::new( + pg_field.name().to_string(), + None, + None, + type_.clone(), + format, + ); + let inner_arrow_field = Field::new(pg_field.name(), *value_type.clone(), true); + + encode_value(encoder, values, idx, &inner_arrow_field, &inner_pg_field)? } _ => { return Err(PgWireError::ApiError(ToSqlError::from(format!( @@ -585,7 +598,10 @@ mod tests { let mut encoder = MockEncoder::default(); - let result = encode_value(&mut encoder, &dict_arr, 2, &Type::TEXT, FieldFormat::Text); + let arrow_field = Field::new("x", DataType::Utf8, true); + let pg_field = FieldInfo::new("x".to_string(), None, None, Type::TEXT, FieldFormat::Text); + + let result = encode_value(&mut encoder, &dict_arr, 2, &arrow_field, &pg_field); assert!(result.is_ok()); diff --git a/arrow-pg/src/list_encoder.rs b/arrow-pg/src/list_encoder.rs index a13c1c7..ae9893b 100644 --- a/arrow-pg/src/list_encoder.rs +++ b/arrow-pg/src/list_encoder.rs @@ -386,7 +386,7 @@ pub(crate) fn encode_list( } } }, - DataType::Struct(_) => { + DataType::Struct(arrow_fields) => { let fields = match type_.kind() { postgres_types::Kind::Array(struct_type_) => Ok(struct_type_), _ => Err(format!( @@ -406,7 +406,7 @@ pub(crate) fn encode_list( .map_err(ToSqlError::from)?; let values: PgWireResult> = (0..arr.len()) - .map(|row| encode_struct(&arr, row, fields, format)) + .map(|row| encode_struct(&arr, row, arrow_fields, fields, format)) .map(|x| { if matches!(format, FieldFormat::Text) { x.map(|opt| { diff --git a/arrow-pg/src/row_encoder.rs b/arrow-pg/src/row_encoder.rs index 145c9ab..c8a73f6 100644 --- a/arrow-pg/src/row_encoder.rs +++ b/arrow-pg/src/row_encoder.rs @@ -33,13 +33,14 @@ impl RowEncoder { if self.curr_idx == self.rb.num_rows() { return None; } + let arrow_schema = self.rb.schema_ref(); let mut encoder = DataRowEncoder::new(self.fields.clone()); for col in 0..self.rb.num_columns() { let array = self.rb.column(col); - let field = &self.fields[col]; - let type_ = field.datatype(); - let format = field.format(); - encode_value(&mut encoder, array, self.curr_idx, type_, format).unwrap(); + let arrow_field = arrow_schema.field(col); + let pg_field = &self.fields[col]; + + encode_value(&mut encoder, array, self.curr_idx, arrow_field, pg_field).unwrap(); } self.curr_idx += 1; Some(encoder.finish()) diff --git a/arrow-pg/src/struct_encoder.rs b/arrow-pg/src/struct_encoder.rs index 49fce1b..1224797 100644 --- a/arrow-pg/src/struct_encoder.rs +++ b/arrow-pg/src/struct_encoder.rs @@ -2,11 +2,12 @@ use std::sync::Arc; #[cfg(not(feature = "datafusion"))] use arrow::array::{Array, StructArray}; +use arrow_schema::Fields; #[cfg(feature = "datafusion")] use datafusion::arrow::array::{Array, StructArray}; use bytes::{BufMut, BytesMut}; -use pgwire::api::results::FieldFormat; +use pgwire::api::results::{FieldFormat, FieldInfo}; use pgwire::error::PgWireResult; use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}; use postgres_types::{Field, IsNull, ToSql, Type}; @@ -16,6 +17,7 @@ use crate::encoder::{encode_value, EncodedValue, Encoder}; pub(crate) fn encode_struct( arr: &Arc, idx: usize, + arrow_fields: &Fields, fields: &[Field], format: FieldFormat, ) -> PgWireResult> { @@ -27,7 +29,11 @@ pub(crate) fn encode_struct( for (i, arr) in arr.columns().iter().enumerate() { let field = &fields[i]; let type_ = field.type_(); - encode_value(&mut row_encoder, arr, idx, type_, format).unwrap(); + + let arrow_field = &arrow_fields[i]; + let pgwire_field = FieldInfo::new("fields".to_string(), None, None, type_.clone(), format); + + encode_value(&mut row_encoder, arr, idx, arrow_field, &pgwire_field).unwrap(); } Ok(Some(EncodedValue { bytes: row_encoder.row_buffer, diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 55fafd2..b6069f3 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -391,7 +391,8 @@ impl ExtendedQueryHandler for DfSessionService { for param_type in ordered_param_types(¶ms).iter() { // Fixed: Use ¶ms if let Some(datatype) = param_type { - let pgtype = into_pg_type(datatype)?; + let pgtype = + into_pg_type(&Arc::new(Field::new("item", (*datatype).clone(), true)))?; param_types.push(pgtype); } else { param_types.push(Type::UNKNOWN);