diff --git a/arrow-pg/src/encoder.rs b/arrow-pg/src/encoder.rs index 2de82f4..5490e1f 100644 --- a/arrow-pg/src/encoder.rs +++ b/arrow-pg/src/encoder.rs @@ -513,38 +513,31 @@ pub fn encode_value( if arr.is_null(idx) { return encoder.encode_field_with_type_and_format(&None::, type_, format); } - // Get the dictionary values, ignoring keys - // We'll use Int32Type as a common key type, but we're only interested in values - macro_rules! get_dict_values { + // Get the dictionary values and the mapped row index + macro_rules! get_dict_values_and_index { ($key_type:ty) => { arr.as_any() .downcast_ref::>() - .map(|dict| dict.values()) + .map(|dict| (dict.values(), dict.keys().value(idx) as usize)) }; } // Try to extract values using different key types - let values = get_dict_values!(Int8Type) - .or_else(|| get_dict_values!(Int16Type)) - .or_else(|| get_dict_values!(Int32Type)) - .or_else(|| get_dict_values!(Int64Type)) - .or_else(|| get_dict_values!(UInt8Type)) - .or_else(|| get_dict_values!(UInt16Type)) - .or_else(|| get_dict_values!(UInt32Type)) - .or_else(|| get_dict_values!(UInt64Type)) + let (values, idx) = get_dict_values_and_index!(Int8Type) + .or_else(|| get_dict_values_and_index!(Int16Type)) + .or_else(|| get_dict_values_and_index!(Int32Type)) + .or_else(|| get_dict_values_and_index!(Int64Type)) + .or_else(|| get_dict_values_and_index!(UInt8Type)) + .or_else(|| get_dict_values_and_index!(UInt16Type)) + .or_else(|| get_dict_values_and_index!(UInt32Type)) + .or_else(|| get_dict_values_and_index!(UInt64Type)) .ok_or_else(|| { ToSqlError::from(format!( "Unsupported dictionary key type for value type {value_type}" )) })?; - // If the dictionary has only one value, treat it as a primitive - if values.len() == 1 { - encode_value(encoder, values, 0, type_, format)? - } else { - // Otherwise, use value directly indexed by values array - encode_value(encoder, values, idx, type_, format)? - } + encode_value(encoder, values, idx, type_, format)? } _ => { return Err(PgWireError::ApiError(ToSqlError::from(format!( @@ -557,3 +550,48 @@ pub fn encode_value( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn encodes_dictionary_array() { + #[derive(Default)] + struct MockEncoder { + encoded_value: String, + } + + impl Encoder for MockEncoder { + fn encode_field_with_type_and_format( + &mut self, + value: &T, + data_type: &Type, + _format: FieldFormat, + ) -> PgWireResult<()> + where + T: ToSql + ToSqlText + Sized, + { + let mut bytes = BytesMut::new(); + let _sql_text = value.to_sql_text(data_type, &mut bytes); + let string = String::from_utf8((&bytes).to_vec()); + self.encoded_value = string.unwrap(); + Ok(()) + } + } + + let val = "~!@&$[]()@@!!"; + let value = StringArray::from_iter_values([val]); + let keys = Int8Array::from_iter_values([0, 0, 0, 0]); + let dict_arr: Arc = + Arc::new(DictionaryArray::::try_new(keys, Arc::new(value)).unwrap()); + + let mut encoder = MockEncoder::default(); + + let result = encode_value(&mut encoder, &dict_arr, 2, &Type::TEXT, FieldFormat::Text); + + assert!(result.is_ok()); + + assert!(encoder.encoded_value == val); + } +}