Skip to content

Commit 21e4ccd

Browse files
committed
fixed dict bug
1 parent 9f95f56 commit 21e4ccd

File tree

1 file changed

+66
-53
lines changed

1 file changed

+66
-53
lines changed

arrow-pg/src/encoder.rs

Lines changed: 66 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -294,27 +294,21 @@ pub fn encode_value<T: Encoder>(
294294
type_: &Type,
295295
format: FieldFormat,
296296
) -> PgWireResult<()> {
297-
println!("[DATAFUSION POSTGRES ENCODER] encode_value: idx={}, type={:?}", idx, arr.data_type());
298297
match arr.data_type() {
299298
DataType::Null => encoder.encode_field_with_type_and_format(&None::<i8>, type_, format)?,
300299
DataType::Boolean => {
301-
println!("BOOLEAN BRANCH");
302300
encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)?
303301
}
304302
DataType::Int8 => {
305-
println!("BOOLEAN BRANCH");
306303
encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)?
307304
}
308305
DataType::Int16 => {
309-
println!("BOOLEAN BRANCH");
310306
encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)?
311307
}
312308
DataType::Int32 => {
313-
println!("BOOLEAN BRANCH");
314309
encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)?
315310
}
316311
DataType::Int64 => {
317-
println!("BOOLEAN BRANCH");
318312
encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)?
319313
}
320314
DataType::UInt8 => encoder.encode_field_with_type_and_format(
@@ -496,7 +490,6 @@ pub fn encode_value<T: Encoder>(
496490
}
497491
},
498492
DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
499-
println!("LIST BRANCH");
500493
if arr.is_null(idx) {
501494
return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format);
502495
}
@@ -505,7 +498,6 @@ pub fn encode_value<T: Encoder>(
505498
encoder.encode_field_with_type_and_format(&value, type_, format)?
506499
}
507500
DataType::Struct(_) => {
508-
println!("STRUCT BRANCH");
509501
let fields = match type_.kind() {
510502
postgres_types::Kind::Composite(fields) => fields,
511503
_ => {
@@ -518,66 +510,34 @@ pub fn encode_value<T: Encoder>(
518510
encoder.encode_field_with_type_and_format(&value, type_, format)?
519511
}
520512
DataType::Dictionary(_, value_type) => {
521-
println!("DICT BRANCH");
522513
if arr.is_null(idx) {
523514
return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
524515
}
525-
526-
macro_rules! get_value_index {
527-
($key_ty:ty) => {
528-
arr.as_any()
529-
.downcast_ref::<DictionaryArray<$key_ty>>()
530-
.map(|dict| dict.keys().value(idx) as usize)
531-
};
532-
}
533-
534-
let value_idx = get_value_index!(Int8Type)
535-
.or_else(|| get_value_index!(Int16Type))
536-
.or_else(|| get_value_index!(Int32Type))
537-
.or_else(|| get_value_index!(Int64Type))
538-
.or_else(|| get_value_index!(UInt8Type))
539-
.or_else(|| get_value_index!(UInt16Type))
540-
.or_else(|| get_value_index!(UInt32Type))
541-
.or_else(|| get_value_index!(UInt64Type))
542-
.ok_or_else(|| {
543-
ToSqlError::from(format!(
544-
"Unsupported dictionary key type"
545-
))
546-
})?;
547-
548-
549-
// Get the dictionary values, ignoring keys
550-
// We'll use Int32Type as a common key type, but we're only interested in values
551-
macro_rules! get_dict_values {
516+
// Get the dictionary values and the mapped row index
517+
macro_rules! get_dict_values_and_index {
552518
($key_type:ty) => {
553519
arr.as_any()
554520
.downcast_ref::<DictionaryArray<$key_type>>()
555-
.map(|dict| dict.values().clone())
521+
.map(|dict| (dict.values(), dict.keys().value(idx) as usize))
556522
};
557523
}
558524

559525
// Try to extract values using different key types
560-
let values = get_dict_values!(Int8Type)
561-
.or_else(|| get_dict_values!(Int16Type))
562-
.or_else(|| get_dict_values!(Int32Type))
563-
.or_else(|| get_dict_values!(Int64Type))
564-
.or_else(|| get_dict_values!(UInt8Type))
565-
.or_else(|| get_dict_values!(UInt16Type))
566-
.or_else(|| get_dict_values!(UInt32Type))
567-
.or_else(|| get_dict_values!(UInt64Type))
526+
let (values, idx) = get_dict_values_and_index!(Int8Type)
527+
.or_else(|| get_dict_values_and_index!(Int16Type))
528+
.or_else(|| get_dict_values_and_index!(Int32Type))
529+
.or_else(|| get_dict_values_and_index!(Int64Type))
530+
.or_else(|| get_dict_values_and_index!(UInt8Type))
531+
.or_else(|| get_dict_values_and_index!(UInt16Type))
532+
.or_else(|| get_dict_values_and_index!(UInt32Type))
533+
.or_else(|| get_dict_values_and_index!(UInt64Type))
568534
.ok_or_else(|| {
569535
ToSqlError::from(format!(
570-
"Unsupported dictionary key type"
536+
"Unsupported dictionary key type for value type {value_type}"
571537
))
572538
})?;
573539

574-
// If the dictionary has only one value, treat it as a primitive
575-
if values.len() == 1 {
576-
encode_value(encoder, &values, 0, type_, format)?
577-
} else {
578-
// Otherwise, use value directly indexed by values array
579-
encode_value(encoder, &values, idx, type_, format)?
580-
}
540+
encode_value(encoder, values, idx, type_, format)?
581541
}
582542
_ => {
583543
return Err(PgWireError::ApiError(ToSqlError::from(format!(
@@ -590,3 +550,56 @@ pub fn encode_value<T: Encoder>(
590550

591551
Ok(())
592552
}
553+
554+
#[cfg(test)]
555+
mod tests{
556+
use super::*;
557+
558+
#[test]
559+
fn encodes_dictionary_array() {
560+
561+
#[derive(Default)]
562+
struct MockEncoder {
563+
encoded_value: String,
564+
}
565+
566+
impl Encoder for MockEncoder {
567+
fn encode_field_with_type_and_format<T>(
568+
&mut self,
569+
value: &T,
570+
data_type: &Type,
571+
_format: FieldFormat,
572+
) -> PgWireResult<()>
573+
where
574+
T: ToSql + ToSqlText + Sized {
575+
let mut bytes = BytesMut::new();
576+
let _sql_text = value.to_sql_text(data_type, &mut bytes);
577+
let string = String::from_utf8((&bytes).to_vec());
578+
self.encoded_value = string.unwrap();
579+
Ok(())
580+
}
581+
}
582+
583+
let val = "~!@&$[]()@@!!";
584+
let value = StringArray::from_iter_values([val]);
585+
let keys = Int8Array::from_iter_values([0, 0, 0, 0]);
586+
let dict_arr: Arc<dyn Array> = Arc::new(DictionaryArray::<Int8Type>::try_new(keys, Arc::new(value)).unwrap());
587+
588+
589+
let mut encoder = MockEncoder::default();
590+
591+
let result = encode_value(
592+
&mut encoder,
593+
&dict_arr,
594+
2,
595+
&Type::TEXT,
596+
FieldFormat::Text,
597+
);
598+
599+
assert!(result.is_ok());
600+
601+
assert!(encoder.encoded_value == val);
602+
603+
}
604+
}
605+

0 commit comments

Comments
 (0)