diff --git a/Cargo.lock b/Cargo.lock index fbd513e..0bbe61f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2729,8 +2729,7 @@ dependencies = [ [[package]] name = "pgwire" version = "0.34.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f56a81b4fcc69016028f657a68f9b8e8a2a4b7d07684ca3298f2d3e7ff199ce" +source = "git+https://github.com/sunng87/pgwire?rev=ad1e31a4b3fdc325eeb57d0cfe0dc1798b6687a6#ad1e31a4b3fdc325eeb57d0cfe0dc1798b6687a6" dependencies = [ "async-trait", "base64", diff --git a/Cargo.toml b/Cargo.toml index a8a945f..a235956 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,8 @@ bytes = "1.10.1" chrono = { version = "0.4", features = ["std"] } datafusion = { version = "50", default-features = false } futures = "0.3" -pgwire = { version = "0.34", default-features = false } +#pgwire = { version = "0.34", default-features = false } +pgwire = { git = "https://github.com/sunng87/pgwire", rev = "ad1e31a4b3fdc325eeb57d0cfe0dc1798b6687a6", default-features = false } postgres-types = "0.2" rust_decimal = { version = "1.39", features = ["db-postgres"] } tokio = { version = "1", default-features = false } diff --git a/arrow-pg/examples/duckdb.rs b/arrow-pg/examples/duckdb.rs index c8d2dfa..29faa1e 100644 --- a/arrow-pg/examples/duckdb.rs +++ b/arrow-pg/examples/duckdb.rs @@ -19,6 +19,7 @@ use pgwire::api::stmt::{NoopQueryParser, StoredStatement}; use pgwire::api::{ClientInfo, PgWireServerHandlers, Type}; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::tokio::process_socket; +use pgwire::types::format::FormatOptions; use tokio::net::TcpListener; pub struct DuckDBBackend { @@ -45,7 +46,7 @@ impl AuthSource for DummyAuthSource { #[async_trait] impl SimpleQueryHandler for DuckDBBackend { - async fn do_query(&self, _client: &mut C, query: &str) -> PgWireResult> + async fn do_query(&self, client: &mut C, query: &str) -> PgWireResult> where C: ClientInfo + Unpin + Send + Sync, { @@ -59,9 +60,12 @@ impl SimpleQueryHandler for DuckDBBackend { .query_arrow(params![]) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let schema = ret.get_schema(); + let format_options = FormatOptions::from_client_metadata(client.metadata()); + let header = Arc::new(arrow_schema_to_pg_fields( schema.as_ref(), &Format::UnifiedText, + Some(Arc::new(format_options)), )?); let header_ref = header.clone(); @@ -155,7 +159,7 @@ impl ExtendedQueryHandler for DuckDBBackend { async fn do_query( &self, - _client: &mut C, + client: &mut C, portal: &Portal, _max_rows: usize, ) -> PgWireResult @@ -178,9 +182,11 @@ impl ExtendedQueryHandler for DuckDBBackend { .query_arrow(params![]) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let schema = ret.get_schema(); + let format_options = FormatOptions::from_client_metadata(client.metadata()); let header = Arc::new(arrow_schema_to_pg_fields( schema.as_ref(), &Format::UnifiedText, + Some(Arc::new(format_options)), )?); let header_ref = header.clone(); diff --git a/arrow-pg/src/datatypes.rs b/arrow-pg/src/datatypes.rs index c3c6276..d61e648 100644 --- a/arrow-pg/src/datatypes.rs +++ b/arrow-pg/src/datatypes.rs @@ -10,6 +10,7 @@ use pgwire::api::results::FieldInfo; use pgwire::api::Type; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; +use pgwire::types::format::FormatOptions; use postgres_types::Kind; use crate::row_encoder::RowEncoder; @@ -111,20 +112,25 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { }) } -pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResult> { +pub fn arrow_schema_to_pg_fields( + schema: &Schema, + format: &Format, + data_format_options: Option>, +) -> PgWireResult> { + let _ = data_format_options; schema .fields() .iter() .enumerate() .map(|(idx, f)| { let pg_type = into_pg_type(f.data_type())?; - Ok(FieldInfo::new( - f.name().into(), - None, - None, - pg_type, - format.format_for(idx), - )) + let mut field_info = + FieldInfo::new(f.name().into(), None, None, pg_type, format.format_for(idx)); + if let Some(data_format_options) = &data_format_options { + field_info = field_info.with_format_options(data_format_options.clone()); + } + + Ok(field_info) }) .collect::>>() } diff --git a/arrow-pg/src/datatypes/df.rs b/arrow-pg/src/datatypes/df.rs index c81d53a..2959128 100644 --- a/arrow-pg/src/datatypes/df.rs +++ b/arrow-pg/src/datatypes/df.rs @@ -13,13 +13,22 @@ use pgwire::api::results::QueryResponse; use pgwire::api::Type; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; +use pgwire::types::format::FormatOptions; use rust_decimal::prelude::ToPrimitive; use rust_decimal::Decimal; use super::{arrow_schema_to_pg_fields, encode_recordbatch, into_pg_type}; -pub async fn encode_dataframe(df: DataFrame, format: &Format) -> PgWireResult { - let fields = Arc::new(arrow_schema_to_pg_fields(df.schema().as_arrow(), format)?); +pub async fn encode_dataframe( + df: DataFrame, + format: &Format, + data_format_options: Option>, +) -> PgWireResult { + let fields = Arc::new(arrow_schema_to_pg_fields( + df.schema().as_arrow(), + format, + data_format_options, + )?); let recordbatch_stream = df .execute_stream() diff --git a/arrow-pg/src/encoder.rs b/arrow-pg/src/encoder.rs index 8ac10da..d0eb72b 100644 --- a/arrow-pg/src/encoder.rs +++ b/arrow-pg/src/encoder.rs @@ -10,9 +10,9 @@ use bytes::BytesMut; use chrono::{NaiveDate, NaiveDateTime}; #[cfg(feature = "datafusion")] use datafusion::arrow::{array::*, datatypes::*}; -use pgwire::api::results::DataRowEncoder; -use pgwire::api::results::FieldFormat; +use pgwire::api::results::{DataRowEncoder, FieldInfo}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::types::format::FormatOptions; use pgwire::types::ToSqlText; use postgres_types::{ToSql, Type}; use rust_decimal::Decimal; @@ -23,27 +23,22 @@ use crate::list_encoder::encode_list; use crate::struct_encoder::encode_struct; pub trait Encoder { - fn encode_field_with_type_and_format( - &mut self, - value: &T, - data_type: &Type, - format: FieldFormat, - ) -> PgWireResult<()> + fn encode_field(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()> where T: ToSql + ToSqlText + Sized; } impl Encoder for DataRowEncoder { - fn encode_field_with_type_and_format( - &mut self, - value: &T, - data_type: &Type, - format: FieldFormat, - ) -> PgWireResult<()> + fn encode_field(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()> where T: ToSql + ToSqlText + Sized, { - self.encode_field_with_type_and_format(value, data_type, format) + self.encode_field_with_type_and_format( + value, + pg_field.datatype(), + pg_field.format(), + pg_field.format_options(), + ) } } @@ -91,6 +86,7 @@ impl ToSqlText for EncodedValue { &self, _ty: &Type, out: &mut BytesMut, + _format_options: &FormatOptions, ) -> Result> where Self: Sized, @@ -291,121 +287,64 @@ pub fn encode_value( encoder: &mut T, arr: &Arc, idx: usize, - type_: &Type, - format: FieldFormat, + pg_field: &FieldInfo, ) -> PgWireResult<()> { + let type_ = pg_field.datatype(); + match arr.data_type() { - DataType::Null => encoder.encode_field_with_type_and_format(&None::, type_, format)?, - DataType::Boolean => { - encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)? - } - DataType::Int8 => { - encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)? - } - DataType::Int16 => { - encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)? - } - DataType::Int32 => { - encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)? + DataType::Null => encoder.encode_field(&None::, pg_field)?, + DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx), pg_field)?, + DataType::Int8 => encoder.encode_field(&get_i8_value(arr, idx), pg_field)?, + DataType::Int16 => encoder.encode_field(&get_i16_value(arr, idx), pg_field)?, + DataType::Int32 => encoder.encode_field(&get_i32_value(arr, idx), pg_field)?, + DataType::Int64 => encoder.encode_field(&get_i64_value(arr, idx), pg_field)?, + DataType::UInt8 => { + encoder.encode_field(&(get_u8_value(arr, idx).map(|x| x as i8)), pg_field)? } - DataType::Int64 => { - encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)? + DataType::UInt16 => { + encoder.encode_field(&(get_u16_value(arr, idx).map(|x| x as i16)), pg_field)? } - DataType::UInt8 => encoder.encode_field_with_type_and_format( - &(get_u8_value(arr, idx).map(|x| x as i8)), - type_, - format, - )?, - DataType::UInt16 => encoder.encode_field_with_type_and_format( - &(get_u16_value(arr, idx).map(|x| x as i16)), - type_, - format, - )?, - DataType::UInt32 => { - encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)? + DataType::UInt32 => encoder.encode_field(&get_u32_value(arr, idx), pg_field)?, + DataType::UInt64 => { + encoder.encode_field(&(get_u64_value(arr, idx).map(|x| x as i64)), pg_field)? } - DataType::UInt64 => encoder.encode_field_with_type_and_format( - &(get_u64_value(arr, idx).map(|x| x as i64)), - type_, - format, - )?, - DataType::Float32 => { - encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)? + DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx), pg_field)?, + DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx), pg_field)?, + DataType::Decimal128(_, s) => { + encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?, pg_field)? } - DataType::Float64 => { - encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)? - } - DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format( - &get_numeric_128_value(arr, idx, *s as u32)?, - type_, - format, - )?, - DataType::Utf8 => { - encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)? - } - DataType::Utf8View => encoder.encode_field_with_type_and_format( - &get_utf8_view_value(arr, idx), - type_, - format, - )?, - DataType::BinaryView => encoder.encode_field_with_type_and_format( - &get_binary_view_value(arr, idx), - type_, - format, - )?, - DataType::LargeUtf8 => encoder.encode_field_with_type_and_format( - &get_large_utf8_value(arr, idx), - type_, - format, - )?, - DataType::Binary => { - encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)? - } - DataType::LargeBinary => encoder.encode_field_with_type_and_format( - &get_large_binary_value(arr, idx), - type_, - format, - )?, - DataType::Date32 => { - encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)? - } - DataType::Date64 => { - encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)? + DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx), pg_field)?, + DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx), pg_field)?, + DataType::BinaryView => encoder.encode_field(&get_binary_view_value(arr, idx), pg_field)?, + DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx), pg_field)?, + DataType::Binary => encoder.encode_field(&get_binary_value(arr, idx), pg_field)?, + DataType::LargeBinary => { + encoder.encode_field(&get_large_binary_value(arr, idx), pg_field)? } + DataType::Date32 => encoder.encode_field(&get_date32_value(arr, idx), pg_field)?, + DataType::Date64 => encoder.encode_field(&get_date64_value(arr, idx), pg_field)?, DataType::Time32(unit) => match unit { - TimeUnit::Second => encoder.encode_field_with_type_and_format( - &get_time32_second_value(arr, idx), - type_, - format, - )?, - TimeUnit::Millisecond => encoder.encode_field_with_type_and_format( - &get_time32_millisecond_value(arr, idx), - type_, - format, - )?, + TimeUnit::Second => { + encoder.encode_field(&get_time32_second_value(arr, idx), pg_field)? + } + TimeUnit::Millisecond => { + encoder.encode_field(&get_time32_millisecond_value(arr, idx), pg_field)? + } _ => {} }, DataType::Time64(unit) => match unit { - TimeUnit::Microsecond => encoder.encode_field_with_type_and_format( - &get_time64_microsecond_value(arr, idx), - type_, - format, - )?, - TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format( - &get_time64_nanosecond_value(arr, idx), - type_, - format, - )?, + TimeUnit::Microsecond => { + encoder.encode_field(&get_time64_microsecond_value(arr, idx), pg_field)? + } + TimeUnit::Nanosecond => { + encoder.encode_field(&get_time64_nanosecond_value(arr, idx), pg_field)? + } _ => {} }, DataType::Timestamp(unit, timezone) => match unit { TimeUnit::Second => { if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format( - &None::, - type_, - format, - ); + return encoder.encode_field(&None::, pg_field); } let ts_array = arr.as_any().downcast_ref::().unwrap(); if let Some(tz) = timezone { @@ -413,19 +352,16 @@ pub fn encode_value( let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); - encoder.encode_field_with_type_and_format(&value, type_, format)?; + + encoder.encode_field(&value, pg_field)?; } else { let value = ts_array.value_as_datetime(idx); - encoder.encode_field_with_type_and_format(&value, type_, format)?; + encoder.encode_field(&value, pg_field)?; } } TimeUnit::Millisecond => { if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format( - &None::, - type_, - format, - ); + return encoder.encode_field(&None::, pg_field); } let ts_array = arr .as_any() @@ -436,19 +372,15 @@ pub fn encode_value( let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); - encoder.encode_field_with_type_and_format(&value, type_, format)?; + encoder.encode_field(&value, pg_field)?; } else { let value = ts_array.value_as_datetime(idx); - encoder.encode_field_with_type_and_format(&value, type_, format)?; + encoder.encode_field(&value, pg_field)?; } } TimeUnit::Microsecond => { if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format( - &None::, - type_, - format, - ); + return encoder.encode_field(&None::, pg_field); } let ts_array = arr .as_any() @@ -459,19 +391,15 @@ pub fn encode_value( let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); - encoder.encode_field_with_type_and_format(&value, type_, format)?; + encoder.encode_field(&value, pg_field)?; } else { let value = ts_array.value_as_datetime(idx); - encoder.encode_field_with_type_and_format(&value, type_, format)?; + encoder.encode_field(&value, pg_field)?; } } TimeUnit::Nanosecond => { if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format( - &None::, - type_, - format, - ); + return encoder.encode_field(&None::, pg_field); } let ts_array = arr .as_any() @@ -482,20 +410,20 @@ pub fn encode_value( let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); - encoder.encode_field_with_type_and_format(&value, type_, format)?; + encoder.encode_field(&value, pg_field)?; } else { let value = ts_array.value_as_datetime(idx); - encoder.encode_field_with_type_and_format(&value, type_, format)?; + encoder.encode_field(&value, pg_field)?; } } }, DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format); + return encoder.encode_field(&None::<&[i8]>, pg_field); } let array = arr.as_any().downcast_ref::().unwrap().value(idx); - let value = encode_list(array, type_, format)?; - encoder.encode_field_with_type_and_format(&value, type_, format)? + let value = encode_list(array, pg_field)?; + encoder.encode_field(&value, pg_field)? } DataType::Struct(_) => { let fields = match type_.kind() { @@ -506,12 +434,12 @@ pub fn encode_value( )))); } }; - let value = encode_struct(arr, idx, fields, format)?; - encoder.encode_field_with_type_and_format(&value, type_, format)? + let value = encode_struct(arr, idx, fields, pg_field)?; + encoder.encode_field(&value, pg_field)? } DataType::Dictionary(_, value_type) => { if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format(&None::, type_, format); + return encoder.encode_field(&None::, pg_field); } // Get the dictionary values and the mapped row index macro_rules! get_dict_values_and_index { @@ -537,7 +465,7 @@ pub fn encode_value( )) })?; - encode_value(encoder, values, idx, type_, format)? + encode_value(encoder, values, idx, pg_field)? } _ => { return Err(PgWireError::ApiError(ToSqlError::from(format!( @@ -553,6 +481,8 @@ pub fn encode_value( #[cfg(test)] mod tests { + use pgwire::api::results::FieldFormat; + use super::*; #[test] @@ -563,17 +493,13 @@ mod tests { } impl Encoder for MockEncoder { - fn encode_field_with_type_and_format( - &mut self, - value: &T, - data_type: &Type, - _format: FieldFormat, - ) -> PgWireResult<()> + fn encode_field(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()> where T: ToSql + ToSqlText + Sized, { let mut bytes = BytesMut::new(); - let _sql_text = value.to_sql_text(data_type, &mut bytes); + let _sql_text = + value.to_sql_text(pg_field.datatype(), &mut bytes, &FormatOptions::default()); let string = String::from_utf8(bytes.to_vec()); self.encoded_value = string.unwrap(); Ok(()) @@ -588,7 +514,8 @@ mod tests { let mut encoder = MockEncoder::default(); - let result = encode_value(&mut encoder, &dict_arr, 2, &Type::TEXT, FieldFormat::Text); + let pg_field = FieldInfo::new("x".to_string(), None, None, Type::TEXT, FieldFormat::Text); + let result = encode_value(&mut encoder, &dict_arr, 2, &pg_field); assert!(result.is_ok()); diff --git a/arrow-pg/src/lib.rs b/arrow-pg/src/lib.rs index e33a375..6ee7e2b 100644 --- a/arrow-pg/src/lib.rs +++ b/arrow-pg/src/lib.rs @@ -9,3 +9,8 @@ mod error; pub mod list_encoder; pub mod row_encoder; pub mod struct_encoder; + +#[cfg(feature = "datafusion")] +pub use datatypes::df::encode_dataframe; + +pub use datatypes::encode_recordbatch; diff --git a/arrow-pg/src/list_encoder.rs b/arrow-pg/src/list_encoder.rs index a13c1c7..d7dca3d 100644 --- a/arrow-pg/src/list_encoder.rs +++ b/arrow-pg/src/list_encoder.rs @@ -37,10 +37,10 @@ use datafusion::arrow::{ use bytes::{BufMut, BytesMut}; use chrono::{DateTime, TimeZone, Utc}; -use pgwire::api::results::FieldFormat; +use pgwire::api::results::{FieldFormat, FieldInfo}; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::types::{ToSqlText, QUOTE_ESCAPE}; -use postgres_types::{ToSql, Type}; +use postgres_types::ToSql; use rust_decimal::Decimal; use crate::encoder::EncodedValue; @@ -93,44 +93,44 @@ get_primitive_list_value!(get_u64_list_value, UInt64Type, i64, |val: u64| { get_primitive_list_value!(get_f32_list_value, Float32Type, f32); get_primitive_list_value!(get_f64_list_value, Float64Type, f64); -fn encode_field( - t: &[T], - type_: &Type, - format: FieldFormat, -) -> PgWireResult { +fn encode_field(t: &[T], field: &FieldInfo) -> PgWireResult { let mut bytes = BytesMut::new(); + + let format = field.format(); + let type_ = field.datatype(); match format { - FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?, + FieldFormat::Text => t.to_sql_text(type_, &mut bytes, field.format_options().as_ref())?, FieldFormat::Binary => t.to_sql(type_, &mut bytes)?, }; Ok(EncodedValue { bytes }) } -pub(crate) fn encode_list( - arr: Arc, - type_: &Type, - format: FieldFormat, -) -> PgWireResult { +pub(crate) fn encode_list(arr: Arc, pg_field: &FieldInfo) -> PgWireResult { + let type_ = pg_field.datatype(); + let format = pg_field.format(); + match arr.data_type() { DataType::Null => { let mut bytes = BytesMut::new(); match format { - FieldFormat::Text => None::.to_sql_text(type_, &mut bytes), + FieldFormat::Text => { + None::.to_sql_text(type_, &mut bytes, pg_field.format_options().as_ref()) + } FieldFormat::Binary => None::.to_sql(type_, &mut bytes), }?; Ok(EncodedValue { bytes }) } - DataType::Boolean => encode_field(&get_bool_list_value(&arr), type_, format), - DataType::Int8 => encode_field(&get_i8_list_value(&arr), type_, format), - DataType::Int16 => encode_field(&get_i16_list_value(&arr), type_, format), - DataType::Int32 => encode_field(&get_i32_list_value(&arr), type_, format), - DataType::Int64 => encode_field(&get_i64_list_value(&arr), type_, format), - DataType::UInt8 => encode_field(&get_u8_list_value(&arr), type_, format), - DataType::UInt16 => encode_field(&get_u16_list_value(&arr), type_, format), - DataType::UInt32 => encode_field(&get_u32_list_value(&arr), type_, format), - DataType::UInt64 => encode_field(&get_u64_list_value(&arr), type_, format), - DataType::Float32 => encode_field(&get_f32_list_value(&arr), type_, format), - DataType::Float64 => encode_field(&get_f64_list_value(&arr), type_, format), + DataType::Boolean => encode_field(&get_bool_list_value(&arr), pg_field), + DataType::Int8 => encode_field(&get_i8_list_value(&arr), pg_field), + DataType::Int16 => encode_field(&get_i16_list_value(&arr), pg_field), + DataType::Int32 => encode_field(&get_i32_list_value(&arr), pg_field), + DataType::Int64 => encode_field(&get_i64_list_value(&arr), pg_field), + DataType::UInt8 => encode_field(&get_u8_list_value(&arr), pg_field), + DataType::UInt16 => encode_field(&get_u16_list_value(&arr), pg_field), + DataType::UInt32 => encode_field(&get_u32_list_value(&arr), pg_field), + DataType::UInt64 => encode_field(&get_u64_list_value(&arr), pg_field), + DataType::Float32 => encode_field(&get_f32_list_value(&arr), pg_field), + DataType::Float64 => encode_field(&get_f64_list_value(&arr), pg_field), DataType::Decimal128(_, s) => { let value: Vec<_> = arr .as_any() @@ -139,7 +139,7 @@ pub(crate) fn encode_list( .iter() .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32))) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Utf8 => { let value: Vec> = arr @@ -148,7 +148,7 @@ pub(crate) fn encode_list( .unwrap() .iter() .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Utf8View => { let value: Vec> = arr @@ -157,7 +157,7 @@ pub(crate) fn encode_list( .unwrap() .iter() .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Binary => { let value: Vec> = arr @@ -166,7 +166,7 @@ pub(crate) fn encode_list( .unwrap() .iter() .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::LargeBinary => { let value: Vec> = arr @@ -175,7 +175,7 @@ pub(crate) fn encode_list( .unwrap() .iter() .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::BinaryView => { let value: Vec> = arr @@ -184,7 +184,7 @@ pub(crate) fn encode_list( .unwrap() .iter() .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Date32 => { @@ -195,7 +195,7 @@ pub(crate) fn encode_list( .iter() .map(|val| val.and_then(|x| as_date::(x as i64))) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Date64 => { let value: Vec> = arr @@ -205,7 +205,7 @@ pub(crate) fn encode_list( .iter() .map(|val| val.and_then(as_date::)) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Time32(unit) => match unit { TimeUnit::Second => { @@ -216,7 +216,7 @@ pub(crate) fn encode_list( .iter() .map(|val| val.and_then(|x| as_time::(x as i64))) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } TimeUnit::Millisecond => { let value: Vec> = arr @@ -226,7 +226,7 @@ pub(crate) fn encode_list( .iter() .map(|val| val.and_then(|x| as_time::(x as i64))) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } _ => { // Time32 only supports Second and Millisecond in Arrow @@ -243,7 +243,7 @@ pub(crate) fn encode_list( .iter() .map(|val| val.and_then(as_time::)) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } TimeUnit::Nanosecond => { let value: Vec> = arr @@ -253,7 +253,7 @@ pub(crate) fn encode_list( .iter() .map(|val| val.and_then(as_time::)) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } _ => { // Time64 only supports Microsecond and Nanosecond in Arrow @@ -283,14 +283,14 @@ pub(crate) fn encode_list( }) }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } else { let value: Vec<_> = array_iter .map(|i| { i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc())) }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } } TimeUnit::Millisecond => { @@ -313,7 +313,7 @@ pub(crate) fn encode_list( }) }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } else { let value: Vec<_> = array_iter .map(|i| { @@ -322,7 +322,7 @@ pub(crate) fn encode_list( }) }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } } TimeUnit::Microsecond => { @@ -345,7 +345,7 @@ pub(crate) fn encode_list( }) }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } else { let value: Vec<_> = array_iter .map(|i| { @@ -354,7 +354,7 @@ pub(crate) fn encode_list( }) }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } } TimeUnit::Nanosecond => { @@ -377,12 +377,12 @@ pub(crate) fn encode_list( }) }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } else { let value: Vec<_> = array_iter .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc())) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } } }, @@ -406,7 +406,7 @@ pub(crate) fn encode_list( .map_err(ToSqlError::from)?; let values: PgWireResult> = (0..arr.len()) - .map(|row| encode_struct(&arr, row, fields, format)) + .map(|row| encode_struct(&arr, row, fields, pg_field)) .map(|x| { if matches!(format, FieldFormat::Text) { x.map(|opt| { @@ -430,7 +430,7 @@ pub(crate) fn encode_list( } }) .collect(); - encode_field(&values?, type_, format) + encode_field(&values?, pg_field) } DataType::LargeUtf8 => { let value: Vec> = arr @@ -439,7 +439,7 @@ pub(crate) fn encode_list( .unwrap() .iter() .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Decimal256(_, s) => { // Convert Decimal256 to string representation for now @@ -474,7 +474,7 @@ pub(crate) fn encode_list( } }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Duration(_) => { // Convert duration to microseconds for now @@ -484,7 +484,7 @@ pub(crate) fn encode_list( .unwrap() .iter() .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::List(_) => { // Support for nested lists (list of lists) @@ -500,7 +500,7 @@ pub(crate) fn encode_list( } }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::LargeList(_) => { // Support for large lists @@ -514,7 +514,7 @@ pub(crate) fn encode_list( } }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Map(_, _) => { // Support for map types @@ -528,7 +528,7 @@ pub(crate) fn encode_list( } }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Union(_, _) => { @@ -542,7 +542,7 @@ pub(crate) fn encode_list( } }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } DataType::Dictionary(_, _) => { // Support for dictionary types @@ -555,7 +555,7 @@ pub(crate) fn encode_list( } }) .collect(); - encode_field(&value, type_, format) + encode_field(&value, pg_field) } // TODO: add support for more advanced types (fixed size lists, etc.) list_type => Err(PgWireError::ApiError(ToSqlError::from(format!( diff --git a/arrow-pg/src/row_encoder.rs b/arrow-pg/src/row_encoder.rs index 145c9ab..674751b 100644 --- a/arrow-pg/src/row_encoder.rs +++ b/arrow-pg/src/row_encoder.rs @@ -37,9 +37,8 @@ impl RowEncoder { for col in 0..self.rb.num_columns() { let array = self.rb.column(col); let field = &self.fields[col]; - let type_ = field.datatype(); - let format = field.format(); - encode_value(&mut encoder, array, self.curr_idx, type_, format).unwrap(); + + encode_value(&mut encoder, array, self.curr_idx, field).unwrap(); } self.curr_idx += 1; Some(encoder.finish()) diff --git a/arrow-pg/src/struct_encoder.rs b/arrow-pg/src/struct_encoder.rs index 49fce1b..ad86b96 100644 --- a/arrow-pg/src/struct_encoder.rs +++ b/arrow-pg/src/struct_encoder.rs @@ -6,10 +6,10 @@ use arrow::array::{Array, StructArray}; use datafusion::arrow::array::{Array, StructArray}; use bytes::{BufMut, BytesMut}; -use pgwire::api::results::FieldFormat; +use pgwire::api::results::{FieldFormat, FieldInfo}; use pgwire::error::PgWireResult; use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}; -use postgres_types::{Field, IsNull, ToSql, Type}; +use postgres_types::{Field, IsNull, ToSql}; use crate::encoder::{encode_value, EncodedValue, Encoder}; @@ -17,7 +17,7 @@ pub(crate) fn encode_struct( arr: &Arc, idx: usize, fields: &[Field], - format: FieldFormat, + parent_pg_field_info: &FieldInfo, ) -> PgWireResult> { let arr = arr.as_any().downcast_ref::().unwrap(); if arr.is_null(idx) { @@ -27,7 +27,18 @@ pub(crate) fn encode_struct( for (i, arr) in arr.columns().iter().enumerate() { let field = &fields[i]; let type_ = field.type_(); - encode_value(&mut row_encoder, arr, idx, type_, format).unwrap(); + + let mut pg_field = FieldInfo::new( + field.name().to_string(), + None, + None, + type_.clone(), + parent_pg_field_info.format(), + ); + + pg_field = pg_field.with_format_options(parent_pg_field_info.format_options().clone()); + + encode_value(&mut row_encoder, arr, idx, &pg_field).unwrap(); } Ok(Some(EncodedValue { bytes: row_encoder.row_buffer, @@ -51,22 +62,20 @@ impl StructEncoder { } impl Encoder for StructEncoder { - fn encode_field_with_type_and_format( - &mut self, - value: &T, - data_type: &Type, - format: FieldFormat, - ) -> PgWireResult<()> + fn encode_field(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()> where T: ToSql + ToSqlText + Sized, { + let datatype = pg_field.datatype(); + let format = pg_field.format(); + if format == FieldFormat::Text { if self.curr_col == 0 { self.row_buffer.put_slice(b"("); } // encode value in an intermediate buf let mut buf = BytesMut::new(); - value.to_sql_text(data_type, &mut buf)?; + value.to_sql_text(datatype, &mut buf, pg_field.format_options().as_ref())?; let encoded_value_as_str = String::from_utf8_lossy(&buf); if QUOTE_CHECK.is_match(&encoded_value_as_str) { self.row_buffer.put_u8(b'"'); @@ -90,12 +99,12 @@ impl Encoder for StructEncoder { self.row_buffer.put_i32(self.num_cols as i32); } - self.row_buffer.put_u32(data_type.oid()); + self.row_buffer.put_u32(datatype.oid()); // remember the position of the 4-byte length field let prev_index = self.row_buffer.len(); // write value length as -1 ahead of time self.row_buffer.put_i32(-1); - let is_null = value.to_sql(data_type, &mut self.row_buffer)?; + let is_null = value.to_sql(datatype, &mut self.row_buffer)?; if let IsNull::No = is_null { let value_length = self.row_buffer.len() - prev_index - 4; let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)]; diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index bc43b4f..7fadd8e 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -22,6 +22,7 @@ use pgwire::api::stmt::StoredStatement; use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type}; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::response::TransactionStatus; +use pgwire::types::format::FormatOptions; use crate::auth::AuthManager; use crate::client; @@ -355,7 +356,10 @@ impl SimpleQueryHandler for DfSessionService { results.push(resp); } else { // For non-INSERT queries, return a regular Query response - let resp = df::encode_dataframe(df, &Format::UnifiedText).await?; + let format_options = + Arc::new(FormatOptions::from_client_metadata(client.metadata())); + let resp = + df::encode_dataframe(df, &Format::UnifiedText, Some(format_options)).await?; results.push(Response::Query(resp)); } } @@ -382,7 +386,8 @@ impl ExtendedQueryHandler for DfSessionService { { if let (_, Some((_, plan))) = &target.statement { let schema = plan.schema(); - let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?; + let fields = + arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary, None)?; let params = plan .get_parameter_types() .map_err(|e| PgWireError::ApiError(Box::new(e)))?; @@ -415,7 +420,7 @@ impl ExtendedQueryHandler for DfSessionService { if let (_, Some((_, plan))) = &target.statement.statement { let format = &target.result_column_format; let schema = plan.schema(); - let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?; + let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format, None)?; Ok(DescribePortalResponse::new(fields)) } else { @@ -543,7 +548,14 @@ impl ExtendedQueryHandler for DfSessionService { Ok(resp) } else { // For non-INSERT queries, return a regular Query response - let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?; + let format_options = + Arc::new(FormatOptions::from_client_metadata(client.metadata())); + let resp = df::encode_dataframe( + dataframe, + &portal.result_column_format, + Some(format_options), + ) + .await?; Ok(Response::Query(resp)) } } else {