Skip to content

Commit 2cca406

Browse files
committed
feat: include optional FormatOptions in arrow encode
1 parent 8e2e651 commit 2cca406

File tree

6 files changed

+50
-17
lines changed

6 files changed

+50
-17
lines changed

Cargo.lock

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ bytes = "1.10.1"
1919
chrono = { version = "0.4", features = ["std"] }
2020
datafusion = { version = "50", default-features = false }
2121
futures = "0.3"
22-
pgwire = { version = "0.34", default-features = false }
22+
#pgwire = { version = "0.34", default-features = false }
23+
pgwire = { git = "https://github.com/sunng87/pgwire", rev = "d89089947a56ebfe2e89631925facdd7a85c25b4", default-features = false }
2324
postgres-types = "0.2"
2425
rust_decimal = { version = "1.39", features = ["db-postgres"] }
2526
tokio = { version = "1", default-features = false }

arrow-pg/src/datatypes.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use pgwire::api::results::FieldInfo;
1010
use pgwire::api::Type;
1111
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
1212
use pgwire::messages::data::DataRow;
13+
use pgwire::types::format::FormatOptions;
1314
use postgres_types::Kind;
1415

1516
use crate::row_encoder::RowEncoder;
@@ -111,20 +112,25 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
111112
})
112113
}
113114

114-
pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
115+
pub fn arrow_schema_to_pg_fields(
116+
schema: &Schema,
117+
format: &Format,
118+
data_format_options: Option<Arc<FormatOptions>>,
119+
) -> PgWireResult<Vec<FieldInfo>> {
120+
let _ = data_format_options;
115121
schema
116122
.fields()
117123
.iter()
118124
.enumerate()
119125
.map(|(idx, f)| {
120126
let pg_type = into_pg_type(f.data_type())?;
121-
Ok(FieldInfo::new(
122-
f.name().into(),
123-
None,
124-
None,
125-
pg_type,
126-
format.format_for(idx),
127-
))
127+
let mut field_info =
128+
FieldInfo::new(f.name().into(), None, None, pg_type, format.format_for(idx));
129+
if let Some(data_format_options) = &data_format_options {
130+
field_info = field_info.with_format_options(data_format_options.clone());
131+
}
132+
133+
Ok(field_info)
128134
})
129135
.collect::<PgWireResult<Vec<FieldInfo>>>()
130136
}

arrow-pg/src/datatypes/df.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,22 @@ use pgwire::api::results::QueryResponse;
1313
use pgwire::api::Type;
1414
use pgwire::error::{PgWireError, PgWireResult};
1515
use pgwire::messages::data::DataRow;
16+
use pgwire::types::format::FormatOptions;
1617
use rust_decimal::prelude::ToPrimitive;
1718
use rust_decimal::Decimal;
1819

1920
use super::{arrow_schema_to_pg_fields, encode_recordbatch, into_pg_type};
2021

21-
pub async fn encode_dataframe(df: DataFrame, format: &Format) -> PgWireResult<QueryResponse> {
22-
let fields = Arc::new(arrow_schema_to_pg_fields(df.schema().as_arrow(), format)?);
22+
pub async fn encode_dataframe(
23+
df: DataFrame,
24+
format: &Format,
25+
data_format_options: Option<Arc<FormatOptions>>,
26+
) -> PgWireResult<QueryResponse> {
27+
let fields = Arc::new(arrow_schema_to_pg_fields(
28+
df.schema().as_arrow(),
29+
format,
30+
data_format_options,
31+
)?);
2332

2433
let recordbatch_stream = df
2534
.execute_stream()

arrow-pg/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
pub mod datatypes;
77
pub mod encoder;
88
mod error;
9+
pub mod format;
910
pub mod list_encoder;
1011
pub mod row_encoder;
1112
pub mod struct_encoder;
13+
14+
#[cfg(feature = "datafusion")]
15+
pub use datatypes::df::encode_dataframe;
16+
17+
pub use datatypes::encode_recordbatch;

datafusion-postgres/src/handlers.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use pgwire::api::stmt::StoredStatement;
2222
use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
2323
use pgwire::error::{PgWireError, PgWireResult};
2424
use pgwire::messages::response::TransactionStatus;
25+
use pgwire::types::format::FormatOptions;
2526

2627
use crate::auth::AuthManager;
2728
use crate::client;
@@ -355,7 +356,10 @@ impl SimpleQueryHandler for DfSessionService {
355356
results.push(resp);
356357
} else {
357358
// For non-INSERT queries, return a regular Query response
358-
let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
359+
let format_options =
360+
Arc::new(FormatOptions::from_client_metadata(client.metadata()));
361+
let resp =
362+
df::encode_dataframe(df, &Format::UnifiedText, Some(format_options)).await?;
359363
results.push(Response::Query(resp));
360364
}
361365
}
@@ -382,7 +386,8 @@ impl ExtendedQueryHandler for DfSessionService {
382386
{
383387
if let (_, Some((_, plan))) = &target.statement {
384388
let schema = plan.schema();
385-
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
389+
let fields =
390+
arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary, None)?;
386391
let params = plan
387392
.get_parameter_types()
388393
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
@@ -415,7 +420,7 @@ impl ExtendedQueryHandler for DfSessionService {
415420
if let (_, Some((_, plan))) = &target.statement.statement {
416421
let format = &target.result_column_format;
417422
let schema = plan.schema();
418-
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
423+
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format, None)?;
419424

420425
Ok(DescribePortalResponse::new(fields))
421426
} else {
@@ -543,7 +548,14 @@ impl ExtendedQueryHandler for DfSessionService {
543548
Ok(resp)
544549
} else {
545550
// For non-INSERT queries, return a regular Query response
546-
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
551+
let format_options =
552+
Arc::new(FormatOptions::from_client_metadata(client.metadata()));
553+
let resp = df::encode_dataframe(
554+
dataframe,
555+
&portal.result_column_format,
556+
Some(format_options),
557+
)
558+
.await?;
547559
Ok(Response::Query(resp))
548560
}
549561
} else {

0 commit comments

Comments
 (0)