Skip to content

Commit 9c42bcf

Browse files
committed
added dict bug fix
1 parent d1b0db0 commit 9c42bcf

File tree

1 file changed

+57
-19
lines changed

1 file changed

+57
-19
lines changed

arrow-pg/src/encoder.rs

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -513,38 +513,31 @@ pub fn encode_value<T: Encoder>(
513513
if arr.is_null(idx) {
514514
return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
515515
}
516-
// Get the dictionary values, ignoring keys
517-
// We'll use Int32Type as a common key type, but we're only interested in values
518-
macro_rules! get_dict_values {
516+
// Get the dictionary values and the mapped row index
517+
macro_rules! get_dict_values_and_index {
519518
($key_type:ty) => {
520519
arr.as_any()
521520
.downcast_ref::<DictionaryArray<$key_type>>()
522-
.map(|dict| dict.values())
521+
.map(|dict| (dict.values(), dict.keys().value(idx) as usize))
523522
};
524523
}
525524

526525
// Try to extract values using different key types
527-
let values = get_dict_values!(Int8Type)
528-
.or_else(|| get_dict_values!(Int16Type))
529-
.or_else(|| get_dict_values!(Int32Type))
530-
.or_else(|| get_dict_values!(Int64Type))
531-
.or_else(|| get_dict_values!(UInt8Type))
532-
.or_else(|| get_dict_values!(UInt16Type))
533-
.or_else(|| get_dict_values!(UInt32Type))
534-
.or_else(|| get_dict_values!(UInt64Type))
526+
let (values, idx) = get_dict_values_and_index!(Int8Type)
527+
.or_else(|| get_dict_values_and_index!(Int16Type))
528+
.or_else(|| get_dict_values_and_index!(Int32Type))
529+
.or_else(|| get_dict_values_and_index!(Int64Type))
530+
.or_else(|| get_dict_values_and_index!(UInt8Type))
531+
.or_else(|| get_dict_values_and_index!(UInt16Type))
532+
.or_else(|| get_dict_values_and_index!(UInt32Type))
533+
.or_else(|| get_dict_values_and_index!(UInt64Type))
535534
.ok_or_else(|| {
536535
ToSqlError::from(format!(
537536
"Unsupported dictionary key type for value type {value_type}"
538537
))
539538
})?;
540539

541-
// If the dictionary has only one value, treat it as a primitive
542-
if values.len() == 1 {
543-
encode_value(encoder, values, 0, type_, format)?
544-
} else {
545-
// Otherwise, use value directly indexed by values array
546-
encode_value(encoder, values, idx, type_, format)?
547-
}
540+
encode_value(encoder, values, idx, type_, format)?
548541
}
549542
_ => {
550543
return Err(PgWireError::ApiError(ToSqlError::from(format!(
@@ -557,3 +550,48 @@ pub fn encode_value<T: Encoder>(
557550

558551
Ok(())
559552
}
553+
554+
#[cfg(test)]
555+
mod tests {
556+
use super::*;
557+
558+
#[test]
559+
fn encodes_dictionary_array() {
560+
#[derive(Default)]
561+
struct MockEncoder {
562+
encoded_value: String,
563+
}
564+
565+
impl Encoder for MockEncoder {
566+
fn encode_field_with_type_and_format<T>(
567+
&mut self,
568+
value: &T,
569+
data_type: &Type,
570+
_format: FieldFormat,
571+
) -> PgWireResult<()>
572+
where
573+
T: ToSql + ToSqlText + Sized,
574+
{
575+
let mut bytes = BytesMut::new();
576+
let _sql_text = value.to_sql_text(data_type, &mut bytes);
577+
let string = String::from_utf8((&bytes).to_vec());
578+
self.encoded_value = string.unwrap();
579+
Ok(())
580+
}
581+
}
582+
583+
let val = "~!@&$[]()@@!!";
584+
let value = StringArray::from_iter_values([val]);
585+
let keys = Int8Array::from_iter_values([0, 0, 0, 0]);
586+
let dict_arr: Arc<dyn Array> =
587+
Arc::new(DictionaryArray::<Int8Type>::try_new(keys, Arc::new(value)).unwrap());
588+
589+
let mut encoder = MockEncoder::default();
590+
591+
let result = encode_value(&mut encoder, &dict_arr, 2, &Type::TEXT, FieldFormat::Text);
592+
593+
assert!(result.is_ok());
594+
595+
assert!(encoder.encoded_value == val);
596+
}
597+
}

0 commit comments

Comments
 (0)