Skip to content

Commit 0de2318

Browse files
committed
refactor dictionary type handling
1 parent 7f966ba commit 0de2318

File tree

1 file changed

+39
-55
lines changed

1 file changed

+39
-55
lines changed

datafusion-postgres/src/datatypes.rs

Lines changed: 39 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,6 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
1818
use timezone::Tz;
1919

2020
pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
21-
// Handle Dictionary types specially
22-
if let DataType::Dictionary(_, value_type) = df_type {
23-
// For Dictionary types, use the value type for mapping to Postgres types
24-
return into_pg_type(value_type);
25-
}
26-
2721
Ok(match df_type {
2822
DataType::Null => Type::UNKNOWN,
2923
DataType::Boolean => Type::BOOL,
@@ -86,6 +80,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
8680
}
8781
}
8882
DataType::Utf8View => Type::TEXT,
83+
DataType::Dictionary(_, value_type) => into_pg_type(value_type)?,
8984
_ => {
9085
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
9186
"ERROR".to_owned(),
@@ -253,59 +248,12 @@ fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<Naive
253248
.value_as_datetime(idx)
254249
}
255250

256-
fn encode_dictionary_value(
257-
encoder: &mut DataRowEncoder,
258-
arr: &Arc<dyn Array>,
259-
idx: usize,
260-
) -> Option<PgWireResult<()>> {
261-
// First, extract the dictionary value type
262-
let value_type = match arr.data_type() {
263-
DataType::Dictionary(_, value_type) => value_type.as_ref(),
264-
_ => return None,
265-
};
266-
267-
// Handle different key types with a macro to reduce repetition
268-
macro_rules! handle_dict_type {
269-
($key_type:ty) => {{
270-
let dict = arr.as_any().downcast_ref::<DictionaryArray<$key_type>>()?;
271-
let key = dict.keys().value(idx) as usize;
272-
// If the dictionary value is out of bounds, return None
273-
if key >= dict.values().len() {
274-
return None;
275-
}
276-
Some(encode_value(encoder, dict.values(), key))
277-
}};
278-
}
279-
280-
// Dispatch based on the key type
281-
match arr.data_type() {
282-
DataType::Dictionary(key_type, _) => {
283-
match key_type.as_ref() {
284-
DataType::Int8 => handle_dict_type!(Int8Type),
285-
DataType::Int16 => handle_dict_type!(Int16Type),
286-
DataType::Int32 => handle_dict_type!(Int32Type),
287-
DataType::Int64 => handle_dict_type!(Int64Type),
288-
DataType::UInt8 => handle_dict_type!(UInt8Type),
289-
DataType::UInt16 => handle_dict_type!(UInt16Type),
290-
DataType::UInt32 => handle_dict_type!(UInt32Type),
291-
DataType::UInt64 => handle_dict_type!(UInt64Type),
292-
_ => None
293-
}
294-
}
295-
_ => None
296-
}
297-
}
298251

299252
fn encode_value(
300253
encoder: &mut DataRowEncoder,
301254
arr: &Arc<dyn Array>,
302255
idx: usize,
303256
) -> PgWireResult<()> {
304-
// Handle dictionary encoding by extracting the actual value from the dictionary
305-
if let Some(result) = encode_dictionary_value(encoder, arr, idx) {
306-
return result;
307-
}
308-
309257
match arr.data_type() {
310258
DataType::Null => encoder.encode_field(&None::<i8>)?,
311259
DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx))?,
@@ -677,6 +625,42 @@ fn encode_value(
677625
}
678626
}
679627
}
628+
DataType::Dictionary(_, value_type) => {
629+
// Get the dictionary values, ignoring keys
630+
// We'll use Int32Type as a common key type, but we're only interested in values
631+
macro_rules! get_dict_values {
632+
($key_type:ty) => {
633+
arr.as_any()
634+
.downcast_ref::<DictionaryArray<$key_type>>()
635+
.map(|dict| dict.values())
636+
};
637+
}
638+
639+
// Try to extract values using different key types
640+
let values = get_dict_values!(Int8Type)
641+
.or_else(|| get_dict_values!(Int16Type))
642+
.or_else(|| get_dict_values!(Int32Type))
643+
.or_else(|| get_dict_values!(Int64Type))
644+
.or_else(|| get_dict_values!(UInt8Type))
645+
.or_else(|| get_dict_values!(UInt16Type))
646+
.or_else(|| get_dict_values!(UInt32Type))
647+
.or_else(|| get_dict_values!(UInt64Type))
648+
.ok_or_else(|| {
649+
PgWireError::UserError(Box::new(ErrorInfo::new(
650+
"ERROR".to_owned(),
651+
"XX000".to_owned(),
652+
format!("Unsupported dictionary key type for value type {}", value_type),
653+
)))
654+
})?;
655+
656+
// If the dictionary has only one value, treat it as a primitive
657+
if values.len() == 1 {
658+
encode_value(encoder, values, 0)?
659+
} else {
660+
// Otherwise, use value directly indexed by values array
661+
encode_value(encoder, values, idx)?
662+
}
663+
}
680664
_ => {
681665
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
682666
"ERROR".to_owned(),
@@ -706,10 +690,10 @@ pub(crate) fn df_schema_to_pg_fields(
706690
DataType::Dictionary(_, value_type) => value_type.as_ref(),
707691
other_type => other_type,
708692
};
709-
693+
710694
// Convert to PostgreSQL type using the unwrapped type
711695
let pg_type = into_pg_type(data_type)?;
712-
696+
713697
Ok(FieldInfo::new(
714698
f.name().into(),
715699
None,

0 commit comments

Comments
 (0)