diff --git a/.gitignore b/.gitignore index 264a638..a7bba04 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target .direnv .envrc +.vscode \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 0bd57e7..5d216c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1355,11 +1355,13 @@ version = "0.3.0" dependencies = [ "arrow", "async-trait", + "bytes", "chrono", "datafusion", "futures", "log", "pgwire", + "postgres-types", "rust_decimal", "tokio", ] diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index 6148312..6a42d83 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -16,12 +16,14 @@ readme = "../README.md" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -pgwire = { workspace = true } -datafusion = { workspace = true } -tokio = { version = "1.45", features = ["sync"] } arrow = "54.2.0" -futures = "0.3" async-trait = "0.1" -log = "0.4" +bytes = "1.10.1" chrono = { version = "0.4", features = ["std"] } +datafusion = { workspace = true } +futures = "0.3" +log = "0.4" +pgwire = { workspace = true } +postgres-types = "0.2" rust_decimal = { version = "1.37", features = ["db-postgres"] } +tokio = { version = "1.45", features = ["sync"] } diff --git a/datafusion-postgres/src/datatypes.rs b/datafusion-postgres/src/datatypes.rs index 3660d40..bf5223b 100644 --- a/datafusion-postgres/src/datatypes.rs +++ b/datafusion-postgres/src/datatypes.rs @@ -1,10 +1,8 @@ use std::iter; -use std::str::FromStr; use std::sync::Arc; -use chrono::{DateTime, FixedOffset, TimeZone, Utc}; +use chrono::{DateTime, FixedOffset}; use chrono::{NaiveDate, NaiveDateTime}; -use datafusion::arrow::array::*; use datafusion::arrow::datatypes::*; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::{DFSchema, ParamValues}; @@ -12,12 +10,15 @@ use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use futures::{stream, StreamExt}; use pgwire::api::portal::{Format, Portal}; -use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse}; +use pgwire::api::results::{FieldInfo, QueryResponse}; use pgwire::api::Type; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::messages::data::DataRow; +use postgres_types::Kind; use rust_decimal::prelude::ToPrimitive; -use rust_decimal::{Decimal, Error}; -use timezone::Tz; +use rust_decimal::Decimal; + +use crate::encoder::row_encoder::RowEncoder; pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult { Ok(match df_type { @@ -65,6 +66,12 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult { DataType::Float64 => Type::FLOAT8_ARRAY, DataType::Utf8 => Type::VARCHAR_ARRAY, DataType::LargeUtf8 => Type::TEXT_ARRAY, + struct_type @ DataType::Struct(_) => Type::new( + Type::RECORD_ARRAY.name().into(), + Type::RECORD_ARRAY.oid(), + Kind::Array(into_pg_type(struct_type)?), + Type::RECORD_ARRAY.schema().into(), + ), list_type => { return Err(PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), @@ -76,6 +83,24 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult { } DataType::Utf8View => Type::TEXT, DataType::Dictionary(_, value_type) => into_pg_type(value_type)?, + DataType::Struct(fields) => { + let name: String = fields + .iter() + .map(|x| x.name().clone()) + .reduce(|a, b| a + ", " + &b) + .map(|x| format!("({x})")) + .unwrap_or("()".to_string()); + let kind = Kind::Composite( + fields + .iter() + .map(|x| { + into_pg_type(x.data_type()) + .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) + }) + .collect::, PgWireError>>()?, + ); + Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) + } _ => { return Err(PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), @@ -86,619 +111,6 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult { }) } -fn get_numeric_128_value(arr: &Arc, idx: usize, scale: u32) -> PgWireResult { - let array = arr.as_any().downcast_ref::().unwrap(); - let value = array.value(idx); - Decimal::try_from_i128_with_scale(value, scale).map_err(|e| { - let message = match e { - Error::ExceedsMaximumPossibleValue => "Exceeds maximum possible value", - Error::LessThanMinimumPossibleValue => "Less than minimum possible value", - Error::ScaleExceedsMaximumPrecision(_) => "Scale exceeds maximum precision", - _ => unreachable!(), - }; - PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - message.to_owned(), - ))) - }) -} - -fn get_bool_value(arr: &Arc, idx: usize) -> bool { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) -} - -fn get_bool_list_value(arr: &Arc, idx: usize) -> Vec> { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect() -} - -macro_rules! get_primitive_value { - ($name:ident, $t:ty, $pt:ty) => { - fn $name(arr: &Arc, idx: usize) -> $pt { - arr.as_any() - .downcast_ref::>() - .unwrap() - .value(idx) - } - }; -} - -get_primitive_value!(get_i8_value, Int8Type, i8); -get_primitive_value!(get_i16_value, Int16Type, i16); -get_primitive_value!(get_i32_value, Int32Type, i32); -get_primitive_value!(get_i64_value, Int64Type, i64); -get_primitive_value!(get_u8_value, UInt8Type, u8); -get_primitive_value!(get_u16_value, UInt16Type, u16); -get_primitive_value!(get_u32_value, UInt32Type, u32); -get_primitive_value!(get_u64_value, UInt64Type, u64); -get_primitive_value!(get_f32_value, Float32Type, f32); -get_primitive_value!(get_f64_value, Float64Type, f64); - -macro_rules! get_primitive_list_value { - ($name:ident, $t:ty, $pt:ty) => { - fn $name(arr: &Arc, idx: usize) -> Vec> { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - list_arr - .as_any() - .downcast_ref::>() - .unwrap() - .iter() - .collect() - } - }; - - ($name:ident, $t:ty, $pt:ty, $f:expr) => { - fn $name(arr: &Arc, idx: usize) -> Vec> { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - list_arr - .as_any() - .downcast_ref::>() - .unwrap() - .iter() - .map(|val| val.map($f)) - .collect() - } - }; -} - -get_primitive_list_value!(get_i8_list_value, Int8Type, i8); -get_primitive_list_value!(get_i16_list_value, Int16Type, i16); -get_primitive_list_value!(get_i32_list_value, Int32Type, i32); -get_primitive_list_value!(get_i64_list_value, Int64Type, i64); -get_primitive_list_value!(get_u8_list_value, UInt8Type, i8, |val: u8| { val as i8 }); -get_primitive_list_value!(get_u16_list_value, UInt16Type, i16, |val: u16| { - val as i16 -}); -get_primitive_list_value!(get_u32_list_value, UInt32Type, u32); -get_primitive_list_value!(get_u64_list_value, UInt64Type, i64, |val: u64| { - val as i64 -}); -get_primitive_list_value!(get_f32_list_value, Float32Type, f32); -get_primitive_list_value!(get_f64_list_value, Float64Type, f64); - -fn get_utf8_view_value(arr: &Arc, idx: usize) -> &str { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) -} - -fn get_utf8_value(arr: &Arc, idx: usize) -> &str { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) -} - -fn get_large_utf8_value(arr: &Arc, idx: usize) -> &str { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) -} - -fn get_binary_value(arr: &Arc, idx: usize) -> &[u8] { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) -} - -fn get_large_binary_value(arr: &Arc, idx: usize) -> &[u8] { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) -} - -fn get_date32_value(arr: &Arc, idx: usize) -> Option { - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_date(idx) -} - -fn get_date64_value(arr: &Arc, idx: usize) -> Option { - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_date(idx) -} - -fn get_time32_second_value(arr: &Arc, idx: usize) -> Option { - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(idx) -} - -fn get_time32_millisecond_value(arr: &Arc, idx: usize) -> Option { - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(idx) -} - -fn get_time64_microsecond_value(arr: &Arc, idx: usize) -> Option { - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(idx) -} -fn get_time64_nanosecond_value(arr: &Arc, idx: usize) -> Option { - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(idx) -} - -fn encode_value( - encoder: &mut DataRowEncoder, - arr: &Arc, - idx: usize, -) -> PgWireResult<()> { - match arr.data_type() { - DataType::Null => encoder.encode_field(&None::)?, - DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx))?, - DataType::Int8 => encoder.encode_field(&get_i8_value(arr, idx))?, - DataType::Int16 => encoder.encode_field(&get_i16_value(arr, idx))?, - DataType::Int32 => encoder.encode_field(&get_i32_value(arr, idx))?, - DataType::Int64 => encoder.encode_field(&get_i64_value(arr, idx))?, - DataType::UInt8 => encoder.encode_field(&(get_u8_value(arr, idx) as i8))?, - DataType::UInt16 => encoder.encode_field(&(get_u16_value(arr, idx) as i16))?, - DataType::UInt32 => encoder.encode_field(&get_u32_value(arr, idx))?, - DataType::UInt64 => encoder.encode_field(&(get_u64_value(arr, idx) as i64))?, - DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx))?, - DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx))?, - DataType::Decimal128(_, s) => { - encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?)? - } - DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx))?, - DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx))?, - DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx))?, - DataType::Binary => encoder.encode_field(&get_binary_value(arr, idx))?, - DataType::LargeBinary => encoder.encode_field(&get_large_binary_value(arr, idx))?, - DataType::Date32 => encoder.encode_field(&get_date32_value(arr, idx))?, - DataType::Date64 => encoder.encode_field(&get_date64_value(arr, idx))?, - DataType::Time32(unit) => match unit { - TimeUnit::Second => encoder.encode_field(&get_time32_second_value(arr, idx))?, - TimeUnit::Millisecond => { - encoder.encode_field(&get_time32_millisecond_value(arr, idx))? - } - _ => {} - }, - DataType::Time64(unit) => match unit { - TimeUnit::Microsecond => { - encoder.encode_field(&get_time64_microsecond_value(arr, idx))? - } - TimeUnit::Nanosecond => encoder.encode_field(&get_time64_nanosecond_value(arr, idx))?, - _ => {} - }, - DataType::Timestamp(unit, timezone) => match unit { - TimeUnit::Second => { - let ts_array = arr.as_any().downcast_ref::().unwrap(); - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let value = ts_array - .value_as_datetime_with_tz(idx, tz) - .map(|d| d.fixed_offset()); - encoder.encode_field(&value)?; - } else { - let value = ts_array.value_as_datetime(idx); - encoder.encode_field(&value)? - } - } - TimeUnit::Millisecond => { - let ts_array = arr - .as_any() - .downcast_ref::() - .unwrap(); - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let value = ts_array - .value_as_datetime_with_tz(idx, tz) - .map(|d| d.fixed_offset()); - encoder.encode_field(&value)?; - } else { - let value = ts_array.value_as_datetime(idx); - encoder.encode_field(&value)? - } - } - TimeUnit::Microsecond => { - let ts_array = arr - .as_any() - .downcast_ref::() - .unwrap(); - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let value = ts_array - .value_as_datetime_with_tz(idx, tz) - .map(|d| d.fixed_offset()); - encoder.encode_field(&value)?; - } else { - let value = ts_array.value_as_datetime(idx); - encoder.encode_field(&value)? - } - } - TimeUnit::Nanosecond => { - let ts_array = arr - .as_any() - .downcast_ref::() - .unwrap(); - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let value = ts_array - .value_as_datetime_with_tz(idx, tz) - .map(|d| d.fixed_offset()); - encoder.encode_field(&value)?; - } else { - let value = ts_array.value_as_datetime(idx); - encoder.encode_field(&value)? - } - } - }, - - DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { - match field.data_type() { - DataType::Null => encoder.encode_field(&None::)?, - DataType::Boolean => encoder.encode_field(&get_bool_list_value(arr, idx))?, - DataType::Int8 => encoder.encode_field(&get_i8_list_value(arr, idx))?, - DataType::Int16 => encoder.encode_field(&get_i16_list_value(arr, idx))?, - DataType::Int32 => encoder.encode_field(&get_i32_list_value(arr, idx))?, - DataType::Int64 => encoder.encode_field(&get_i64_list_value(arr, idx))?, - DataType::UInt8 => encoder.encode_field(&get_u8_list_value(arr, idx))?, - DataType::UInt16 => encoder.encode_field(&get_u16_list_value(arr, idx))?, - DataType::UInt32 => encoder.encode_field(&get_u32_list_value(arr, idx))?, - DataType::UInt64 => encoder.encode_field(&get_u64_list_value(arr, idx))?, - DataType::Float32 => encoder.encode_field(&get_f32_list_value(arr, idx))?, - DataType::Float64 => encoder.encode_field(&get_f64_list_value(arr, idx))?, - DataType::Decimal128(_, s) => { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - let value: Vec<_> = list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32))) - .collect(); - encoder.encode_field(&value)? - } - DataType::Utf8 => { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - let value: Vec<_> = list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect(); - encoder.encode_field(&value)? - } - DataType::Binary => { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - let value: Vec<_> = list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect(); - encoder.encode_field(&value)? - } - DataType::LargeBinary => { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - let value: Vec<_> = list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect(); - encoder.encode_field(&value)? - } - - DataType::Date32 => { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - let value: Vec<_> = list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect(); - encoder.encode_field(&value)? - } - DataType::Date64 => { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - let value: Vec<_> = list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect(); - encoder.encode_field(&value)? - } - DataType::Time32(unit) => match unit { - TimeUnit::Second => { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - let value: Vec<_> = list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect(); - encoder.encode_field(&value)? - } - TimeUnit::Millisecond => { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - let value: Vec<_> = list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect(); - encoder.encode_field(&value)? - } - _ => {} - }, - DataType::Time64(unit) => match unit { - TimeUnit::Microsecond => { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - let value: Vec<_> = list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect(); - encoder.encode_field(&value)? - } - TimeUnit::Nanosecond => { - let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); - let value: Vec<_> = list_arr - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect(); - encoder.encode_field(&value)? - } - _ => {} - }, - DataType::Timestamp(unit, timezone) => match unit { - TimeUnit::Second => { - let list_array = - arr.as_any().downcast_ref::().unwrap().value(idx); - let array_iter = list_array - .as_any() - .downcast_ref::() - .unwrap() - .iter(); - - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let value: Vec<_> = array_iter - .map(|i| { - i.and_then(|i| { - DateTime::from_timestamp(i, 0).map(|dt| { - Utc.from_utc_datetime(&dt.naive_utc()) - .with_timezone(&tz) - .fixed_offset() - }) - }) - }) - .collect(); - encoder.encode_field(&value)?; - } else { - let value: Vec<_> = array_iter - .map(|i| { - i.and_then(|i| { - DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc()) - }) - }) - .collect(); - encoder.encode_field(&value)? - } - } - TimeUnit::Millisecond => { - let list_array = - arr.as_any().downcast_ref::().unwrap().value(idx); - let array_iter = list_array - .as_any() - .downcast_ref::() - .unwrap() - .iter(); - - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let value: Vec<_> = array_iter - .map(|i| { - i.and_then(|i| { - DateTime::from_timestamp_millis(i).map(|dt| { - Utc.from_utc_datetime(&dt.naive_utc()) - .with_timezone(&tz) - .fixed_offset() - }) - }) - }) - .collect(); - encoder.encode_field(&value)?; - } else { - let value: Vec<_> = array_iter - .map(|i| { - i.and_then(|i| { - DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc()) - }) - }) - .collect(); - encoder.encode_field(&value)? - } - } - TimeUnit::Microsecond => { - let list_array = - arr.as_any().downcast_ref::().unwrap().value(idx); - let array_iter = list_array - .as_any() - .downcast_ref::() - .unwrap() - .iter(); - - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let value: Vec<_> = array_iter - .map(|i| { - i.and_then(|i| { - DateTime::from_timestamp_micros(i).map(|dt| { - Utc.from_utc_datetime(&dt.naive_utc()) - .with_timezone(&tz) - .fixed_offset() - }) - }) - }) - .collect(); - encoder.encode_field(&value)?; - } else { - let value: Vec<_> = array_iter - .map(|i| { - i.and_then(|i| { - DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc()) - }) - }) - .collect(); - encoder.encode_field(&value)? - } - } - TimeUnit::Nanosecond => { - let list_array = - arr.as_any().downcast_ref::().unwrap().value(idx); - let array_iter = list_array - .as_any() - .downcast_ref::() - .unwrap() - .iter(); - - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let value: Vec<_> = array_iter - .map(|i| { - i.map(|i| { - Utc.from_utc_datetime( - &DateTime::from_timestamp_nanos(i).naive_utc(), - ) - .with_timezone(&tz) - .fixed_offset() - }) - }) - .collect(); - encoder.encode_field(&value)?; - } else { - let value: Vec<_> = array_iter - .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc())) - .collect(); - encoder.encode_field(&value)? - } - } - }, - - // TODO: more types - list_type => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( - "Unsupported List Datatype {} and array {:?}", - list_type, &arr - ), - )))) - } - } - } - DataType::Dictionary(_, value_type) => { - // Get the dictionary values, ignoring keys - // We'll use Int32Type as a common key type, but we're only interested in values - macro_rules! get_dict_values { - ($key_type:ty) => { - arr.as_any() - .downcast_ref::>() - .map(|dict| dict.values()) - }; - } - - // Try to extract values using different key types - let values = get_dict_values!(Int8Type) - .or_else(|| get_dict_values!(Int16Type)) - .or_else(|| get_dict_values!(Int32Type)) - .or_else(|| get_dict_values!(Int64Type)) - .or_else(|| get_dict_values!(UInt8Type)) - .or_else(|| get_dict_values!(UInt16Type)) - .or_else(|| get_dict_values!(UInt32Type)) - .or_else(|| get_dict_values!(UInt64Type)) - .ok_or_else(|| { - PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( - "Unsupported dictionary key type for value type {}", - value_type - ), - ))) - })?; - - // If the dictionary has only one value, treat it as a primitive - if values.len() == 1 { - encode_value(encoder, values, 0)? - } else { - // Otherwise, use value directly indexed by values array - encode_value(encoder, values, idx)? - } - } - _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( - "Unsupported Datatype {} and array {:?}", - arr.data_type(), - &arr - ), - )))) - } - } - Ok(()) -} - pub(crate) fn df_schema_to_pg_fields( schema: &DFSchema, format: &Format, @@ -734,34 +146,18 @@ pub(crate) async fn encode_dataframe<'a>( let fields_ref = fields.clone(); let pg_row_stream = recordbatch_stream .map(move |rb: datafusion::error::Result| { - let row_stream: Box + Send> = match rb { + let row_stream: Box> + Send + Sync> = match rb + { Ok(rb) => { - let rows = rb.num_rows(); - let cols = rb.num_columns(); - let fields = fields_ref.clone(); - - let row_stream = (0..rows).map(move |row| { - let mut encoder = DataRowEncoder::new(fields.clone()); - for col in 0..cols { - let array = rb.column(col); - if array.is_null(row) { - encoder.encode_field(&None::)?; - } else { - encode_value(&mut encoder, array, row)? - } - } - encoder.finish() - }); - Box::new(row_stream) + let mut row_stream = RowEncoder::new(rb, fields); + Box::new(std::iter::from_fn(move || row_stream.next_row())) } Err(e) => Box::new(iter::once(Err(PgWireError::ApiError(e.into())))), }; - stream::iter(row_stream) }) .flatten(); - Ok(QueryResponse::new(fields, pg_row_stream)) } diff --git a/datafusion-postgres/src/encoder/list_encoder.rs b/datafusion-postgres/src/encoder/list_encoder.rs new file mode 100644 index 0000000..0c9457c --- /dev/null +++ b/datafusion-postgres/src/encoder/list_encoder.rs @@ -0,0 +1,417 @@ +use std::{error::Error, str::FromStr, sync::Arc}; + +use arrow::{ + datatypes::{ + Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, + }, + temporal_conversions::{as_date, as_time}, +}; +use bytes::{BufMut, BytesMut}; +use chrono::{DateTime, TimeZone, Utc}; +use datafusion::arrow::{ + array::{ + timezone::Tz, Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, + LargeBinaryArray, PrimitiveArray, StringArray, Time32MillisecondArray, Time32SecondArray, + Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + }, + datatypes::{ + DataType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, +}; +use pgwire::{ + api::results::FieldFormat, + error::{ErrorInfo, PgWireError}, + types::{ToSqlText, QUOTE_ESCAPE}, +}; +use postgres_types::{ToSql, Type}; +use rust_decimal::Decimal; + +use super::{struct_encoder::encode_struct, EncodedValue}; + +fn get_bool_list_value(arr: &Arc) -> Vec> { + arr.as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect() +} + +macro_rules! get_primitive_list_value { + ($name:ident, $t:ty, $pt:ty) => { + fn $name(arr: &Arc) -> Vec> { + arr.as_any() + .downcast_ref::>() + .unwrap() + .iter() + .collect() + } + }; + + ($name:ident, $t:ty, $pt:ty, $f:expr) => { + fn $name(arr: &Arc) -> Vec> { + arr.as_any() + .downcast_ref::>() + .unwrap() + .iter() + .map(|val| val.map($f)) + .collect() + } + }; +} + +get_primitive_list_value!(get_i8_list_value, Int8Type, i8); +get_primitive_list_value!(get_i16_list_value, Int16Type, i16); +get_primitive_list_value!(get_i32_list_value, Int32Type, i32); +get_primitive_list_value!(get_i64_list_value, Int64Type, i64); +get_primitive_list_value!(get_u8_list_value, UInt8Type, i8, |val: u8| { val as i8 }); +get_primitive_list_value!(get_u16_list_value, UInt16Type, i16, |val: u16| { + val as i16 +}); +get_primitive_list_value!(get_u32_list_value, UInt32Type, u32); +get_primitive_list_value!(get_u64_list_value, UInt64Type, i64, |val: u64| { + val as i64 +}); +get_primitive_list_value!(get_f32_list_value, Float32Type, f32); +get_primitive_list_value!(get_f64_list_value, Float64Type, f64); + +fn encode_field( + t: &[T], + type_: &Type, + format: FieldFormat, +) -> Result> { + let mut bytes = BytesMut::new(); + match format { + FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?, + FieldFormat::Binary => t.to_sql(type_, &mut bytes)?, + }; + Ok(EncodedValue { bytes }) +} + +pub(crate) fn encode_list( + arr: Arc, + type_: &Type, + format: FieldFormat, +) -> Result> { + match arr.data_type() { + DataType::Null => { + let mut bytes = BytesMut::new(); + match format { + FieldFormat::Text => None::.to_sql_text(type_, &mut bytes), + FieldFormat::Binary => None::.to_sql(type_, &mut bytes), + }?; + Ok(EncodedValue { bytes }) + } + DataType::Boolean => encode_field(&get_bool_list_value(&arr), type_, format), + DataType::Int8 => encode_field(&get_i8_list_value(&arr), type_, format), + DataType::Int16 => encode_field(&get_i16_list_value(&arr), type_, format), + DataType::Int32 => encode_field(&get_i32_list_value(&arr), type_, format), + DataType::Int64 => encode_field(&get_i64_list_value(&arr), type_, format), + DataType::UInt8 => encode_field(&get_u8_list_value(&arr), type_, format), + DataType::UInt16 => encode_field(&get_u16_list_value(&arr), type_, format), + DataType::UInt32 => encode_field(&get_u32_list_value(&arr), type_, format), + DataType::UInt64 => encode_field(&get_u64_list_value(&arr), type_, format), + DataType::Float32 => encode_field(&get_f32_list_value(&arr), type_, format), + DataType::Float64 => encode_field(&get_f64_list_value(&arr), type_, format), + DataType::Decimal128(_, s) => { + let value: Vec<_> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32))) + .collect(); + encode_field(&value, type_, format) + } + DataType::Utf8 => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(); + encode_field(&value, type_, format) + } + DataType::Binary => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(); + encode_field(&value, type_, format) + } + DataType::LargeBinary => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(); + encode_field(&value, type_, format) + } + + DataType::Date32 => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|val| val.and_then(|x| as_date::(x as i64))) + .collect(); + encode_field(&value, type_, format) + } + DataType::Date64 => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|val| val.and_then(as_date::)) + .collect(); + encode_field(&value, type_, format) + } + DataType::Time32(unit) => match unit { + TimeUnit::Second => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|val| val.and_then(|x| as_time::(x as i64))) + .collect(); + encode_field(&value, type_, format) + } + TimeUnit::Millisecond => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|val| val.and_then(|x| as_time::(x as i64))) + .collect(); + encode_field(&value, type_, format) + } + _ => { + unimplemented!() + } + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|val| val.and_then(as_time::)) + .collect(); + encode_field(&value, type_, format) + } + TimeUnit::Nanosecond => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|val| val.and_then(as_time::)) + .collect(); + encode_field(&value, type_, format) + } + _ => { + unimplemented!() + } + }, + DataType::Timestamp(unit, timezone) => match unit { + TimeUnit::Second => { + let array_iter = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter(); + + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let value: Vec<_> = array_iter + .map(|i| { + i.and_then(|i| { + DateTime::from_timestamp(i, 0).map(|dt| { + Utc.from_utc_datetime(&dt.naive_utc()) + .with_timezone(&tz) + .fixed_offset() + }) + }) + }) + .collect(); + encode_field(&value, type_, format) + } else { + let value: Vec<_> = array_iter + .map(|i| { + i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc())) + }) + .collect(); + encode_field(&value, type_, format) + } + } + TimeUnit::Millisecond => { + let array_iter = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter(); + + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let value: Vec<_> = array_iter + .map(|i| { + i.and_then(|i| { + DateTime::from_timestamp_millis(i).map(|dt| { + Utc.from_utc_datetime(&dt.naive_utc()) + .with_timezone(&tz) + .fixed_offset() + }) + }) + }) + .collect(); + encode_field(&value, type_, format) + } else { + let value: Vec<_> = array_iter + .map(|i| { + i.and_then(|i| { + DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc()) + }) + }) + .collect(); + encode_field(&value, type_, format) + } + } + TimeUnit::Microsecond => { + let array_iter = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter(); + + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let value: Vec<_> = array_iter + .map(|i| { + i.and_then(|i| { + DateTime::from_timestamp_micros(i).map(|dt| { + Utc.from_utc_datetime(&dt.naive_utc()) + .with_timezone(&tz) + .fixed_offset() + }) + }) + }) + .collect(); + encode_field(&value, type_, format) + } else { + let value: Vec<_> = array_iter + .map(|i| { + i.and_then(|i| { + DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc()) + }) + }) + .collect(); + encode_field(&value, type_, format) + } + } + TimeUnit::Nanosecond => { + let array_iter = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter(); + + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let value: Vec<_> = array_iter + .map(|i| { + i.map(|i| { + Utc.from_utc_datetime( + &DateTime::from_timestamp_nanos(i).naive_utc(), + ) + .with_timezone(&tz) + .fixed_offset() + }) + }) + .collect(); + encode_field(&value, type_, format) + } else { + let value: Vec<_> = array_iter + .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc())) + .collect(); + encode_field(&value, type_, format) + } + } + }, + DataType::Struct(_) => { + let fields = match type_.kind() { + postgres_types::Kind::Array(struct_type_) => Ok(struct_type_), + _ => Err(format!( + "Expected list type found type {} of kind {:?}", + type_, + type_.kind() + )), + } + .and_then(|struct_type| match struct_type.kind() { + postgres_types::Kind::Composite(fields) => Ok(fields), + _ => Err(format!( + "Failed to unwrap a composite type inside from type {} kind {:?}", + type_, + type_.kind() + )), + }) + .map_err(|err| { + let err = ErrorInfo::new("ERROR".to_owned(), "XX000".to_owned(), err); + Box::new(PgWireError::UserError(Box::new(err))) + })?; + + let values: Result, _> = (0..arr.len()) + .map(|row| encode_struct(&arr, row, fields, format)) + .map(|x| { + if matches!(format, FieldFormat::Text) { + x.map(|opt| { + opt.map(|value| { + let mut w = BytesMut::new(); + w.put_u8(b'"'); + w.put_slice( + QUOTE_ESCAPE + .replace_all( + &String::from_utf8_lossy(&value.bytes), + r#"\$1"#, + ) + .as_bytes(), + ); + w.put_u8(b'"'); + EncodedValue { bytes: w } + }) + }) + } else { + x + } + }) + .collect(); + encode_field(&values?, type_, format) + } + // TODO: more types + list_type => { + let err = PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!( + "Unsupported List Datatype {} and array {:?}", + list_type, &arr + ), + ))); + + Err(Box::new(err)) + } + } +} diff --git a/datafusion-postgres/src/encoder/mod.rs b/datafusion-postgres/src/encoder/mod.rs new file mode 100644 index 0000000..233d24c --- /dev/null +++ b/datafusion-postgres/src/encoder/mod.rs @@ -0,0 +1,552 @@ +use std::io::Write; +use std::str::FromStr; +use std::sync::Arc; + +use bytes::BufMut; +use bytes::BytesMut; +use chrono::{NaiveDate, NaiveDateTime}; +use datafusion::arrow::array::*; +use datafusion::arrow::datatypes::*; +use list_encoder::encode_list; +use pgwire::api::results::DataRowEncoder; +use pgwire::api::results::FieldFormat; +use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::types::ToSqlText; +use postgres_types::{ToSql, Type}; +use rust_decimal::Decimal; +use struct_encoder::encode_struct; +use timezone::Tz; + +pub mod list_encoder; +pub mod row_encoder; +pub mod struct_encoder; + +trait Encoder { + fn encode_field_with_type_and_format( + &mut self, + value: &T, + data_type: &Type, + format: FieldFormat, + ) -> PgWireResult<()> + where + T: ToSql + ToSqlText + Sized; +} + +impl Encoder for DataRowEncoder { + fn encode_field_with_type_and_format( + &mut self, + value: &T, + data_type: &Type, + format: FieldFormat, + ) -> PgWireResult<()> + where + T: ToSql + ToSqlText + Sized, + { + self.encode_field_with_type_and_format(value, data_type, format) + } +} + +pub(crate) struct EncodedValue { + pub(crate) bytes: BytesMut, +} + +impl std::fmt::Debug for EncodedValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EncodedValue").finish() + } +} + +impl ToSql for EncodedValue { + fn to_sql( + &self, + _ty: &Type, + out: &mut BytesMut, + ) -> Result> + where + Self: Sized, + { + out.writer().write_all(&self.bytes)?; + Ok(postgres_types::IsNull::No) + } + + fn accepts(_ty: &Type) -> bool + where + Self: Sized, + { + true + } + + fn to_sql_checked( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result> { + self.to_sql(ty, out) + } +} + +impl ToSqlText for EncodedValue { + fn to_sql_text( + &self, + _ty: &Type, + out: &mut BytesMut, + ) -> Result> + where + Self: Sized, + { + out.writer().write_all(&self.bytes)?; + Ok(postgres_types::IsNull::No) + } +} + +fn get_bool_value(arr: &Arc, idx: usize) -> Option { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +macro_rules! get_primitive_value { + ($name:ident, $t:ty, $pt:ty) => { + fn $name(arr: &Arc, idx: usize) -> Option<$pt> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::>() + .unwrap() + .value(idx) + }) + } + }; +} + +get_primitive_value!(get_i8_value, Int8Type, i8); +get_primitive_value!(get_i16_value, Int16Type, i16); +get_primitive_value!(get_i32_value, Int32Type, i32); +get_primitive_value!(get_i64_value, Int64Type, i64); +get_primitive_value!(get_u8_value, UInt8Type, u8); +get_primitive_value!(get_u16_value, UInt16Type, u16); +get_primitive_value!(get_u32_value, UInt32Type, u32); +get_primitive_value!(get_u64_value, UInt64Type, u64); +get_primitive_value!(get_f32_value, Float32Type, f32); +get_primitive_value!(get_f64_value, Float64Type, f64); + +fn get_utf8_view_value(arr: &Arc, idx: usize) -> Option<&str> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +fn get_utf8_value(arr: &Arc, idx: usize) -> Option<&str> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +fn get_large_utf8_value(arr: &Arc, idx: usize) -> Option<&str> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +fn get_binary_value(arr: &Arc, idx: usize) -> Option<&[u8]> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +fn get_large_binary_value(arr: &Arc, idx: usize) -> Option<&[u8]> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +fn get_date32_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_date(idx) +} + +fn get_date64_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_date(idx) +} + +fn get_time32_second_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(idx) +} + +fn get_time32_millisecond_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(idx) +} + +fn get_time64_microsecond_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(idx) +} +fn get_time64_nanosecond_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(idx) +} + +fn get_numeric_128_value( + arr: &Arc, + idx: usize, + scale: u32, +) -> PgWireResult> { + if arr.is_null(idx) { + return Ok(None); + } + + let array = arr.as_any().downcast_ref::().unwrap(); + let value = array.value(idx); + Decimal::try_from_i128_with_scale(value, scale) + .map_err(|e| { + let message = match e { + rust_decimal::Error::ExceedsMaximumPossibleValue => { + "Exceeds maximum possible value" + } + rust_decimal::Error::LessThanMinimumPossibleValue => { + "Less than minimum possible value" + } + rust_decimal::Error::ScaleExceedsMaximumPrecision(_) => { + "Scale exceeds maximum precision" + } + _ => unreachable!(), + }; + PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + message.to_owned(), + ))) + }) + .map(Some) +} + +fn encode_value( + encoder: &mut T, + arr: &Arc, + idx: usize, + type_: &Type, + format: FieldFormat, +) -> PgWireResult<()> { + match arr.data_type() { + DataType::Null => encoder.encode_field_with_type_and_format(&None::, type_, format)?, + DataType::Boolean => { + encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)? + } + DataType::Int8 => { + encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)? + } + DataType::Int16 => { + encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)? + } + DataType::Int32 => { + encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)? + } + DataType::Int64 => { + encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)? + } + DataType::UInt8 => encoder.encode_field_with_type_and_format( + &(get_u8_value(arr, idx).map(|x| x as i8)), + type_, + format, + )?, + DataType::UInt16 => encoder.encode_field_with_type_and_format( + &(get_u16_value(arr, idx).map(|x| x as i16)), + type_, + format, + )?, + DataType::UInt32 => { + encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)? + } + DataType::UInt64 => encoder.encode_field_with_type_and_format( + &(get_u64_value(arr, idx).map(|x| x as i64)), + type_, + format, + )?, + DataType::Float32 => { + encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)? + } + DataType::Float64 => { + encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)? + } + DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format( + &get_numeric_128_value(arr, idx, *s as u32)?, + type_, + format, + )?, + DataType::Utf8 => { + encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)? + } + DataType::Utf8View => encoder.encode_field_with_type_and_format( + &get_utf8_view_value(arr, idx), + type_, + format, + )?, + DataType::LargeUtf8 => encoder.encode_field_with_type_and_format( + &get_large_utf8_value(arr, idx), + type_, + format, + )?, + DataType::Binary => { + encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)? + } + DataType::LargeBinary => encoder.encode_field_with_type_and_format( + &get_large_binary_value(arr, idx), + type_, + format, + )?, + DataType::Date32 => { + encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)? + } + DataType::Date64 => { + encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)? + } + DataType::Time32(unit) => match unit { + TimeUnit::Second => encoder.encode_field_with_type_and_format( + &get_time32_second_value(arr, idx), + type_, + format, + )?, + TimeUnit::Millisecond => encoder.encode_field_with_type_and_format( + &get_time32_millisecond_value(arr, idx), + type_, + format, + )?, + _ => {} + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => encoder.encode_field_with_type_and_format( + &get_time64_microsecond_value(arr, idx), + type_, + format, + )?, + TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format( + &get_time64_nanosecond_value(arr, idx), + type_, + format, + )?, + _ => {} + }, + DataType::Timestamp(unit, timezone) => match unit { + TimeUnit::Second => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format( + &None::, + type_, + format, + ); + } + let ts_array = arr.as_any().downcast_ref::().unwrap(); + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let value = ts_array + .value_as_datetime_with_tz(idx, tz) + .map(|d| d.fixed_offset()); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } else { + let value = ts_array.value_as_datetime(idx); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } + } + TimeUnit::Millisecond => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format( + &None::, + type_, + format, + ); + } + let ts_array = arr + .as_any() + .downcast_ref::() + .unwrap(); + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let value = ts_array + .value_as_datetime_with_tz(idx, tz) + .map(|d| d.fixed_offset()); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } else { + let value = ts_array.value_as_datetime(idx); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } + } + TimeUnit::Microsecond => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format( + &None::, + type_, + format, + ); + } + let ts_array = arr + .as_any() + .downcast_ref::() + .unwrap(); + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let value = ts_array + .value_as_datetime_with_tz(idx, tz) + .map(|d| d.fixed_offset()); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } else { + let value = ts_array.value_as_datetime(idx); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } + } + TimeUnit::Nanosecond => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format( + &None::, + type_, + format, + ); + } + let ts_array = arr + .as_any() + .downcast_ref::() + .unwrap(); + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let value = ts_array + .value_as_datetime_with_tz(idx, tz) + .map(|d| d.fixed_offset()); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } else { + let value = ts_array.value_as_datetime(idx); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } + } + }, + DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format); + } + let array = arr.as_any().downcast_ref::().unwrap().value(idx); + let value = encode_list(array, type_, format)?; + encoder.encode_field_with_type_and_format(&value, type_, format)? + } + DataType::Struct(_) => { + let fields = match type_.kind() { + postgres_types::Kind::Composite(fields) => fields, + _ => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!("Failed to unwrap a composite type from type {}", type_), + )))) + } + }; + let value = encode_struct(arr, idx, fields, format)?; + encoder.encode_field_with_type_and_format(&value, type_, format)? + } + DataType::Dictionary(_, value_type) => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format(&None::, type_, format); + } + // Get the dictionary values, ignoring keys + // We'll use Int32Type as a common key type, but we're only interested in values + macro_rules! get_dict_values { + ($key_type:ty) => { + arr.as_any() + .downcast_ref::>() + .map(|dict| dict.values()) + }; + } + + // Try to extract values using different key types + let values = get_dict_values!(Int8Type) + .or_else(|| get_dict_values!(Int16Type)) + .or_else(|| get_dict_values!(Int32Type)) + .or_else(|| get_dict_values!(Int64Type)) + .or_else(|| get_dict_values!(UInt8Type)) + .or_else(|| get_dict_values!(UInt16Type)) + .or_else(|| get_dict_values!(UInt32Type)) + .or_else(|| get_dict_values!(UInt64Type)) + .ok_or_else(|| { + PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!( + "Unsupported dictionary key type for value type {}", + value_type + ), + ))) + })?; + + // If the dictionary has only one value, treat it as a primitive + if values.len() == 1 { + encode_value(encoder, values, 0, type_, format)? + } else { + // Otherwise, use value directly indexed by values array + encode_value(encoder, values, idx, type_, format)? + } + } + _ => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!( + "Unsupported Datatype {} and array {:?}", + arr.data_type(), + &arr + ), + )))) + } + } + + Ok(()) +} diff --git a/datafusion-postgres/src/encoder/row_encoder.rs b/datafusion-postgres/src/encoder/row_encoder.rs new file mode 100644 index 0000000..9d48145 --- /dev/null +++ b/datafusion-postgres/src/encoder/row_encoder.rs @@ -0,0 +1,43 @@ +use std::sync::Arc; + +use datafusion::arrow::array::RecordBatch; +use pgwire::{ + api::results::{DataRowEncoder, FieldInfo}, + error::PgWireResult, + messages::data::DataRow, +}; + +use super::encode_value; + +pub struct RowEncoder { + rb: RecordBatch, + curr_idx: usize, + fields: Arc>, +} + +impl RowEncoder { + pub fn new(rb: RecordBatch, fields: Arc>) -> Self { + assert_eq!(rb.num_columns(), fields.len()); + Self { + rb, + fields, + curr_idx: 0, + } + } + + pub fn next_row(&mut self) -> Option> { + if self.curr_idx == self.rb.num_rows() { + return None; + } + let mut encoder = DataRowEncoder::new(self.fields.clone()); + for col in 0..self.rb.num_columns() { + let array = self.rb.column(col); + let field = &self.fields[col]; + let type_ = field.datatype(); + let format = field.format(); + encode_value(&mut encoder, array, self.curr_idx, type_, format).unwrap(); + } + self.curr_idx += 1; + Some(encoder.finish()) + } +} diff --git a/datafusion-postgres/src/encoder/struct_encoder.rs b/datafusion-postgres/src/encoder/struct_encoder.rs new file mode 100644 index 0000000..d0b1523 --- /dev/null +++ b/datafusion-postgres/src/encoder/struct_encoder.rs @@ -0,0 +1,106 @@ +use std::{error::Error, sync::Arc}; + +use bytes::{BufMut, BytesMut}; +use datafusion::arrow::array::{Array, StructArray}; +use pgwire::{ + api::results::FieldFormat, + error::PgWireResult, + types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}, +}; +use postgres_types::{Field, IsNull, ToSql, Type}; + +use super::{encode_value, EncodedValue}; + +pub fn encode_struct( + arr: &Arc, + idx: usize, + fields: &[Field], + format: FieldFormat, +) -> Result, Box> { + let arr = arr.as_any().downcast_ref::().unwrap(); + if arr.is_null(idx) { + return Ok(None); + } + let mut row_encoder = StructEncoder::new(fields.len()); + for (i, arr) in arr.columns().iter().enumerate() { + let field = &fields[i]; + let type_ = field.type_(); + encode_value(&mut row_encoder, arr, idx, type_, format).unwrap(); + } + Ok(Some(EncodedValue { + bytes: row_encoder.row_buffer, + })) +} + +struct StructEncoder { + num_cols: usize, + curr_col: usize, + row_buffer: BytesMut, +} + +impl StructEncoder { + fn new(num_cols: usize) -> Self { + Self { + num_cols, + curr_col: 0, + row_buffer: BytesMut::new(), + } + } +} + +impl super::Encoder for StructEncoder { + fn encode_field_with_type_and_format( + &mut self, + value: &T, + data_type: &Type, + format: FieldFormat, + ) -> PgWireResult<()> + where + T: ToSql + ToSqlText + Sized, + { + if format == FieldFormat::Text { + if self.curr_col == 0 { + self.row_buffer.put_slice(b"("); + } + // encode value in an intermediate buf + let mut buf = BytesMut::new(); + value.to_sql_text(data_type, &mut buf)?; + let encoded_value_as_str = String::from_utf8_lossy(&buf); + if QUOTE_CHECK.is_match(&encoded_value_as_str) { + self.row_buffer.put_u8(b'"'); + self.row_buffer.put_slice( + QUOTE_ESCAPE + .replace_all(&encoded_value_as_str, r#"\$1"#) + .as_bytes(), + ); + self.row_buffer.put_u8(b'"'); + } else { + self.row_buffer.put_slice(&buf); + } + if self.curr_col == self.num_cols - 1 { + self.row_buffer.put_slice(b")"); + } else { + self.row_buffer.put_slice(b","); + } + } else { + if self.curr_col == 0 && format == FieldFormat::Binary { + // Place Number of fields + self.row_buffer.put_i32(self.num_cols as i32); + } + + self.row_buffer.put_u32(data_type.oid()); + // remember the position of the 4-byte length field + let prev_index = self.row_buffer.len(); + // write value length as -1 ahead of time + self.row_buffer.put_i32(-1); + let is_null = value.to_sql(data_type, &mut self.row_buffer)?; + if let IsNull::No = is_null { + let value_length = self.row_buffer.len() - prev_index - 4; + let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)]; + length_bytes.put_i32(value_length as i32); + } + } + self.curr_col += 1; + Ok(()) + } +} diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index c560526..b677a2a 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -1,4 +1,5 @@ mod datatypes; +mod encoder; mod handlers; mod information_schema; diff --git a/tests-integration/all_types.parquet b/tests-integration/all_types.parquet new file mode 100644 index 0000000..293b853 Binary files /dev/null and b/tests-integration/all_types.parquet differ diff --git a/tests-integration/create_arrow_testfile.py b/tests-integration/create_arrow_testfile.py new file mode 100644 index 0000000..6fcf4a2 --- /dev/null +++ b/tests-integration/create_arrow_testfile.py @@ -0,0 +1,42 @@ +import pyarrow as pa +from pyarrow import StructArray, ListArray +import pyarrow.parquet as pq +from datetime import date, datetime + +base: list[pa.Field] = [ + pa.field("int32", pa.int32(), nullable=True), + pa.field("float64", pa.float64(), nullable=True), + pa.field("string", pa.string(), nullable=True), + pa.field("bool", pa.bool_(), nullable=True), + pa.field("date32", pa.date32(), nullable=True), + pa.field("timestamp", pa.timestamp("ms"), nullable=True), +] +list_type = [ + pa.field(str(inner.name) + "_list", pa.list_(inner), nullable=True) + for inner in base +] +struct_type = [pa.field("struct", pa.struct(base), nullable=True)] +list_struct_type = [pa.field("list_struct", pa.list_(struct_type[0]), nullable=True)] +fields = base + list_type + struct_type + list_struct_type +schema = pa.schema(fields) + +base_data = [ + pa.array([1, None, 2], type=pa.int32()), + pa.array([1.0, None, 2.0], type=pa.float64()), + pa.array(["a", None, "b"], type=pa.string()), + pa.array([True, None, False], type=pa.bool_()), + pa.array([date(2012, 1, 1), None, date(2012, 1, 2)], type=pa.date32()), + pa.array( + [datetime(2012, 1, 1), None, datetime(2012, 1, 2)], type=pa.timestamp("ms") + ), +] +list_data = [pa.array([x.to_pylist(), None, None]) for x in base_data] +struct_data = [StructArray.from_arrays(base_data, fields=struct_type[0].type.fields)] +list_struct_data = [ListArray.from_arrays(pa.array([0, 1, 2, 3]), struct_data[0])] + +arrays = base_data + list_data + struct_data + list_struct_data + +# Create a table +table = pa.Table.from_arrays(arrays, schema=schema) + +pq.write_table(table, "all_types.parquet") diff --git a/tests-integration/test.sh b/tests-integration/test.sh index 87fe775..6158786 100755 --- a/tests-integration/test.sh +++ b/tests-integration/test.sh @@ -8,3 +8,9 @@ PID=$! sleep 3 python tests-integration/test.py kill -9 $PID 2>/dev/null + +./target/debug/datafusion-postgres-cli --parquet all_types:tests-integration/all_types.parquet & +PID=$! +sleep 3 +python tests-integration/test_all_types.py +kill -9 $PID 2>/dev/null \ No newline at end of file diff --git a/tests-integration/test_all_types.py b/tests-integration/test_all_types.py new file mode 100644 index 0000000..d4d9240 --- /dev/null +++ b/tests-integration/test_all_types.py @@ -0,0 +1,152 @@ +import psycopg +from datetime import date, datetime + +conn: psycopg.connection.Connection = psycopg.connect( + "host=127.0.0.1 port=5432 user=tom password=pencil dbname=public" +) +conn.autocommit = True + + +def data(format: str): + return [ + ( + 1, + 1.0, + "a", + True, + date(2012, 1, 1), + datetime(2012, 1, 1), + [1, None, 2], + [1.0, None, 2.0], + ["a", None, "b"], + [True, None, False], + [date(2012, 1, 1), None, date(2012, 1, 2)], + [datetime(2012, 1, 1), None, datetime(2012, 1, 2)], + ( + (1, 1.0, "a", True, date(2012, 1, 1), datetime(2012, 1, 1)) + if format == "text" + else ( + "1", + "1", + "a", + "t", + "2012-01-01", + "2012-01-01 00:00:00.000000", + ) + ), + ( + [(1, 1.0, "a", True, date(2012, 1, 1), datetime(2012, 1, 1))] + if format == "text" + else [ + ( + "1", + "1", + "a", + "t", + "2012-01-01", + "2012-01-01 00:00:00.000000", + ) + ] + ), + ), + ( + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ( + (None, None, None, None, None, None) + if format == "text" + else ("", "", "", "", "", "") + ), + ( + [(None, None, None, None, None, None)] + if format == "text" + else [("", "", "", "", "", "")] + ), + ), + ( + 2, + 2.0, + "b", + False, + date(2012, 1, 2), + datetime(2012, 1, 2), + None, + None, + None, + None, + None, + None, + ( + (2, 2.0, "b", False, date(2012, 1, 2), datetime(2012, 1, 2)) + if format == "text" + else ( + "2", + "2", + "b", + "f", + "2012-01-02", + "2012-01-02 00:00:00.000000", + ) + ), + ( + [(2, 2.0, "b", False, date(2012, 1, 2), datetime(2012, 1, 2))] + if format == "text" + else [ + ( + "2", + "2", + "b", + "f", + "2012-01-02", + "2012-01-02 00:00:00.000000", + ) + ] + ), + ), + ] + + +def assert_select_all(results: list[psycopg.rows.Row], format: str): + expected = data(format) + + assert len(results) == len( + expected + ), f"Expected {len(expected)} rows, got {len(results)}" + + for i, (res_row, exp_row) in enumerate(zip(results, expected)): + assert len(res_row) == len(exp_row), f"Row {i} column count mismatch" + for j, (res_val, exp_val) in enumerate(zip(res_row, exp_row)): + assert ( + res_val == exp_val + ), f"Mismatch at row {i}, column {j}: expected {exp_val}, got {res_val}" + + +with conn.cursor(binary=True) as cur: + cur.execute("SELECT count(*) FROM all_types") + results = cur.fetchone() + assert results[0] == 3 + +with conn.cursor(binary=False) as cur: + cur.execute("SELECT count(*) FROM all_types") + results = cur.fetchone() + assert results[0] == 3 + +with conn.cursor(binary=True) as cur: + cur.execute("SELECT * FROM all_types") + results = cur.fetchall() + assert_select_all(results, "text") + +with conn.cursor(binary=False) as cur: + cur.execute("SELECT * FROM all_types") + results = cur.fetchall() + assert_select_all(results, "binary")