Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions arrow-pg/src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,38 +513,31 @@ pub fn encode_value<T: Encoder>(
if arr.is_null(idx) {
return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
}
// Get the dictionary values, ignoring keys
// We'll use Int32Type as a common key type, but we're only interested in values
macro_rules! get_dict_values {
// Get the dictionary values and the mapped row index
macro_rules! get_dict_values_and_index {
($key_type:ty) => {
arr.as_any()
.downcast_ref::<DictionaryArray<$key_type>>()
.map(|dict| dict.values())
.map(|dict| (dict.values(), dict.keys().value(idx) as usize))
};
}

// Try to extract values using different key types
let values = get_dict_values!(Int8Type)
.or_else(|| get_dict_values!(Int16Type))
.or_else(|| get_dict_values!(Int32Type))
.or_else(|| get_dict_values!(Int64Type))
.or_else(|| get_dict_values!(UInt8Type))
.or_else(|| get_dict_values!(UInt16Type))
.or_else(|| get_dict_values!(UInt32Type))
.or_else(|| get_dict_values!(UInt64Type))
let (values, idx) = get_dict_values_and_index!(Int8Type)
.or_else(|| get_dict_values_and_index!(Int16Type))
.or_else(|| get_dict_values_and_index!(Int32Type))
.or_else(|| get_dict_values_and_index!(Int64Type))
.or_else(|| get_dict_values_and_index!(UInt8Type))
.or_else(|| get_dict_values_and_index!(UInt16Type))
.or_else(|| get_dict_values_and_index!(UInt32Type))
.or_else(|| get_dict_values_and_index!(UInt64Type))
.ok_or_else(|| {
ToSqlError::from(format!(
"Unsupported dictionary key type for value type {value_type}"
))
})?;

// If the dictionary has only one value, treat it as a primitive
if values.len() == 1 {
encode_value(encoder, values, 0, type_, format)?
} else {
// Otherwise, use value directly indexed by values array
encode_value(encoder, values, idx, type_, format)?
}
encode_value(encoder, values, idx, type_, format)?
}
_ => {
return Err(PgWireError::ApiError(ToSqlError::from(format!(
Expand All @@ -557,3 +550,48 @@ pub fn encode_value<T: Encoder>(

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn encodes_dictionary_array() {
#[derive(Default)]
struct MockEncoder {
encoded_value: String,
}

impl Encoder for MockEncoder {
fn encode_field_with_type_and_format<T>(
&mut self,
value: &T,
data_type: &Type,
_format: FieldFormat,
) -> PgWireResult<()>
where
T: ToSql + ToSqlText + Sized,
{
let mut bytes = BytesMut::new();
let _sql_text = value.to_sql_text(data_type, &mut bytes);
let string = String::from_utf8((&bytes).to_vec());
self.encoded_value = string.unwrap();
Ok(())
}
}

let val = "~!@&$[]()@@!!";
let value = StringArray::from_iter_values([val]);
let keys = Int8Array::from_iter_values([0, 0, 0, 0]);
let dict_arr: Arc<dyn Array> =
Arc::new(DictionaryArray::<Int8Type>::try_new(keys, Arc::new(value)).unwrap());

let mut encoder = MockEncoder::default();

let result = encode_value(&mut encoder, &dict_arr, 2, &Type::TEXT, FieldFormat::Text);

assert!(result.is_ok());

assert!(encoder.encoded_value == val);
}
}
Loading