Skip to content

Commit 86c5d2a

Browse files
committed
fix(cubesql): properly convert type from arrow to dataframe
1 parent 0604df1 commit 86c5d2a

File tree

1 file changed

+238
-17
lines changed

1 file changed

+238
-17
lines changed

rust/cubesql/cubesql/src/sql/dataframe.rs

Lines changed: 238 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -409,14 +409,15 @@ pub fn arrow_to_column_type(arrow_type: DataType) -> Result<ColumnType, CubeErro
409409
DataType::Float16 | DataType::Float32 | DataType::Float64 => Ok(ColumnType::Double),
410410
DataType::Boolean => Ok(ColumnType::Boolean),
411411
DataType::List(field) => Ok(ColumnType::List(field)),
412-
DataType::Int32 | DataType::UInt32 => Ok(ColumnType::Int32),
413412
DataType::Decimal(_, _) => Ok(ColumnType::Int32),
414-
DataType::Int8
415-
| DataType::Int16
416-
| DataType::Int64
417-
| DataType::UInt8
413+
DataType::Int8 //we are missing TableValue::Int8 type to use ColumnType:Int8
414+
| DataType::UInt8 //we are missing ColumnType::Int16 type
415+
| DataType::Int16 //we are missing ColumnType::Int16 type
418416
| DataType::UInt16
419-
| DataType::UInt64 => Ok(ColumnType::Int64),
417+
| DataType::Int32 => Ok(ColumnType::Int32),
418+
DataType::UInt32
419+
| DataType::Int64 => Ok(ColumnType::Int64),
420+
DataType::UInt64 => Ok(ColumnType::Decimal(128, 0)),
420421
DataType::Null => Ok(ColumnType::String),
421422
x => Err(CubeError::internal(format!("unsupported type {:?}", x))),
422423
}
@@ -452,12 +453,23 @@ pub fn batches_to_dataframe(
452453
let array = batch.column(column_index);
453454
let num_rows = batch.num_rows();
454455
match array.data_type() {
455-
DataType::UInt16 => convert_array!(array, num_rows, rows, UInt16Array, Int16, i16),
456+
DataType::Int8 => convert_array!(array, num_rows, rows, Int8Array, Int16, i16),
457+
DataType::UInt8 => convert_array!(array, num_rows, rows, UInt8Array, Int16, i16),
456458
DataType::Int16 => convert_array!(array, num_rows, rows, Int16Array, Int16, i16),
457-
DataType::UInt32 => convert_array!(array, num_rows, rows, UInt32Array, Int32, i32),
459+
DataType::UInt16 => convert_array!(array, num_rows, rows, UInt16Array, Int32, i32),
458460
DataType::Int32 => convert_array!(array, num_rows, rows, Int32Array, Int32, i32),
459-
DataType::UInt64 => convert_array!(array, num_rows, rows, UInt64Array, Int64, i64),
461+
DataType::UInt32 => convert_array!(array, num_rows, rows, UInt32Array, Int64, i64),
460462
DataType::Int64 => convert_array!(array, num_rows, rows, Int64Array, Int64, i64),
463+
DataType::UInt64 => {
464+
let a = array.as_any().downcast_ref::<UInt64Array>().unwrap();
465+
for i in 0..num_rows {
466+
rows[i].push(if a.is_null(i) {
467+
TableValue::Null
468+
} else {
469+
TableValue::Decimal128(Decimal128Value::new(a.value(i) as i128, 0))
470+
});
471+
}
472+
}
461473
DataType::Boolean => {
462474
convert_array!(array, num_rows, rows, BooleanArray, Boolean, bool)
463475
}
@@ -685,7 +697,17 @@ pub fn batches_to_dataframe(
685697

686698
#[cfg(test)]
687699
mod tests {
700+
use std::sync::Arc;
701+
702+
use datafusion::arrow::{array::PrimitiveArray, datatypes::TimestampMicrosecondType};
703+
use datafusion::scalar::ScalarValue::TimestampMillisecond;
704+
use itertools::Itertools;
705+
688706
use super::*;
707+
use crate::compile::arrow::{
708+
datatypes::{ArrowPrimitiveType, Field, TimestampNanosecondType},
709+
record_batch::RecordBatchOptions,
710+
};
689711

690712
#[test]
691713
fn test_dataframe_print() {
@@ -803,20 +825,26 @@ mod tests {
803825
(DataType::LargeUtf8, ColumnType::String),
804826
(DataType::Date32, ColumnType::Date(false)),
805827
(DataType::Date64, ColumnType::Date(true)),
806-
(DataType::Timestamp(TimeUnit::Second, None), ColumnType::Timestamp),
807-
(DataType::Interval(IntervalUnit::YearMonth), ColumnType::Interval(IntervalUnit::YearMonth)),
828+
(
829+
DataType::Timestamp(TimeUnit::Second, None),
830+
ColumnType::Timestamp,
831+
),
832+
(
833+
DataType::Interval(IntervalUnit::YearMonth),
834+
ColumnType::Interval(IntervalUnit::YearMonth),
835+
),
808836
(DataType::Float16, ColumnType::Double),
809837
(DataType::Float32, ColumnType::Double),
810838
(DataType::Float64, ColumnType::Double),
811839
(DataType::Boolean, ColumnType::Boolean),
812840
(DataType::Int32, ColumnType::Int32),
813-
(DataType::UInt32, ColumnType::Int32),
814-
(DataType::Int8, ColumnType::Int64),
815-
(DataType::Int16, ColumnType::Int64),
841+
(DataType::UInt32, ColumnType::Int64),
842+
(DataType::Int8, ColumnType::Int32),
843+
(DataType::Int16, ColumnType::Int32),
816844
(DataType::Int64, ColumnType::Int64),
817-
(DataType::UInt8, ColumnType::Int64),
818-
(DataType::UInt16, ColumnType::Int64),
819-
(DataType::UInt64, ColumnType::Int64),
845+
(DataType::UInt8, ColumnType::Int32),
846+
(DataType::UInt16, ColumnType::Int32),
847+
(DataType::UInt64, ColumnType::Decimal(128, 0)),
820848
(DataType::Null, ColumnType::String),
821849
];
822850

@@ -825,4 +853,197 @@ mod tests {
825853
assert_eq!(result, expected_column_type, "Failed for {:?}", arrow_type);
826854
}
827855
}
856+
857+
fn create_record_batch<T: ArrowPrimitiveType>(
858+
data_type: DataType,
859+
value: PrimitiveArray<T>,
860+
expected_data_type: ColumnType,
861+
expected_data: Vec<TableValue>,
862+
) -> Result<(), CubeError> {
863+
let batch = RecordBatch::try_new_with_options(
864+
Arc::new(Schema::new(vec![Field::new("data", data_type, false)])),
865+
vec![Arc::new(value)],
866+
&RecordBatchOptions::default(),
867+
)
868+
.map_err(|e| CubeError::from(e))?;
869+
870+
let df = batches_to_dataframe(&batch.schema(), vec![batch.clone()])?;
871+
let colums = df.get_columns().clone();
872+
let data = df.data;
873+
assert_eq!(
874+
colums.len(),
875+
1,
876+
"Expecting one column in DF, but: {:?}",
877+
colums
878+
);
879+
assert_eq!(expected_data_type, colums.get(0).unwrap().column_type);
880+
assert_eq!(
881+
data.len(),
882+
expected_data.len(),
883+
"Expecting {} columns in DF data, but: {:?}",
884+
expected_data.len(),
885+
data
886+
);
887+
let vec1 = data.into_iter().map(|r| r.values).flatten().collect_vec();
888+
assert_eq!(
889+
vec1.len(),
890+
expected_data.len(),
891+
"Data len {} != {}",
892+
vec1.len(),
893+
expected_data.len()
894+
);
895+
assert_eq!(vec1, expected_data);
896+
Ok(())
897+
}
898+
899+
#[test]
900+
fn test_timestamp_conversion() -> Result<(), CubeError> {
901+
let data_nano = vec![Some(1640995200000000000)];
902+
create_record_batch(
903+
DataType::Timestamp(TimeUnit::Nanosecond, None),
904+
TimestampNanosecondArray::from(data_nano.clone()),
905+
ColumnType::Timestamp,
906+
data_nano
907+
.into_iter()
908+
.map(|e| TableValue::Timestamp(TimestampValue::new(e.unwrap(), None)))
909+
.collect::<Vec<_>>(),
910+
)?;
911+
912+
let data_micro = vec![Some(1640995200000000)];
913+
create_record_batch(
914+
DataType::Timestamp(TimeUnit::Microsecond, None),
915+
TimestampMicrosecondArray::from(data_micro.clone()),
916+
ColumnType::Timestamp,
917+
data_micro
918+
.into_iter()
919+
.map(|e| TableValue::Timestamp(TimestampValue::new(e.unwrap() * 1000, None)))
920+
.collect::<Vec<_>>(),
921+
)?;
922+
923+
let data_milli = vec![Some(1640995200000)];
924+
create_record_batch(
925+
DataType::Timestamp(TimeUnit::Millisecond, None),
926+
TimestampMillisecondArray::from(data_milli.clone()),
927+
ColumnType::Timestamp,
928+
data_milli
929+
.into_iter()
930+
.map(|e| TableValue::Timestamp(TimestampValue::new(e.unwrap() * 1000000, None)))
931+
.collect::<Vec<_>>(),
932+
)
933+
}
934+
935+
#[test]
936+
fn test_signed_conversion() -> Result<(), CubeError> {
937+
let data8 = vec![i8::MIN, -1, 0, 1, 2, 3, 4, i8::MAX];
938+
create_record_batch(
939+
DataType::Int8,
940+
Int8Array::from(data8.clone()),
941+
ColumnType::Int32, //here we are missing TableValue::Int8 to use ColumnType::Int32
942+
data8
943+
.into_iter()
944+
.map(|e| TableValue::Int16(e as i16))
945+
.collect::<Vec<_>>(),
946+
)?;
947+
948+
let data16 = vec![i16::MIN, -1, 0, 1, 2, 3, 4, i16::MAX];
949+
create_record_batch(
950+
DataType::Int16,
951+
Int16Array::from(data16.clone()),
952+
ColumnType::Int32, //here we are missing ColumnType::Int16
953+
data16
954+
.into_iter()
955+
.map(|e| TableValue::Int16(e))
956+
.collect::<Vec<_>>(),
957+
)?;
958+
959+
let data32 = vec![i32::MIN, -1, 0, 1, 2, 3, 4, i32::MAX];
960+
create_record_batch(
961+
DataType::Int32,
962+
Int32Array::from(data32.clone()),
963+
ColumnType::Int32,
964+
data32
965+
.into_iter()
966+
.map(|e| TableValue::Int32(e))
967+
.collect::<Vec<_>>(),
968+
)?;
969+
970+
let data64 = vec![i64::MIN, -1, 0, 1, 2, 3, 4, i64::MAX];
971+
create_record_batch(
972+
DataType::Int64,
973+
Int64Array::from(data64.clone()),
974+
ColumnType::Int64,
975+
data64
976+
.into_iter()
977+
.map(|e| TableValue::Int64(e))
978+
.collect::<Vec<_>>(),
979+
)
980+
}
981+
982+
#[test]
983+
fn test_unsigned_conversion() -> Result<(), CubeError> {
984+
let data8 = vec![0, 1, 2, 3, 4, u8::MAX];
985+
create_record_batch(
986+
DataType::UInt8,
987+
UInt8Array::from(data8.clone()),
988+
ColumnType::Int32, //here we are missing ColumnType::Int16
989+
data8
990+
.into_iter()
991+
.map(|e| TableValue::Int16(e as i16))
992+
.collect::<Vec<_>>(),
993+
)?;
994+
995+
let data16 = vec![0, 1, 2, 3, 4, u16::MAX];
996+
create_record_batch(
997+
DataType::UInt16,
998+
UInt16Array::from(data16.clone()),
999+
ColumnType::Int32,
1000+
data16
1001+
.into_iter()
1002+
.map(|e| TableValue::Int32(e as i32))
1003+
.collect::<Vec<_>>(),
1004+
)?;
1005+
1006+
let data32 = vec![0, 1, 2, 3, 4, u32::MAX];
1007+
create_record_batch(
1008+
DataType::UInt32,
1009+
UInt32Array::from(data32.clone()),
1010+
ColumnType::Int64,
1011+
data32
1012+
.into_iter()
1013+
.map(|e| TableValue::Int64(e as i64))
1014+
.collect::<Vec<_>>(),
1015+
)?;
1016+
1017+
let data64 = vec![0, 1, 2, 3, 4, u64::MAX];
1018+
create_record_batch(
1019+
DataType::UInt64,
1020+
UInt64Array::from(data64.clone()),
1021+
ColumnType::Decimal(128, 0),
1022+
data64
1023+
.into_iter()
1024+
.map(|e| TableValue::Decimal128(Decimal128Value::new(e as i128, 0)))
1025+
.collect::<Vec<_>>(),
1026+
)
1027+
}
1028+
1029+
impl PartialEq for TableValue {
1030+
fn eq(&self, other: &Self) -> bool {
1031+
match (self, other) {
1032+
(TableValue::Null, TableValue::Null) => true,
1033+
(TableValue::String(a), TableValue::String(b)) => a == b,
1034+
(TableValue::Int16(a), TableValue::Int16(b)) => a == b,
1035+
(TableValue::Int32(a), TableValue::Int32(b)) => a == b,
1036+
(TableValue::Int64(a), TableValue::Int64(b)) => a == b,
1037+
(TableValue::Boolean(a), TableValue::Boolean(b)) => a == b,
1038+
(TableValue::Float32(a), TableValue::Float32(b)) => a == b,
1039+
(TableValue::Float64(a), TableValue::Float64(b)) => a == b,
1040+
(TableValue::List(_), TableValue::List(_)) => panic!("unsupported"),
1041+
(TableValue::Decimal128(a), TableValue::Decimal128(b)) => a == b,
1042+
(TableValue::Date(a), TableValue::Date(b)) => a == b,
1043+
(TableValue::Timestamp(a), TableValue::Timestamp(b)) => a == b,
1044+
(TableValue::Interval(_), TableValue::Interval(_)) => panic!("unsupported"),
1045+
_ => false,
1046+
}
1047+
}
1048+
}
8281049
}

0 commit comments

Comments
 (0)