Skip to content

Commit 2972e80

Browse files
committed
feat: simplify Encoder trait function
1 parent 2c82432 commit 2972e80

File tree

2 files changed

+70
-158
lines changed

2 files changed

+70
-158
lines changed

arrow-pg/src/encoder.rs

Lines changed: 62 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ use bytes::BytesMut;
1010
use chrono::{NaiveDate, NaiveDateTime};
1111
#[cfg(feature = "datafusion")]
1212
use datafusion::arrow::{array::*, datatypes::*};
13-
use pgwire::api::results::DataRowEncoder;
14-
use pgwire::api::results::FieldFormat;
15-
use pgwire::api::results::FieldInfo;
13+
use pgwire::api::results::{DataRowEncoder, FieldInfo};
1614
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
1715
use pgwire::types::ToSqlText;
1816
use postgres_types::{ToSql, Type};
@@ -24,27 +22,17 @@ use crate::list_encoder::encode_list;
2422
use crate::struct_encoder::encode_struct;
2523

2624
pub trait Encoder {
27-
fn encode_field_with_type_and_format<T>(
28-
&mut self,
29-
value: &T,
30-
data_type: &Type,
31-
format: FieldFormat,
32-
) -> PgWireResult<()>
25+
fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
3326
where
3427
T: ToSql + ToSqlText + Sized;
3528
}
3629

3730
impl Encoder for DataRowEncoder {
38-
fn encode_field_with_type_and_format<T>(
39-
&mut self,
40-
value: &T,
41-
data_type: &Type,
42-
format: FieldFormat,
43-
) -> PgWireResult<()>
31+
fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
4432
where
4533
T: ToSql + ToSqlText + Sized,
4634
{
47-
self.encode_field_with_type_and_format(value, data_type, format)
35+
self.encode_field_with_type_and_format(value, pg_field.datatype(), pg_field.format())
4836
}
4937
}
5038

@@ -295,120 +283,61 @@ pub fn encode_value<T: Encoder>(
295283
pg_field: &FieldInfo,
296284
) -> PgWireResult<()> {
297285
let type_ = pg_field.datatype();
298-
let format = pg_field.format();
299286

300287
match arr.data_type() {
301-
DataType::Null => encoder.encode_field_with_type_and_format(&None::<i8>, type_, format)?,
302-
DataType::Boolean => {
303-
encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)?
288+
DataType::Null => encoder.encode_field(&None::<i8>, pg_field)?,
289+
DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx), pg_field)?,
290+
DataType::Int8 => encoder.encode_field(&get_i8_value(arr, idx), pg_field)?,
291+
DataType::Int16 => encoder.encode_field(&get_i16_value(arr, idx), pg_field)?,
292+
DataType::Int32 => encoder.encode_field(&get_i32_value(arr, idx), pg_field)?,
293+
DataType::Int64 => encoder.encode_field(&get_i64_value(arr, idx), pg_field)?,
294+
DataType::UInt8 => {
295+
encoder.encode_field(&(get_u8_value(arr, idx).map(|x| x as i8)), pg_field)?
304296
}
305-
DataType::Int8 => {
306-
encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)?
297+
DataType::UInt16 => {
298+
encoder.encode_field(&(get_u16_value(arr, idx).map(|x| x as i16)), pg_field)?
307299
}
308-
DataType::Int16 => {
309-
encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)?
300+
DataType::UInt32 => encoder.encode_field(&get_u32_value(arr, idx), pg_field)?,
301+
DataType::UInt64 => {
302+
encoder.encode_field(&(get_u64_value(arr, idx).map(|x| x as i64)), pg_field)?
310303
}
311-
DataType::Int32 => {
312-
encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)?
304+
DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx), pg_field)?,
305+
DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx), pg_field)?,
306+
DataType::Decimal128(_, s) => {
307+
encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?, pg_field)?
313308
}
314-
DataType::Int64 => {
315-
encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)?
316-
}
317-
DataType::UInt8 => encoder.encode_field_with_type_and_format(
318-
&(get_u8_value(arr, idx).map(|x| x as i8)),
319-
type_,
320-
format,
321-
)?,
322-
DataType::UInt16 => encoder.encode_field_with_type_and_format(
323-
&(get_u16_value(arr, idx).map(|x| x as i16)),
324-
type_,
325-
format,
326-
)?,
327-
DataType::UInt32 => {
328-
encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)?
329-
}
330-
DataType::UInt64 => encoder.encode_field_with_type_and_format(
331-
&(get_u64_value(arr, idx).map(|x| x as i64)),
332-
type_,
333-
format,
334-
)?,
335-
DataType::Float32 => {
336-
encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)?
337-
}
338-
DataType::Float64 => {
339-
encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)?
340-
}
341-
DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format(
342-
&get_numeric_128_value(arr, idx, *s as u32)?,
343-
type_,
344-
format,
345-
)?,
346-
DataType::Utf8 => {
347-
encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)?
348-
}
349-
DataType::Utf8View => encoder.encode_field_with_type_and_format(
350-
&get_utf8_view_value(arr, idx),
351-
type_,
352-
format,
353-
)?,
354-
DataType::BinaryView => encoder.encode_field_with_type_and_format(
355-
&get_binary_view_value(arr, idx),
356-
type_,
357-
format,
358-
)?,
359-
DataType::LargeUtf8 => encoder.encode_field_with_type_and_format(
360-
&get_large_utf8_value(arr, idx),
361-
type_,
362-
format,
363-
)?,
364-
DataType::Binary => {
365-
encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)?
366-
}
367-
DataType::LargeBinary => encoder.encode_field_with_type_and_format(
368-
&get_large_binary_value(arr, idx),
369-
type_,
370-
format,
371-
)?,
372-
DataType::Date32 => {
373-
encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)?
374-
}
375-
DataType::Date64 => {
376-
encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)?
309+
DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx), pg_field)?,
310+
DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx), pg_field)?,
311+
DataType::BinaryView => encoder.encode_field(&get_binary_view_value(arr, idx), pg_field)?,
312+
DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx), pg_field)?,
313+
DataType::Binary => encoder.encode_field(&get_binary_value(arr, idx), pg_field)?,
314+
DataType::LargeBinary => {
315+
encoder.encode_field(&get_large_binary_value(arr, idx), pg_field)?
377316
}
317+
DataType::Date32 => encoder.encode_field(&get_date32_value(arr, idx), pg_field)?,
318+
DataType::Date64 => encoder.encode_field(&get_date64_value(arr, idx), pg_field)?,
378319
DataType::Time32(unit) => match unit {
379-
TimeUnit::Second => encoder.encode_field_with_type_and_format(
380-
&get_time32_second_value(arr, idx),
381-
type_,
382-
format,
383-
)?,
384-
TimeUnit::Millisecond => encoder.encode_field_with_type_and_format(
385-
&get_time32_millisecond_value(arr, idx),
386-
type_,
387-
format,
388-
)?,
320+
TimeUnit::Second => {
321+
encoder.encode_field(&get_time32_second_value(arr, idx), pg_field)?
322+
}
323+
TimeUnit::Millisecond => {
324+
encoder.encode_field(&get_time32_millisecond_value(arr, idx), pg_field)?
325+
}
389326
_ => {}
390327
},
391328
DataType::Time64(unit) => match unit {
392-
TimeUnit::Microsecond => encoder.encode_field_with_type_and_format(
393-
&get_time64_microsecond_value(arr, idx),
394-
type_,
395-
format,
396-
)?,
397-
TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format(
398-
&get_time64_nanosecond_value(arr, idx),
399-
type_,
400-
format,
401-
)?,
329+
TimeUnit::Microsecond => {
330+
encoder.encode_field(&get_time64_microsecond_value(arr, idx), pg_field)?
331+
}
332+
TimeUnit::Nanosecond => {
333+
encoder.encode_field(&get_time64_nanosecond_value(arr, idx), pg_field)?
334+
}
402335
_ => {}
403336
},
404337
DataType::Timestamp(unit, timezone) => match unit {
405338
TimeUnit::Second => {
406339
if arr.is_null(idx) {
407-
return encoder.encode_field_with_type_and_format(
408-
&None::<NaiveDateTime>,
409-
type_,
410-
format,
411-
);
340+
return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
412341
}
413342
let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
414343
if let Some(tz) = timezone {
@@ -417,19 +346,15 @@ pub fn encode_value<T: Encoder>(
417346
.value_as_datetime_with_tz(idx, tz)
418347
.map(|d| d.fixed_offset());
419348

420-
encoder.encode_field_with_type_and_format(&value, type_, format)?;
349+
encoder.encode_field(&value, pg_field)?;
421350
} else {
422351
let value = ts_array.value_as_datetime(idx);
423-
encoder.encode_field_with_type_and_format(&value, type_, format)?;
352+
encoder.encode_field(&value, pg_field)?;
424353
}
425354
}
426355
TimeUnit::Millisecond => {
427356
if arr.is_null(idx) {
428-
return encoder.encode_field_with_type_and_format(
429-
&None::<NaiveDateTime>,
430-
type_,
431-
format,
432-
);
357+
return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
433358
}
434359
let ts_array = arr
435360
.as_any()
@@ -440,19 +365,15 @@ pub fn encode_value<T: Encoder>(
440365
let value = ts_array
441366
.value_as_datetime_with_tz(idx, tz)
442367
.map(|d| d.fixed_offset());
443-
encoder.encode_field_with_type_and_format(&value, type_, format)?;
368+
encoder.encode_field(&value, pg_field)?;
444369
} else {
445370
let value = ts_array.value_as_datetime(idx);
446-
encoder.encode_field_with_type_and_format(&value, type_, format)?;
371+
encoder.encode_field(&value, pg_field)?;
447372
}
448373
}
449374
TimeUnit::Microsecond => {
450375
if arr.is_null(idx) {
451-
return encoder.encode_field_with_type_and_format(
452-
&None::<NaiveDateTime>,
453-
type_,
454-
format,
455-
);
376+
return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
456377
}
457378
let ts_array = arr
458379
.as_any()
@@ -463,19 +384,15 @@ pub fn encode_value<T: Encoder>(
463384
let value = ts_array
464385
.value_as_datetime_with_tz(idx, tz)
465386
.map(|d| d.fixed_offset());
466-
encoder.encode_field_with_type_and_format(&value, type_, format)?;
387+
encoder.encode_field(&value, pg_field)?;
467388
} else {
468389
let value = ts_array.value_as_datetime(idx);
469-
encoder.encode_field_with_type_and_format(&value, type_, format)?;
390+
encoder.encode_field(&value, pg_field)?;
470391
}
471392
}
472393
TimeUnit::Nanosecond => {
473394
if arr.is_null(idx) {
474-
return encoder.encode_field_with_type_and_format(
475-
&None::<NaiveDateTime>,
476-
type_,
477-
format,
478-
);
395+
return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
479396
}
480397
let ts_array = arr
481398
.as_any()
@@ -486,20 +403,20 @@ pub fn encode_value<T: Encoder>(
486403
let value = ts_array
487404
.value_as_datetime_with_tz(idx, tz)
488405
.map(|d| d.fixed_offset());
489-
encoder.encode_field_with_type_and_format(&value, type_, format)?;
406+
encoder.encode_field(&value, pg_field)?;
490407
} else {
491408
let value = ts_array.value_as_datetime(idx);
492-
encoder.encode_field_with_type_and_format(&value, type_, format)?;
409+
encoder.encode_field(&value, pg_field)?;
493410
}
494411
}
495412
},
496413
DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
497414
if arr.is_null(idx) {
498-
return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format);
415+
return encoder.encode_field(&None::<&[i8]>, pg_field);
499416
}
500417
let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
501418
let value = encode_list(array, pg_field)?;
502-
encoder.encode_field_with_type_and_format(&value, type_, format)?
419+
encoder.encode_field(&value, pg_field)?
503420
}
504421
DataType::Struct(_) => {
505422
let fields = match type_.kind() {
@@ -511,11 +428,11 @@ pub fn encode_value<T: Encoder>(
511428
}
512429
};
513430
let value = encode_struct(arr, idx, fields, pg_field)?;
514-
encoder.encode_field_with_type_and_format(&value, type_, format)?
431+
encoder.encode_field(&value, pg_field)?
515432
}
516433
DataType::Dictionary(_, value_type) => {
517434
if arr.is_null(idx) {
518-
return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
435+
return encoder.encode_field(&None::<i8>, pg_field);
519436
}
520437
// Get the dictionary values and the mapped row index
521438
macro_rules! get_dict_values_and_index {
@@ -557,6 +474,8 @@ pub fn encode_value<T: Encoder>(
557474

558475
#[cfg(test)]
559476
mod tests {
477+
use pgwire::api::results::FieldFormat;
478+
560479
use super::*;
561480

562481
#[test]
@@ -567,17 +486,12 @@ mod tests {
567486
}
568487

569488
impl Encoder for MockEncoder {
570-
fn encode_field_with_type_and_format<T>(
571-
&mut self,
572-
value: &T,
573-
data_type: &Type,
574-
_format: FieldFormat,
575-
) -> PgWireResult<()>
489+
fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
576490
where
577491
T: ToSql + ToSqlText + Sized,
578492
{
579493
let mut bytes = BytesMut::new();
580-
let _sql_text = value.to_sql_text(data_type, &mut bytes);
494+
let _sql_text = value.to_sql_text(pg_field.datatype(), &mut bytes);
581495
let string = String::from_utf8(bytes.to_vec());
582496
self.encoded_value = string.unwrap();
583497
Ok(())

arrow-pg/src/struct_encoder.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use bytes::{BufMut, BytesMut};
99
use pgwire::api::results::{FieldFormat, FieldInfo};
1010
use pgwire::error::PgWireResult;
1111
use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE};
12-
use postgres_types::{Field, IsNull, ToSql, Type};
12+
use postgres_types::{Field, IsNull, ToSql};
1313

1414
use crate::encoder::{encode_value, EncodedValue, Encoder};
1515

@@ -63,22 +63,20 @@ impl StructEncoder {
6363
}
6464

6565
impl Encoder for StructEncoder {
66-
fn encode_field_with_type_and_format<T>(
67-
&mut self,
68-
value: &T,
69-
data_type: &Type,
70-
format: FieldFormat,
71-
) -> PgWireResult<()>
66+
fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
7267
where
7368
T: ToSql + ToSqlText + Sized,
7469
{
70+
let datatype = pg_field.datatype();
71+
let format = pg_field.format();
72+
7573
if format == FieldFormat::Text {
7674
if self.curr_col == 0 {
7775
self.row_buffer.put_slice(b"(");
7876
}
7977
// encode value in an intermediate buf
8078
let mut buf = BytesMut::new();
81-
value.to_sql_text(data_type, &mut buf)?;
79+
value.to_sql_text(datatype, &mut buf)?;
8280
let encoded_value_as_str = String::from_utf8_lossy(&buf);
8381
if QUOTE_CHECK.is_match(&encoded_value_as_str) {
8482
self.row_buffer.put_u8(b'"');
@@ -102,12 +100,12 @@ impl Encoder for StructEncoder {
102100
self.row_buffer.put_i32(self.num_cols as i32);
103101
}
104102

105-
self.row_buffer.put_u32(data_type.oid());
103+
self.row_buffer.put_u32(datatype.oid());
106104
// remember the position of the 4-byte length field
107105
let prev_index = self.row_buffer.len();
108106
// write value length as -1 ahead of time
109107
self.row_buffer.put_i32(-1);
110-
let is_null = value.to_sql(data_type, &mut self.row_buffer)?;
108+
let is_null = value.to_sql(datatype, &mut self.row_buffer)?;
111109
if let IsNull::No = is_null {
112110
let value_length = self.row_buffer.len() - prev_index - 4;
113111
let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];

0 commit comments

Comments
 (0)