Skip to content

Commit 7f966ba

Browse files
committed
support dictionary keyed by integers via a macro
1 parent b3a441c commit 7f966ba

File tree

1 file changed

+43
-43
lines changed

1 file changed

+43
-43
lines changed

datafusion-postgres/src/datatypes.rs

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -258,50 +258,37 @@ fn encode_dictionary_value(
258258
arr: &Arc<dyn Array>,
259259
idx: usize,
260260
) -> Option<PgWireResult<()>> {
261-
// Use pattern matching to handle dictionary arrays with different key types
261+
// First, extract the dictionary value type
262+
let value_type = match arr.data_type() {
263+
DataType::Dictionary(_, value_type) => value_type.as_ref(),
264+
_ => return None,
265+
};
266+
267+
// Handle different key types with a macro to reduce repetition
268+
macro_rules! handle_dict_type {
269+
($key_type:ty) => {{
270+
let dict = arr.as_any().downcast_ref::<DictionaryArray<$key_type>>()?;
271+
let key = dict.keys().value(idx) as usize;
272+
// If the dictionary value is out of bounds, return None
273+
if key >= dict.values().len() {
274+
return None;
275+
}
276+
Some(encode_value(encoder, dict.values(), key))
277+
}};
278+
}
279+
280+
// Dispatch based on the key type
262281
match arr.data_type() {
263282
DataType::Dictionary(key_type, _) => {
264283
match key_type.as_ref() {
265-
DataType::Int8 => {
266-
let dict = arr.as_any().downcast_ref::<DictionaryArray<Int8Type>>()?;
267-
let key = dict.keys().value(idx) as usize;
268-
Some(encode_value(encoder, dict.values(), key))
269-
}
270-
DataType::Int16 => {
271-
let dict = arr.as_any().downcast_ref::<DictionaryArray<Int16Type>>()?;
272-
let key = dict.keys().value(idx) as usize;
273-
Some(encode_value(encoder, dict.values(), key))
274-
}
275-
DataType::Int32 => {
276-
let dict = arr.as_any().downcast_ref::<DictionaryArray<Int32Type>>()?;
277-
let key = dict.keys().value(idx) as usize;
278-
Some(encode_value(encoder, dict.values(), key))
279-
}
280-
DataType::Int64 => {
281-
let dict = arr.as_any().downcast_ref::<DictionaryArray<Int64Type>>()?;
282-
let key = dict.keys().value(idx) as usize;
283-
Some(encode_value(encoder, dict.values(), key))
284-
}
285-
DataType::UInt8 => {
286-
let dict = arr.as_any().downcast_ref::<DictionaryArray<UInt8Type>>()?;
287-
let key = dict.keys().value(idx) as usize;
288-
Some(encode_value(encoder, dict.values(), key))
289-
}
290-
DataType::UInt16 => {
291-
let dict = arr.as_any().downcast_ref::<DictionaryArray<UInt16Type>>()?;
292-
let key = dict.keys().value(idx) as usize;
293-
Some(encode_value(encoder, dict.values(), key))
294-
}
295-
DataType::UInt32 => {
296-
let dict = arr.as_any().downcast_ref::<DictionaryArray<UInt32Type>>()?;
297-
let key = dict.keys().value(idx) as usize;
298-
Some(encode_value(encoder, dict.values(), key))
299-
}
300-
DataType::UInt64 => {
301-
let dict = arr.as_any().downcast_ref::<DictionaryArray<UInt64Type>>()?;
302-
let key = dict.keys().value(idx) as usize;
303-
Some(encode_value(encoder, dict.values(), key))
304-
}
284+
DataType::Int8 => handle_dict_type!(Int8Type),
285+
DataType::Int16 => handle_dict_type!(Int16Type),
286+
DataType::Int32 => handle_dict_type!(Int32Type),
287+
DataType::Int64 => handle_dict_type!(Int64Type),
288+
DataType::UInt8 => handle_dict_type!(UInt8Type),
289+
DataType::UInt16 => handle_dict_type!(UInt16Type),
290+
DataType::UInt32 => handle_dict_type!(UInt32Type),
291+
DataType::UInt64 => handle_dict_type!(UInt64Type),
305292
_ => None
306293
}
307294
}
@@ -714,7 +701,15 @@ pub(crate) fn df_schema_to_pg_fields(
714701
.iter()
715702
.enumerate()
716703
.map(|(idx, f)| {
717-
let pg_type = into_pg_type(f.data_type())?;
704+
// Get the actual data type, unwrapping any dictionary type
705+
let data_type = match f.data_type() {
706+
DataType::Dictionary(_, value_type) => value_type.as_ref(),
707+
other_type => other_type,
708+
};
709+
710+
// Convert to PostgreSQL type using the unwrapped type
711+
let pg_type = into_pg_type(data_type)?;
712+
718713
Ok(FieldInfo::new(
719714
f.name().into(),
720715
None,
@@ -792,7 +787,12 @@ where
792787
if let Some(ty) = pg_type_hint {
793788
Ok(ty.clone())
794789
} else if let Some(infer_type) = inferenced_type {
795-
into_pg_type(infer_type)
790+
// If inferenced type is a dictionary, use the value type
791+
let actual_type = match infer_type {
792+
DataType::Dictionary(_, value_type) => value_type.as_ref(),
793+
other_type => other_type,
794+
};
795+
into_pg_type(actual_type)
796796
} else {
797797
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
798798
"FATAL".to_string(),

0 commit comments

Comments
 (0)