diff --git a/Cargo.lock b/Cargo.lock index 79ae693..640cbb7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -351,6 +351,19 @@ dependencies = [ "arrow-select", ] +[[package]] +name = "arrow-pg" +version = "0.0.1" +dependencies = [ + "arrow", + "bytes", + "chrono", + "futures", + "pgwire", + "postgres-types", + "rust_decimal", +] + [[package]] name = "arrow-row" version = "55.1.0" @@ -1518,6 +1531,7 @@ dependencies = [ name = "datafusion-postgres" version = "0.5.1" dependencies = [ + "arrow-pg", "async-trait", "bytes", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 8d85415..cd7b686 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["datafusion-postgres", "datafusion-postgres-cli"] +members = ["datafusion-postgres", "datafusion-postgres-cli", "arrow-pg"] [workspace.package] version = "0.5.1" @@ -14,8 +14,14 @@ repository = "https://github.com/datafusion-contrib/datafusion-postgres/" documentation = "https://docs.rs/crate/datafusion-postgres/" [workspace.dependencies] -pgwire = "0.30.2" +arrow = "55" +bytes = "1.10.1" +chrono = { version = "0.4", features = ["std"] } datafusion = { version = "47", default-features = false } +futures = "0.3" +pgwire = "0.30.2" +postgres-types = "0.2" +rust_decimal = { version = "1.37", features = ["db-postgres"] } tokio = { version = "1", default-features = false } [profile.release] diff --git a/arrow-pg/Cargo.toml b/arrow-pg/Cargo.toml new file mode 100644 index 0000000..0df800d --- /dev/null +++ b/arrow-pg/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "arrow-pg" +description = "Arrow data mapping and encoding/decoding for Postgres" +version = "0.0.1" +edition.workspace = true +license.workspace = true +authors.workspace = true +keywords.workspace = true +homepage.workspace = true +repository.workspace = true +documentation.workspace = true +readme = "../README.md" + +[dependencies] +arrow.workspace = true +bytes.workspace = true +chrono.workspace = true +futures.workspace = true +pgwire.workspace = true +postgres-types.workspace = true +rust_decimal.workspace = true diff --git a/arrow-pg/src/datatypes.rs b/arrow-pg/src/datatypes.rs new file mode 100644 index 0000000..06dafe2 --- /dev/null +++ b/arrow-pg/src/datatypes.rs @@ -0,0 +1,129 @@ +use std::sync::Arc; + +use arrow::datatypes::*; +use arrow::record_batch::RecordBatch; +use pgwire::api::portal::Format; +use pgwire::api::results::FieldInfo; +use pgwire::api::Type; +use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::messages::data::DataRow; +use postgres_types::Kind; + +use crate::row_encoder::RowEncoder; + +pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { + Ok(match arrow_type { + DataType::Null => Type::UNKNOWN, + DataType::Boolean => Type::BOOL, + DataType::Int8 | DataType::UInt8 => Type::CHAR, + DataType::Int16 | DataType::UInt16 => Type::INT2, + DataType::Int32 | DataType::UInt32 => Type::INT4, + DataType::Int64 | DataType::UInt64 => Type::INT8, + DataType::Timestamp(_, tz) => { + if tz.is_some() { + Type::TIMESTAMPTZ + } else { + Type::TIMESTAMP + } + } + DataType::Time32(_) | DataType::Time64(_) => Type::TIME, + DataType::Date32 | DataType::Date64 => Type::DATE, + DataType::Interval(_) => Type::INTERVAL, + DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA, + DataType::Float16 | DataType::Float32 => Type::FLOAT4, + DataType::Float64 => Type::FLOAT8, + DataType::Decimal128(_, _) => Type::NUMERIC, + DataType::Utf8 => Type::VARCHAR, + DataType::LargeUtf8 => Type::TEXT, + DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { + match field.data_type() { + DataType::Boolean => Type::BOOL_ARRAY, + DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, + DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, + DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY, + DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY, + DataType::Timestamp(_, tz) => { + if tz.is_some() { + Type::TIMESTAMPTZ_ARRAY + } else { + Type::TIMESTAMP_ARRAY + } + } + DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY, + DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY, + DataType::Interval(_) => Type::INTERVAL_ARRAY, + DataType::FixedSizeBinary(_) | DataType::Binary => Type::BYTEA_ARRAY, + DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, + DataType::Float64 => Type::FLOAT8_ARRAY, + DataType::Utf8 => Type::VARCHAR_ARRAY, + DataType::LargeUtf8 => Type::TEXT_ARRAY, + struct_type @ DataType::Struct(_) => Type::new( + Type::RECORD_ARRAY.name().into(), + Type::RECORD_ARRAY.oid(), + Kind::Array(into_pg_type(struct_type)?), + Type::RECORD_ARRAY.schema().into(), + ), + list_type => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!("Unsupported List Datatype {list_type}"), + )))); + } + } + } + DataType::Utf8View => Type::TEXT, + DataType::Dictionary(_, value_type) => into_pg_type(value_type)?, + DataType::Struct(fields) => { + let name: String = fields + .iter() + .map(|x| x.name().clone()) + .reduce(|a, b| a + ", " + &b) + .map(|x| format!("({x})")) + .unwrap_or("()".to_string()); + let kind = Kind::Composite( + fields + .iter() + .map(|x| { + into_pg_type(x.data_type()) + .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) + }) + .collect::, PgWireError>>()?, + ); + Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) + } + _ => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!("Unsupported Datatype {arrow_type}"), + )))); + } + }) +} + +pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResult> { + schema + .fields() + .iter() + .enumerate() + .map(|(idx, f)| { + let pg_type = into_pg_type(f.data_type())?; + Ok(FieldInfo::new( + f.name().into(), + None, + None, + pg_type, + format.format_for(idx), + )) + }) + .collect::>>() +} + +pub fn encode_recordbatch( + fields: Arc>, + record_batch: RecordBatch, +) -> Box>> { + let mut row_stream = RowEncoder::new(record_batch, fields); + Box::new(std::iter::from_fn(move || row_stream.next_row())) +} diff --git a/datafusion-postgres/src/encoder/mod.rs b/arrow-pg/src/encoder.rs similarity index 88% rename from datafusion-postgres/src/encoder/mod.rs rename to arrow-pg/src/encoder.rs index 233d24c..cda5ba7 100644 --- a/datafusion-postgres/src/encoder/mod.rs +++ b/arrow-pg/src/encoder.rs @@ -1,27 +1,27 @@ +use std::error::Error; use std::io::Write; use std::str::FromStr; use std::sync::Arc; +use arrow::array::*; +use arrow::datatypes::*; use bytes::BufMut; use bytes::BytesMut; use chrono::{NaiveDate, NaiveDateTime}; -use datafusion::arrow::array::*; -use datafusion::arrow::datatypes::*; -use list_encoder::encode_list; use pgwire::api::results::DataRowEncoder; use pgwire::api::results::FieldFormat; -use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::error::PgWireError; +use pgwire::error::PgWireResult; use pgwire::types::ToSqlText; use postgres_types::{ToSql, Type}; use rust_decimal::Decimal; -use struct_encoder::encode_struct; use timezone::Tz; -pub mod list_encoder; -pub mod row_encoder; -pub mod struct_encoder; +use crate::error::ToSqlError; +use crate::list_encoder::encode_list; +use crate::struct_encoder::encode_struct; -trait Encoder { +pub trait Encoder { fn encode_field_with_type_and_format( &mut self, value: &T, @@ -61,7 +61,7 @@ impl ToSql for EncodedValue { &self, _ty: &Type, out: &mut BytesMut, - ) -> Result> + ) -> Result> where Self: Sized, { @@ -80,7 +80,7 @@ impl ToSql for EncodedValue { &self, ty: &Type, out: &mut BytesMut, - ) -> Result> { + ) -> Result> { self.to_sql(ty, out) } } @@ -90,7 +90,7 @@ impl ToSqlText for EncodedValue { &self, _ty: &Type, out: &mut BytesMut, - ) -> Result> + ) -> Result> where Self: Sized, { @@ -261,16 +261,13 @@ fn get_numeric_128_value( } _ => unreachable!(), }; - PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - message.to_owned(), - ))) + // TODO: add error type in PgWireError + PgWireError::ApiError(ToSqlError::from(message)) }) .map(Some) } -fn encode_value( +pub fn encode_value( encoder: &mut T, arr: &Arc, idx: usize, @@ -387,8 +384,7 @@ fn encode_value( } let ts_array = arr.as_any().downcast_ref::().unwrap(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); @@ -411,8 +407,7 @@ fn encode_value( .downcast_ref::() .unwrap(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); @@ -435,8 +430,7 @@ fn encode_value( .downcast_ref::() .unwrap(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); @@ -459,8 +453,7 @@ fn encode_value( .downcast_ref::() .unwrap(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); @@ -483,11 +476,10 @@ fn encode_value( let fields = match type_.kind() { postgres_types::Kind::Composite(fields) => fields, _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Failed to unwrap a composite type from type {}", type_), - )))) + return Err(PgWireError::ApiError(ToSqlError::from(format!( + "Failed to unwrap a composite type from type {}", + type_ + )))); } }; let value = encode_struct(arr, idx, fields, format)?; @@ -517,14 +509,10 @@ fn encode_value( .or_else(|| get_dict_values!(UInt32Type)) .or_else(|| get_dict_values!(UInt64Type)) .ok_or_else(|| { - PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( - "Unsupported dictionary key type for value type {}", - value_type - ), - ))) + ToSqlError::from(format!( + "Unsupported dictionary key type for value type {}", + value_type + )) })?; // If the dictionary has only one value, treat it as a primitive @@ -536,15 +524,11 @@ fn encode_value( } } _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( - "Unsupported Datatype {} and array {:?}", - arr.data_type(), - &arr - ), - )))) + return Err(PgWireError::ApiError(ToSqlError::from(format!( + "Unsupported Datatype {} and array {:?}", + arr.data_type(), + &arr + )))); } } diff --git a/arrow-pg/src/error.rs b/arrow-pg/src/error.rs new file mode 100644 index 0000000..9dca31b --- /dev/null +++ b/arrow-pg/src/error.rs @@ -0,0 +1 @@ +pub type ToSqlError = Box; diff --git a/arrow-pg/src/lib.rs b/arrow-pg/src/lib.rs new file mode 100644 index 0000000..dd77bce --- /dev/null +++ b/arrow-pg/src/lib.rs @@ -0,0 +1,6 @@ +pub mod datatypes; +pub mod encoder; +mod error; +pub mod list_encoder; +pub mod row_encoder; +pub mod struct_encoder; diff --git a/datafusion-postgres/src/encoder/list_encoder.rs b/arrow-pg/src/list_encoder.rs similarity index 90% rename from datafusion-postgres/src/encoder/list_encoder.rs rename to arrow-pg/src/list_encoder.rs index a8758ef..766da3f 100644 --- a/datafusion-postgres/src/encoder/list_encoder.rs +++ b/arrow-pg/src/list_encoder.rs @@ -1,14 +1,12 @@ -use std::{error::Error, str::FromStr, sync::Arc}; +use std::{str::FromStr, sync::Arc}; -use bytes::{BufMut, BytesMut}; -use chrono::{DateTime, TimeZone, Utc}; -use datafusion::arrow::array::{ +use arrow::array::{ timezone::Tz, Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, LargeBinaryArray, PrimitiveArray, StringArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }; -use datafusion::arrow::{ +use arrow::{ datatypes::{ DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, @@ -16,15 +14,17 @@ use datafusion::arrow::{ }, temporal_conversions::{as_date, as_time}, }; -use pgwire::{ - api::results::FieldFormat, - error::{ErrorInfo, PgWireError}, - types::{ToSqlText, QUOTE_ESCAPE}, -}; +use bytes::{BufMut, BytesMut}; +use chrono::{DateTime, TimeZone, Utc}; +use pgwire::api::results::FieldFormat; +use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::types::{ToSqlText, QUOTE_ESCAPE}; use postgres_types::{ToSql, Type}; use rust_decimal::Decimal; -use super::{struct_encoder::encode_struct, EncodedValue}; +use crate::encoder::EncodedValue; +use crate::error::ToSqlError; +use crate::struct_encoder::encode_struct; fn get_bool_list_value(arr: &Arc) -> Vec> { arr.as_any() @@ -76,7 +76,7 @@ fn encode_field( t: &[T], type_: &Type, format: FieldFormat, -) -> Result> { +) -> PgWireResult { let mut bytes = BytesMut::new(); match format { FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?, @@ -89,7 +89,7 @@ pub(crate) fn encode_list( arr: Arc, type_: &Type, format: FieldFormat, -) -> Result> { +) -> PgWireResult { match arr.data_type() { DataType::Null => { let mut bytes = BytesMut::new(); @@ -228,7 +228,7 @@ pub(crate) fn encode_list( if let Some(tz) = timezone { let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + .map_err(|e| PgWireError::ApiError(ToSqlError::from(e)))?; let value: Vec<_> = array_iter .map(|i| { i.and_then(|i| { @@ -258,8 +258,7 @@ pub(crate) fn encode_list( .iter(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value: Vec<_> = array_iter .map(|i| { i.and_then(|i| { @@ -291,8 +290,7 @@ pub(crate) fn encode_list( .iter(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value: Vec<_> = array_iter .map(|i| { i.and_then(|i| { @@ -324,8 +322,7 @@ pub(crate) fn encode_list( .iter(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value: Vec<_> = array_iter .map(|i| { i.map(|i| { @@ -363,12 +360,9 @@ pub(crate) fn encode_list( type_.kind() )), }) - .map_err(|err| { - let err = ErrorInfo::new("ERROR".to_owned(), "XX000".to_owned(), err); - Box::new(PgWireError::UserError(Box::new(err))) - })?; + .map_err(ToSqlError::from)?; - let values: Result, _> = (0..arr.len()) + let values: PgWireResult> = (0..arr.len()) .map(|row| encode_struct(&arr, row, fields, format)) .map(|x| { if matches!(format, FieldFormat::Text) { @@ -396,17 +390,9 @@ pub(crate) fn encode_list( encode_field(&values?, type_, format) } // TODO: more types - list_type => { - let err = PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( - "Unsupported List Datatype {} and array {:?}", - list_type, &arr - ), - ))); - - Err(Box::new(err)) - } + list_type => Err(PgWireError::ApiError(ToSqlError::from(format!( + "Unsupported List Datatype {} and array {:?}", + list_type, &arr + )))), } } diff --git a/datafusion-postgres/src/encoder/row_encoder.rs b/arrow-pg/src/row_encoder.rs similarity index 94% rename from datafusion-postgres/src/encoder/row_encoder.rs rename to arrow-pg/src/row_encoder.rs index 9d48145..3eab8c7 100644 --- a/datafusion-postgres/src/encoder/row_encoder.rs +++ b/arrow-pg/src/row_encoder.rs @@ -1,13 +1,13 @@ use std::sync::Arc; -use datafusion::arrow::array::RecordBatch; +use arrow::array::RecordBatch; use pgwire::{ api::results::{DataRowEncoder, FieldInfo}, error::PgWireResult, messages::data::DataRow, }; -use super::encode_value; +use crate::encoder::encode_value; pub struct RowEncoder { rb: RecordBatch, diff --git a/datafusion-postgres/src/encoder/struct_encoder.rs b/arrow-pg/src/struct_encoder.rs similarity index 86% rename from datafusion-postgres/src/encoder/struct_encoder.rs rename to arrow-pg/src/struct_encoder.rs index d0b1523..96c9467 100644 --- a/datafusion-postgres/src/encoder/struct_encoder.rs +++ b/arrow-pg/src/struct_encoder.rs @@ -1,22 +1,20 @@ -use std::{error::Error, sync::Arc}; +use std::sync::Arc; +use arrow::array::{Array, StructArray}; use bytes::{BufMut, BytesMut}; -use datafusion::arrow::array::{Array, StructArray}; -use pgwire::{ - api::results::FieldFormat, - error::PgWireResult, - types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}, -}; +use pgwire::api::results::FieldFormat; +use pgwire::error::PgWireResult; +use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}; use postgres_types::{Field, IsNull, ToSql, Type}; -use super::{encode_value, EncodedValue}; +use crate::encoder::{encode_value, EncodedValue, Encoder}; -pub fn encode_struct( +pub(crate) fn encode_struct( arr: &Arc, idx: usize, fields: &[Field], format: FieldFormat, -) -> Result, Box> { +) -> PgWireResult> { let arr = arr.as_any().downcast_ref::().unwrap(); if arr.is_null(idx) { return Ok(None); @@ -32,14 +30,14 @@ pub fn encode_struct( })) } -struct StructEncoder { +pub(crate) struct StructEncoder { num_cols: usize, curr_col: usize, row_buffer: BytesMut, } impl StructEncoder { - fn new(num_cols: usize) -> Self { + pub(crate) fn new(num_cols: usize) -> Self { Self { num_cols, curr_col: 0, @@ -48,7 +46,7 @@ impl StructEncoder { } } -impl super::Encoder for StructEncoder { +impl Encoder for StructEncoder { fn encode_field_with_type_and_format( &mut self, value: &T, diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index 49b8a88..e80cb69 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -12,14 +12,15 @@ documentation.workspace = true readme = "../README.md" [dependencies] +arrow-pg = { path = "../arrow-pg", version = "0.0.1" } +bytes.workspace = true async-trait = "0.1" -bytes = "1.10.1" -chrono = { version = "0.4", features = ["std"] } -datafusion = { workspace = true } -futures = "0.3" +chrono.workspace = true +datafusion.workspace = true +futures.workspace = true getset = "0.1" log = "0.4" -pgwire = { workspace = true } -postgres-types = "0.2" -rust_decimal = { version = "1.37", features = ["db-postgres"] } +pgwire.workspace = true +postgres-types.workspace = true +rust_decimal.workspace = true tokio = { version = "1.45", features = ["sync", "net"] } diff --git a/datafusion-postgres/src/datatypes.rs b/datafusion-postgres/src/datatypes.rs index bf5223b..cbae22b 100644 --- a/datafusion-postgres/src/datatypes.rs +++ b/datafusion-postgres/src/datatypes.rs @@ -3,140 +3,27 @@ use std::sync::Arc; use chrono::{DateTime, FixedOffset}; use chrono::{NaiveDate, NaiveDateTime}; -use datafusion::arrow::datatypes::*; +use datafusion::arrow::datatypes::{DataType, Date32Type}; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::common::{DFSchema, ParamValues}; +use datafusion::common::ParamValues; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use futures::{stream, StreamExt}; use pgwire::api::portal::{Format, Portal}; -use pgwire::api::results::{FieldInfo, QueryResponse}; +use pgwire::api::results::QueryResponse; use pgwire::api::Type; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; -use postgres_types::Kind; use rust_decimal::prelude::ToPrimitive; use rust_decimal::Decimal; -use crate::encoder::row_encoder::RowEncoder; - -pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult { - Ok(match df_type { - DataType::Null => Type::UNKNOWN, - DataType::Boolean => Type::BOOL, - DataType::Int8 | DataType::UInt8 => Type::CHAR, - DataType::Int16 | DataType::UInt16 => Type::INT2, - DataType::Int32 | DataType::UInt32 => Type::INT4, - DataType::Int64 | DataType::UInt64 => Type::INT8, - DataType::Timestamp(_, tz) => { - if tz.is_some() { - Type::TIMESTAMPTZ - } else { - Type::TIMESTAMP - } - } - DataType::Time32(_) | DataType::Time64(_) => Type::TIME, - DataType::Date32 | DataType::Date64 => Type::DATE, - DataType::Interval(_) => Type::INTERVAL, - DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA, - DataType::Float16 | DataType::Float32 => Type::FLOAT4, - DataType::Float64 => Type::FLOAT8, - DataType::Decimal128(_, _) => Type::NUMERIC, - DataType::Utf8 => Type::VARCHAR, - DataType::LargeUtf8 => Type::TEXT, - DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { - match field.data_type() { - DataType::Boolean => Type::BOOL_ARRAY, - DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, - DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, - DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY, - DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY, - DataType::Timestamp(_, tz) => { - if tz.is_some() { - Type::TIMESTAMPTZ_ARRAY - } else { - Type::TIMESTAMP_ARRAY - } - } - DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY, - DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY, - DataType::Interval(_) => Type::INTERVAL_ARRAY, - DataType::FixedSizeBinary(_) | DataType::Binary => Type::BYTEA_ARRAY, - DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, - DataType::Float64 => Type::FLOAT8_ARRAY, - DataType::Utf8 => Type::VARCHAR_ARRAY, - DataType::LargeUtf8 => Type::TEXT_ARRAY, - struct_type @ DataType::Struct(_) => Type::new( - Type::RECORD_ARRAY.name().into(), - Type::RECORD_ARRAY.oid(), - Kind::Array(into_pg_type(struct_type)?), - Type::RECORD_ARRAY.schema().into(), - ), - list_type => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Unsupported List Datatype {list_type}"), - )))); - } - } - } - DataType::Utf8View => Type::TEXT, - DataType::Dictionary(_, value_type) => into_pg_type(value_type)?, - DataType::Struct(fields) => { - let name: String = fields - .iter() - .map(|x| x.name().clone()) - .reduce(|a, b| a + ", " + &b) - .map(|x| format!("({x})")) - .unwrap_or("()".to_string()); - let kind = Kind::Composite( - fields - .iter() - .map(|x| { - into_pg_type(x.data_type()) - .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) - }) - .collect::, PgWireError>>()?, - ); - Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) - } - _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Unsupported Datatype {df_type}"), - )))); - } - }) -} - -pub(crate) fn df_schema_to_pg_fields( - schema: &DFSchema, - format: &Format, -) -> PgWireResult> { - schema - .fields() - .iter() - .enumerate() - .map(|(idx, f)| { - let pg_type = into_pg_type(f.data_type())?; - Ok(FieldInfo::new( - f.name().into(), - None, - None, - pg_type, - format.format_for(idx), - )) - }) - .collect::>>() -} +use arrow_pg::datatypes::{arrow_schema_to_pg_fields, encode_recordbatch, into_pg_type}; pub(crate) async fn encode_dataframe<'a>( df: DataFrame, format: &Format, ) -> PgWireResult> { - let fields = Arc::new(df_schema_to_pg_fields(df.schema(), format)?); + let fields = Arc::new(arrow_schema_to_pg_fields(df.schema().as_arrow(), format)?); let recordbatch_stream = df .execute_stream() @@ -148,11 +35,7 @@ pub(crate) async fn encode_dataframe<'a>( .map(move |rb: datafusion::error::Result| { let row_stream: Box> + Send + Sync> = match rb { - Ok(rb) => { - let fields = fields_ref.clone(); - let mut row_stream = RowEncoder::new(rb, fields); - Box::new(std::iter::from_fn(move || row_stream.next_row())) - } + Ok(rb) => encode_recordbatch(fields_ref.clone(), rb), Err(e) => Box::new(iter::once(Err(PgWireError::ApiError(e.into())))), }; stream::iter(row_stream) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index a77f92a..1b122c1 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -16,10 +16,11 @@ use pgwire::api::results::{ use pgwire::api::stmt::QueryParser; use pgwire::api::stmt::StoredStatement; use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; +use pgwire::error::{PgWireError, PgWireResult}; use tokio::sync::Mutex; use crate::datatypes; -use pgwire::error::{PgWireError, PgWireResult}; +use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; pub struct HandlerFactory(pub Arc); @@ -237,7 +238,7 @@ impl ExtendedQueryHandler for DfSessionService { { let (_, plan) = &target.statement; let schema = plan.schema(); - let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?; + let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?; let params = plan .get_parameter_types() .map_err(|e| PgWireError::ApiError(Box::new(e)))?; @@ -246,7 +247,7 @@ impl ExtendedQueryHandler for DfSessionService { for param_type in ordered_param_types(¶ms).iter() { // Fixed: Use ¶ms if let Some(datatype) = param_type { - let pgtype = datatypes::into_pg_type(datatype)?; + let pgtype = into_pg_type(datatype)?; param_types.push(pgtype); } else { param_types.push(Type::UNKNOWN); @@ -267,7 +268,7 @@ impl ExtendedQueryHandler for DfSessionService { let (_, plan) = &target.statement.statement; let format = &target.result_column_format; let schema = plan.schema(); - let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?; + let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?; Ok(DescribePortalResponse::new(fields)) } diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index da1af1a..db957c5 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -1,5 +1,4 @@ mod datatypes; -mod encoder; mod handlers; pub mod pg_catalog;