Skip to content

Commit 1ffc04e

Browse files
committed
Merge branch 'master' into feature/information-schema-support
2 parents 7e0d3cc + 4b8ea1c commit 1ffc04e

File tree

4 files changed

+124
-6
lines changed

4 files changed

+124
-6
lines changed

Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-postgres/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ futures = "0.3"
2424
async-trait = "0.1"
2525
log = "0.4"
2626
chrono = { version = "0.4", features = ["std"] }
27+
rust_decimal = { version = "1.37", features = ["db-postgres"] }

datafusion-postgres/src/datatypes.rs

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use pgwire::api::portal::{Format, Portal};
1515
use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse};
1616
use pgwire::api::Type;
1717
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
18+
use rust_decimal::prelude::ToPrimitive;
19+
use rust_decimal::{Decimal, Error};
1820
use timezone::Tz;
1921

2022
pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
@@ -38,6 +40,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
3840
DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA,
3941
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
4042
DataType::Float64 => Type::FLOAT8,
43+
DataType::Decimal128(_, _) => Type::NUMERIC,
4144
DataType::Utf8 => Type::VARCHAR,
4245
DataType::LargeUtf8 => Type::TEXT,
4346
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
@@ -72,6 +75,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
7275
}
7376
}
7477
DataType::Utf8View => Type::TEXT,
78+
DataType::Dictionary(_, value_type) => into_pg_type(value_type)?,
7579
_ => {
7680
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
7781
"ERROR".to_owned(),
@@ -82,6 +86,24 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
8286
})
8387
}
8488

89+
fn get_numeric_128_value(arr: &Arc<dyn Array>, idx: usize, scale: u32) -> PgWireResult<Decimal> {
90+
let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
91+
let value = array.value(idx);
92+
Decimal::try_from_i128_with_scale(value, scale).map_err(|e| {
93+
let message = match e {
94+
Error::ExceedsMaximumPossibleValue => "Exceeds maximum possible value",
95+
Error::LessThanMinimumPossibleValue => "Less than minimum possible value",
96+
Error::ScaleExceedsMaximumPrecision(_) => "Scale exceeds maximum precision",
97+
_ => unreachable!(),
98+
};
99+
PgWireError::UserError(Box::new(ErrorInfo::new(
100+
"ERROR".to_owned(),
101+
"XX000".to_owned(),
102+
message.to_owned(),
103+
)))
104+
})
105+
}
106+
85107
fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> bool {
86108
arr.as_any()
87109
.downcast_ref::<BooleanArray>()
@@ -257,6 +279,9 @@ fn encode_value(
257279
DataType::UInt64 => encoder.encode_field(&(get_u64_value(arr, idx) as i64))?,
258280
DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx))?,
259281
DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx))?,
282+
DataType::Decimal128(_, s) => {
283+
encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?)?
284+
}
260285
DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx))?,
261286
DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx))?,
262287
DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx))?,
@@ -360,6 +385,17 @@ fn encode_value(
360385
DataType::UInt64 => encoder.encode_field(&get_u64_list_value(arr, idx))?,
361386
DataType::Float32 => encoder.encode_field(&get_f32_list_value(arr, idx))?,
362387
DataType::Float64 => encoder.encode_field(&get_f64_list_value(arr, idx))?,
388+
DataType::Decimal128(_, s) => {
389+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
390+
let value: Vec<_> = list_arr
391+
.as_any()
392+
.downcast_ref::<Decimal128Array>()
393+
.unwrap()
394+
.iter()
395+
.map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
396+
.collect();
397+
encoder.encode_field(&value)?
398+
}
363399
DataType::Utf8 => {
364400
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
365401
let value: Vec<_> = list_arr
@@ -609,6 +645,45 @@ fn encode_value(
609645
}
610646
}
611647
}
648+
DataType::Dictionary(_, value_type) => {
649+
// Get the dictionary values, ignoring keys
650+
// We'll use Int32Type as a common key type, but we're only interested in values
651+
macro_rules! get_dict_values {
652+
($key_type:ty) => {
653+
arr.as_any()
654+
.downcast_ref::<DictionaryArray<$key_type>>()
655+
.map(|dict| dict.values())
656+
};
657+
}
658+
659+
// Try to extract values using different key types
660+
let values = get_dict_values!(Int8Type)
661+
.or_else(|| get_dict_values!(Int16Type))
662+
.or_else(|| get_dict_values!(Int32Type))
663+
.or_else(|| get_dict_values!(Int64Type))
664+
.or_else(|| get_dict_values!(UInt8Type))
665+
.or_else(|| get_dict_values!(UInt16Type))
666+
.or_else(|| get_dict_values!(UInt32Type))
667+
.or_else(|| get_dict_values!(UInt64Type))
668+
.ok_or_else(|| {
669+
PgWireError::UserError(Box::new(ErrorInfo::new(
670+
"ERROR".to_owned(),
671+
"XX000".to_owned(),
672+
format!(
673+
"Unsupported dictionary key type for value type {}",
674+
value_type
675+
),
676+
)))
677+
})?;
678+
679+
// If the dictionary has only one value, treat it as a primitive
680+
if values.len() == 1 {
681+
encode_value(encoder, values, 0)?
682+
} else {
683+
// Otherwise, use value directly indexed by values array
684+
encode_value(encoder, values, idx)?
685+
}
686+
}
612687
_ => {
613688
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
614689
"ERROR".to_owned(),
@@ -671,9 +746,9 @@ pub(crate) async fn encode_dataframe<'a>(
671746
for col in 0..cols {
672747
let array = rb.column(col);
673748
if array.is_null(row) {
674-
encoder.encode_field(&None::<i8>).unwrap();
749+
encoder.encode_field(&None::<i8>)?;
675750
} else {
676-
encode_value(&mut encoder, array, row).unwrap();
751+
encode_value(&mut encoder, array, row)?
677752
}
678753
}
679754
encoder.finish()
@@ -768,6 +843,20 @@ where
768843
let value = portal.parameter::<f64>(i, &pg_type)?;
769844
deserialized_params.push(ScalarValue::Float64(value));
770845
}
846+
Type::NUMERIC => {
847+
let value = match portal.parameter::<Decimal>(i, &pg_type)? {
848+
None => ScalarValue::Decimal128(None, 0, 0),
849+
Some(value) => {
850+
let precision = match value.mantissa() {
851+
0 => 1,
852+
m => (m.abs() as f64).log10().floor() as u8 + 1,
853+
};
854+
let scale = value.scale() as i8;
855+
ScalarValue::Decimal128(value.to_i128(), precision, scale)
856+
}
857+
};
858+
deserialized_params.push(value);
859+
}
771860
Type::TIMESTAMP => {
772861
let value = portal.parameter::<NaiveDateTime>(i, &pg_type)?;
773862
deserialized_params.push(ScalarValue::TimestampMicrosecond(

datafusion-postgres/src/handlers.rs

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,35 @@ impl SimpleQueryHandler for DfSessionService {
231231
.sql(&qualified_query)
232232
.await
233233
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
234-
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
235-
Ok(vec![Response::Query(resp)])
234+
235+
let query_lower = query.to_lowercase();
236+
if query_lower.starts_with("insert into") {
237+
// For INSERT queries, we need to execute the query to get the row count
238+
// and return an Execution response with the proper tag
239+
let result = df
240+
.clone()
241+
.collect()
242+
.await
243+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
244+
245+
// Extract count field from the first batch
246+
let rows_affected = result
247+
.first()
248+
.and_then(|batch| batch.column_by_name("count"))
249+
.and_then(|col| {
250+
col.as_any()
251+
.downcast_ref::<datafusion::arrow::array::UInt64Array>()
252+
})
253+
.map_or(0, |array| array.value(0) as usize);
254+
255+
// Create INSERT tag with the affected row count
256+
let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
257+
Ok(vec![Response::Execution(tag)])
258+
} else {
259+
// For non-INSERT queries, return a regular Query response
260+
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
261+
Ok(vec![Response::Query(resp)])
262+
}
236263
}
237264
}
238265

0 commit comments

Comments
 (0)