Skip to content
89 changes: 85 additions & 4 deletions datafusion-postgres/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use timezone::Tz;

pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
// Handle Dictionary types specially
if let DataType::Dictionary(_, value_type) = df_type {
// For Dictionary types, use the value type for mapping to Postgres types
return into_pg_type(value_type);
}

Ok(match df_type {
DataType::Null => Type::UNKNOWN,
DataType::Boolean => Type::BOOL,
Expand All @@ -41,7 +47,15 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
DataType::Utf8 => Type::VARCHAR,
DataType::LargeUtf8 => Type::TEXT,
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
match field.data_type() {
let field_type = field.data_type();

// Handle dictionary types in lists
let actual_type = match field_type {
DataType::Dictionary(_, value_type) => value_type.as_ref(),
_ => field_type,
};

match actual_type {
DataType::Boolean => Type::BOOL_ARRAY,
DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
Expand Down Expand Up @@ -239,11 +253,59 @@ fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<Naive
.value_as_datetime(idx)
}

fn encode_dictionary_value(
encoder: &mut DataRowEncoder,
arr: &Arc<dyn Array>,
idx: usize,
) -> Option<PgWireResult<()>> {
// First, extract the dictionary value type
let value_type = match arr.data_type() {
DataType::Dictionary(_, value_type) => value_type.as_ref(),
_ => return None,
};

// Handle different key types with a macro to reduce repetition
macro_rules! handle_dict_type {
($key_type:ty) => {{
let dict = arr.as_any().downcast_ref::<DictionaryArray<$key_type>>()?;
let key = dict.keys().value(idx) as usize;
// If the dictionary value is out of bounds, return None
if key >= dict.values().len() {
return None;
}
Some(encode_value(encoder, dict.values(), key))
}};
}

// Dispatch based on the key type
match arr.data_type() {
DataType::Dictionary(key_type, _) => {
match key_type.as_ref() {
DataType::Int8 => handle_dict_type!(Int8Type),
DataType::Int16 => handle_dict_type!(Int16Type),
DataType::Int32 => handle_dict_type!(Int32Type),
DataType::Int64 => handle_dict_type!(Int64Type),
DataType::UInt8 => handle_dict_type!(UInt8Type),
DataType::UInt16 => handle_dict_type!(UInt16Type),
DataType::UInt32 => handle_dict_type!(UInt32Type),
DataType::UInt64 => handle_dict_type!(UInt64Type),
_ => None
}
}
_ => None
}
}

fn encode_value(
encoder: &mut DataRowEncoder,
arr: &Arc<dyn Array>,
idx: usize,
) -> PgWireResult<()> {
// Handle dictionary encoding by extracting the actual value from the dictionary
if let Some(result) = encode_dictionary_value(encoder, arr, idx) {
return result;
}

match arr.data_type() {
DataType::Null => encoder.encode_field(&None::<i8>)?,
DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx))?,
Expand Down Expand Up @@ -347,7 +409,13 @@ fn encode_value(
},

DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
match field.data_type() {
// Extract the inner type, handling dictionaries by getting the value type
let field_type = match field.data_type() {
DataType::Dictionary(_, value_type) => value_type.as_ref(),
data_type => data_type,
};

match field_type {
DataType::Null => encoder.encode_field(&None::<i8>)?,
DataType::Boolean => encoder.encode_field(&get_bool_list_value(arr, idx))?,
DataType::Int8 => encoder.encode_field(&get_i8_list_value(arr, idx))?,
Expand Down Expand Up @@ -633,7 +701,15 @@ pub(crate) fn df_schema_to_pg_fields(
.iter()
.enumerate()
.map(|(idx, f)| {
let pg_type = into_pg_type(f.data_type())?;
// Get the actual data type, unwrapping any dictionary type
let data_type = match f.data_type() {
DataType::Dictionary(_, value_type) => value_type.as_ref(),
other_type => other_type,
};

// Convert to PostgreSQL type using the unwrapped type
let pg_type = into_pg_type(data_type)?;

Ok(FieldInfo::new(
f.name().into(),
None,
Expand Down Expand Up @@ -711,7 +787,12 @@ where
if let Some(ty) = pg_type_hint {
Ok(ty.clone())
} else if let Some(infer_type) = inferenced_type {
into_pg_type(infer_type)
// If inferenced type is a dictionary, use the value type
let actual_type = match infer_type {
DataType::Dictionary(_, value_type) => value_type.as_ref(),
other_type => other_type,
};
into_pg_type(actual_type)
} else {
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_string(),
Expand Down
Loading