Skip to content

Commit a7e2f9a

Browse files
committed
feat: include FormatOptions in encoder api
1 parent efc3b1c commit a7e2f9a

File tree

5 files changed

+64
-56
lines changed

5 files changed

+64
-56
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ chrono = { version = "0.4", features = ["std"] }
2020
datafusion = { version = "50", default-features = false }
2121
futures = "0.3"
2222
#pgwire = { version = "0.34", default-features = false }
23-
pgwire = { git = "https://github.com/sunng87/pgwire", rev = "37a32a05d2aed55bd013f3c6f93d786368350e0b", default-features = false }
23+
pgwire = { git = "https://github.com/sunng87/pgwire", rev = "ad1e31a4b3fdc325eeb57d0cfe0dc1798b6687a6", default-features = false }
2424
postgres-types = "0.2"
2525
rust_decimal = { version = "1.39", features = ["db-postgres"] }
2626
tokio = { version = "1", default-features = false }

arrow-pg/src/encoder.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use chrono::{NaiveDate, NaiveDateTime};
1212
use datafusion::arrow::{array::*, datatypes::*};
1313
use pgwire::api::results::{DataRowEncoder, FieldInfo};
1414
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
15+
use pgwire::types::format::FormatOptions;
1516
use pgwire::types::ToSqlText;
1617
use postgres_types::{ToSql, Type};
1718
use rust_decimal::Decimal;
@@ -32,7 +33,12 @@ impl Encoder for DataRowEncoder {
3233
where
3334
T: ToSql + ToSqlText + Sized,
3435
{
35-
self.encode_field_with_type_and_format(value, pg_field.datatype(), pg_field.format())
36+
self.encode_field_with_type_and_format(
37+
value,
38+
pg_field.datatype(),
39+
pg_field.format(),
40+
pg_field.format_options(),
41+
)
3642
}
3743
}
3844

@@ -80,6 +86,7 @@ impl ToSqlText for EncodedValue {
8086
&self,
8187
_ty: &Type,
8288
out: &mut BytesMut,
89+
_format_options: &FormatOptions,
8390
) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
8491
where
8592
Self: Sized,
@@ -491,7 +498,8 @@ mod tests {
491498
T: ToSql + ToSqlText + Sized,
492499
{
493500
let mut bytes = BytesMut::new();
494-
let _sql_text = value.to_sql_text(pg_field.datatype(), &mut bytes);
501+
let _sql_text =
502+
value.to_sql_text(pg_field.datatype(), &mut bytes, &FormatOptions::default());
495503
let string = String::from_utf8(bytes.to_vec());
496504
self.encoded_value = string.unwrap();
497505
Ok(())

arrow-pg/src/list_encoder.rs

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use chrono::{DateTime, TimeZone, Utc};
4040
use pgwire::api::results::{FieldFormat, FieldInfo};
4141
use pgwire::error::{PgWireError, PgWireResult};
4242
use pgwire::types::{ToSqlText, QUOTE_ESCAPE};
43-
use postgres_types::{ToSql, Type};
43+
use postgres_types::ToSql;
4444
use rust_decimal::Decimal;
4545

4646
use crate::encoder::EncodedValue;
@@ -93,14 +93,13 @@ get_primitive_list_value!(get_u64_list_value, UInt64Type, i64, |val: u64| {
9393
get_primitive_list_value!(get_f32_list_value, Float32Type, f32);
9494
get_primitive_list_value!(get_f64_list_value, Float64Type, f64);
9595

96-
fn encode_field<T: ToSql + ToSqlText>(
97-
t: &[T],
98-
type_: &Type,
99-
format: FieldFormat,
100-
) -> PgWireResult<EncodedValue> {
96+
fn encode_field<T: ToSql + ToSqlText>(t: &[T], field: &FieldInfo) -> PgWireResult<EncodedValue> {
10197
let mut bytes = BytesMut::new();
98+
99+
let format = field.format();
100+
let type_ = field.datatype();
102101
match format {
103-
FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?,
102+
FieldFormat::Text => t.to_sql_text(type_, &mut bytes, field.format_options().as_ref())?,
104103
FieldFormat::Binary => t.to_sql(type_, &mut bytes)?,
105104
};
106105
Ok(EncodedValue { bytes })
@@ -114,22 +113,24 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
114113
DataType::Null => {
115114
let mut bytes = BytesMut::new();
116115
match format {
117-
FieldFormat::Text => None::<i8>.to_sql_text(type_, &mut bytes),
116+
FieldFormat::Text => {
117+
None::<i8>.to_sql_text(type_, &mut bytes, pg_field.format_options().as_ref())
118+
}
118119
FieldFormat::Binary => None::<i8>.to_sql(type_, &mut bytes),
119120
}?;
120121
Ok(EncodedValue { bytes })
121122
}
122-
DataType::Boolean => encode_field(&get_bool_list_value(&arr), type_, format),
123-
DataType::Int8 => encode_field(&get_i8_list_value(&arr), type_, format),
124-
DataType::Int16 => encode_field(&get_i16_list_value(&arr), type_, format),
125-
DataType::Int32 => encode_field(&get_i32_list_value(&arr), type_, format),
126-
DataType::Int64 => encode_field(&get_i64_list_value(&arr), type_, format),
127-
DataType::UInt8 => encode_field(&get_u8_list_value(&arr), type_, format),
128-
DataType::UInt16 => encode_field(&get_u16_list_value(&arr), type_, format),
129-
DataType::UInt32 => encode_field(&get_u32_list_value(&arr), type_, format),
130-
DataType::UInt64 => encode_field(&get_u64_list_value(&arr), type_, format),
131-
DataType::Float32 => encode_field(&get_f32_list_value(&arr), type_, format),
132-
DataType::Float64 => encode_field(&get_f64_list_value(&arr), type_, format),
123+
DataType::Boolean => encode_field(&get_bool_list_value(&arr), pg_field),
124+
DataType::Int8 => encode_field(&get_i8_list_value(&arr), pg_field),
125+
DataType::Int16 => encode_field(&get_i16_list_value(&arr), pg_field),
126+
DataType::Int32 => encode_field(&get_i32_list_value(&arr), pg_field),
127+
DataType::Int64 => encode_field(&get_i64_list_value(&arr), pg_field),
128+
DataType::UInt8 => encode_field(&get_u8_list_value(&arr), pg_field),
129+
DataType::UInt16 => encode_field(&get_u16_list_value(&arr), pg_field),
130+
DataType::UInt32 => encode_field(&get_u32_list_value(&arr), pg_field),
131+
DataType::UInt64 => encode_field(&get_u64_list_value(&arr), pg_field),
132+
DataType::Float32 => encode_field(&get_f32_list_value(&arr), pg_field),
133+
DataType::Float64 => encode_field(&get_f64_list_value(&arr), pg_field),
133134
DataType::Decimal128(_, s) => {
134135
let value: Vec<_> = arr
135136
.as_any()
@@ -138,7 +139,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
138139
.iter()
139140
.map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
140141
.collect();
141-
encode_field(&value, type_, format)
142+
encode_field(&value, pg_field)
142143
}
143144
DataType::Utf8 => {
144145
let value: Vec<Option<&str>> = arr
@@ -147,7 +148,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
147148
.unwrap()
148149
.iter()
149150
.collect();
150-
encode_field(&value, type_, format)
151+
encode_field(&value, pg_field)
151152
}
152153
DataType::Utf8View => {
153154
let value: Vec<Option<&str>> = arr
@@ -156,7 +157,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
156157
.unwrap()
157158
.iter()
158159
.collect();
159-
encode_field(&value, type_, format)
160+
encode_field(&value, pg_field)
160161
}
161162
DataType::Binary => {
162163
let value: Vec<Option<_>> = arr
@@ -165,7 +166,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
165166
.unwrap()
166167
.iter()
167168
.collect();
168-
encode_field(&value, type_, format)
169+
encode_field(&value, pg_field)
169170
}
170171
DataType::LargeBinary => {
171172
let value: Vec<Option<_>> = arr
@@ -174,7 +175,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
174175
.unwrap()
175176
.iter()
176177
.collect();
177-
encode_field(&value, type_, format)
178+
encode_field(&value, pg_field)
178179
}
179180
DataType::BinaryView => {
180181
let value: Vec<Option<_>> = arr
@@ -183,7 +184,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
183184
.unwrap()
184185
.iter()
185186
.collect();
186-
encode_field(&value, type_, format)
187+
encode_field(&value, pg_field)
187188
}
188189

189190
DataType::Date32 => {
@@ -194,7 +195,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
194195
.iter()
195196
.map(|val| val.and_then(|x| as_date::<Date32Type>(x as i64)))
196197
.collect();
197-
encode_field(&value, type_, format)
198+
encode_field(&value, pg_field)
198199
}
199200
DataType::Date64 => {
200201
let value: Vec<Option<_>> = arr
@@ -204,7 +205,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
204205
.iter()
205206
.map(|val| val.and_then(as_date::<Date64Type>))
206207
.collect();
207-
encode_field(&value, type_, format)
208+
encode_field(&value, pg_field)
208209
}
209210
DataType::Time32(unit) => match unit {
210211
TimeUnit::Second => {
@@ -215,7 +216,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
215216
.iter()
216217
.map(|val| val.and_then(|x| as_time::<Time32SecondType>(x as i64)))
217218
.collect();
218-
encode_field(&value, type_, format)
219+
encode_field(&value, pg_field)
219220
}
220221
TimeUnit::Millisecond => {
221222
let value: Vec<Option<_>> = arr
@@ -225,7 +226,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
225226
.iter()
226227
.map(|val| val.and_then(|x| as_time::<Time32MillisecondType>(x as i64)))
227228
.collect();
228-
encode_field(&value, type_, format)
229+
encode_field(&value, pg_field)
229230
}
230231
_ => {
231232
// Time32 only supports Second and Millisecond in Arrow
@@ -242,7 +243,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
242243
.iter()
243244
.map(|val| val.and_then(as_time::<Time64MicrosecondType>))
244245
.collect();
245-
encode_field(&value, type_, format)
246+
encode_field(&value, pg_field)
246247
}
247248
TimeUnit::Nanosecond => {
248249
let value: Vec<Option<_>> = arr
@@ -252,7 +253,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
252253
.iter()
253254
.map(|val| val.and_then(as_time::<Time64NanosecondType>))
254255
.collect();
255-
encode_field(&value, type_, format)
256+
encode_field(&value, pg_field)
256257
}
257258
_ => {
258259
// Time64 only supports Microsecond and Nanosecond in Arrow
@@ -282,14 +283,14 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
282283
})
283284
})
284285
.collect();
285-
encode_field(&value, type_, format)
286+
encode_field(&value, pg_field)
286287
} else {
287288
let value: Vec<_> = array_iter
288289
.map(|i| {
289290
i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc()))
290291
})
291292
.collect();
292-
encode_field(&value, type_, format)
293+
encode_field(&value, pg_field)
293294
}
294295
}
295296
TimeUnit::Millisecond => {
@@ -312,7 +313,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
312313
})
313314
})
314315
.collect();
315-
encode_field(&value, type_, format)
316+
encode_field(&value, pg_field)
316317
} else {
317318
let value: Vec<_> = array_iter
318319
.map(|i| {
@@ -321,7 +322,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
321322
})
322323
})
323324
.collect();
324-
encode_field(&value, type_, format)
325+
encode_field(&value, pg_field)
325326
}
326327
}
327328
TimeUnit::Microsecond => {
@@ -344,7 +345,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
344345
})
345346
})
346347
.collect();
347-
encode_field(&value, type_, format)
348+
encode_field(&value, pg_field)
348349
} else {
349350
let value: Vec<_> = array_iter
350351
.map(|i| {
@@ -353,7 +354,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
353354
})
354355
})
355356
.collect();
356-
encode_field(&value, type_, format)
357+
encode_field(&value, pg_field)
357358
}
358359
}
359360
TimeUnit::Nanosecond => {
@@ -376,12 +377,12 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
376377
})
377378
})
378379
.collect();
379-
encode_field(&value, type_, format)
380+
encode_field(&value, pg_field)
380381
} else {
381382
let value: Vec<_> = array_iter
382383
.map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc()))
383384
.collect();
384-
encode_field(&value, type_, format)
385+
encode_field(&value, pg_field)
385386
}
386387
}
387388
},
@@ -429,7 +430,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
429430
}
430431
})
431432
.collect();
432-
encode_field(&values?, type_, format)
433+
encode_field(&values?, pg_field)
433434
}
434435
DataType::LargeUtf8 => {
435436
let value: Vec<Option<&str>> = arr
@@ -438,7 +439,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
438439
.unwrap()
439440
.iter()
440441
.collect();
441-
encode_field(&value, type_, format)
442+
encode_field(&value, pg_field)
442443
}
443444
DataType::Decimal256(_, s) => {
444445
// Convert Decimal256 to string representation for now
@@ -473,7 +474,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
473474
}
474475
})
475476
.collect();
476-
encode_field(&value, type_, format)
477+
encode_field(&value, pg_field)
477478
}
478479
DataType::Duration(_) => {
479480
// Convert duration to microseconds for now
@@ -483,7 +484,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
483484
.unwrap()
484485
.iter()
485486
.collect();
486-
encode_field(&value, type_, format)
487+
encode_field(&value, pg_field)
487488
}
488489
DataType::List(_) => {
489490
// Support for nested lists (list of lists)
@@ -499,7 +500,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
499500
}
500501
})
501502
.collect();
502-
encode_field(&value, type_, format)
503+
encode_field(&value, pg_field)
503504
}
504505
DataType::LargeList(_) => {
505506
// Support for large lists
@@ -513,7 +514,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
513514
}
514515
})
515516
.collect();
516-
encode_field(&value, type_, format)
517+
encode_field(&value, pg_field)
517518
}
518519
DataType::Map(_, _) => {
519520
// Support for map types
@@ -527,7 +528,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
527528
}
528529
})
529530
.collect();
530-
encode_field(&value, type_, format)
531+
encode_field(&value, pg_field)
531532
}
532533

533534
DataType::Union(_, _) => {
@@ -541,7 +542,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
541542
}
542543
})
543544
.collect();
544-
encode_field(&value, type_, format)
545+
encode_field(&value, pg_field)
545546
}
546547
DataType::Dictionary(_, _) => {
547548
// Support for dictionary types
@@ -554,7 +555,7 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
554555
}
555556
})
556557
.collect();
557-
encode_field(&value, type_, format)
558+
encode_field(&value, pg_field)
558559
}
559560
// TODO: add support for more advanced types (fixed size lists, etc.)
560561
list_type => Err(PgWireError::ApiError(ToSqlError::from(format!(

0 commit comments

Comments
 (0)